jaxaht-benchmark / common /agent_loader_from_config.py
lainwired's picture
Initial jaxaht-benchmark deployment
5146e76
import logging
import jax
import numpy as np
import os
from omegaconf import OmegaConf
from agents.initialize_agents import initialize_s5_agent, initialize_mlp_agent, \
initialize_rnn_agent, initialize_actor_with_double_critic, \
initialize_actor_with_conditional_critic, \
initialize_liam_agent, initialize_meliba_agent
from agents.lbf.agent_policy_wrappers import (
LBFRandomPolicyWrapper, LBFSequentialFruitPolicyWrapper,
LBFEntitledPolicyWrapper, LBFGreedyHeuristicPolicyWrapper,
)
from agents.overcooked.agent_policy_wrappers import (
OvercookedRandomPolicyWrapper, OvercookedIndependentPolicyWrapper,
OvercookedOnionPolicyWrapper, OvercookedPlatePolicyWrapper,
OvercookedStaticPolicyWrapper,
)
from agents.hanabi.agent_policy_wrappers import (
HanabiRandomPolicyWrapper, HanabiRuleBasedPolicyWrapper,
HanabiIGGIPolicyWrapper, HanabiPiersPolicyWrapper,
HanabiFlawedPolicyWrapper, HanabiOuterPolicyWrapper,
HanabiVanDenBerghPolicyWrapper, HanabiSmartBotPolicyWrapper,
HanabiOBLPolicyWrapper, HanabiBCLSTMPolicyWrapper,
HanabiInternalPolicyWrapper, HanabiCautiousPolicyWrapper,
)
from agents.dsse.agent_policy_wrappers import (
DSSERandomPolicyWrapper, DSSEGreedySearchPolicyWrapper,
DSSESweepPolicyWrapper,
)
from common.save_load_utils import load_checkpoints, REPO_PATH
from envs.overcooked.augmented_layouts import augmented_layouts
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def _validate_teammate_path(path: str) -> str:
"""Validate checkpoint path and fail fast if it does not exist."""
resolved_path = path if os.path.isabs(path) else os.path.join(REPO_PATH, path)
if not os.path.exists(resolved_path):
raise FileNotFoundError(
f"Checkpoint path does not exist: {path}. "
"Use the new eval_teammates/... layout."
)
return path
def process_idx_list(idx_list):
idxs = np.array(idx_list)
if idxs.ndim == 1:
return idxs
elif idxs.ndim == 2:
rows = idxs[:, 0]
cols = idxs[:, 1]
return (rows, cols)
else:
raise ValueError(f"Invalid index list shape: {idxs.shape}")
def create_idx_labels(idx_list, checkpoint_shape):
"""Create human-readable index labels based on the checkpoint extraction pattern.
Args:
idx_list: The list of indices used to extract checkpoints, or None/"all" for all checkpoints
checkpoint_shape: The shape of the checkpoint array
Returns:
Array of string labels with the same shape as the original checkpoint array,
or filtered according to idx_list
"""
# If loading all checkpoints, create labels with the same shape as the original checkpoints
if idx_list is None:
# Handle different dimensional checkpoints
if len(checkpoint_shape) >= 3: # At least 2D of checkpoints (e.g., seeds × steps)
# Create a 2D array of labels with same shape as the first two dimensions
rows, cols = checkpoint_shape[0], checkpoint_shape[1]
idx_labels = []
for row_idx in range(rows):
row_labels = []
for col_idx in range(cols):
row_labels.append(f"{row_idx}, {col_idx}")
idx_labels.append(row_labels)
else: # 1D of checkpoints
# Create a 1D array of labels
idx_labels = [f"{i}" for i in range(checkpoint_shape[0])]
# If loading specific checkpoints, create labels by converting idx_list to strings
elif isinstance(idx_list[0], list):
# For 2D indices like [[0, -1], [1, -1], [2, -1]]
idx_labels = [f"{idx[0]}, {idx[1]}" for idx in idx_list]
else:
# For 1D indices like [0, 1, 2]
idx_labels = [f"{idx}" for idx in idx_list]
return idx_labels
def initialize_heuristic_agent_from_config(agent_config, agent_name, task_name, env_kwargs=None):
'''Load a heuristic (non-RL) agent from config, dispatching on task_name.
agent_config must include "actor_type".
env_kwargs is used as a fallback for env-level parameters (e.g. grid_size,
num_fruits for lbf; layout for overcooked-v1). Per-agent values in
agent_config take priority. Note that the LBF enviornment calls this
'num_food'.
Returns:
policy: policy function (no checkpoint or params required)
'''
assert "actor_type" in agent_config, "Actor type must be provided."
actor_type = agent_config["actor_type"]
if env_kwargs is None:
env_kwargs = {}
if 'lbf' in task_name:
# Grid dimensions: per-agent config > env_kwargs > defaults (7x7, 3 fruits).
grid_size = agent_config.get("grid_size", env_kwargs.get("grid_size", 7))
num_fruits = agent_config.get("num_fruits", env_kwargs.get("num_food", 3))
if actor_type == "random_agent":
return LBFRandomPolicyWrapper()
if actor_type == "seq_agent":
ordering_strategy = agent_config.get("ordering_strategy", "lexicographic")
return LBFSequentialFruitPolicyWrapper(
grid_size=grid_size,
num_fruits=num_fruits,
ordering_strategy=ordering_strategy,
using_log_wrapper=True,
)
if actor_type == "entitled_agent":
return LBFEntitledPolicyWrapper(
grid_size=grid_size,
num_fruits=num_fruits,
using_log_wrapper=True,
)
if actor_type == "greedy_agent":
heuristic = agent_config.get("heuristic", "closest_self")
return LBFGreedyHeuristicPolicyWrapper(
grid_size=grid_size,
num_fruits=num_fruits,
heuristic=heuristic,
using_log_wrapper=True,
)
raise ValueError(f"Unrecognized actor type for {task_name}: '{actor_type}' ({agent_name})")
if 'overcooked-v1' in task_name:
aug_layout_dict = augmented_layouts[env_kwargs["layout"]]
if actor_type == "random_agent":
return OvercookedRandomPolicyWrapper(aug_layout_dict, using_log_wrapper=True)
if actor_type == "static_agent":
return OvercookedStaticPolicyWrapper(aug_layout_dict, using_log_wrapper=True)
if actor_type == "independent_agent":
return OvercookedIndependentPolicyWrapper(
aug_layout_dict,
using_log_wrapper=True,
p_onion_on_counter=agent_config.get("p_onion_on_counter", 0.0),
p_plate_on_counter=agent_config.get("p_plate_on_counter", 0.0),
)
if actor_type == "onion_agent":
return OvercookedOnionPolicyWrapper(
aug_layout_dict,
using_log_wrapper=True,
p_onion_on_counter=agent_config.get("p_onion_on_counter", 0.0),
)
if actor_type == "plate_agent":
return OvercookedPlatePolicyWrapper(
aug_layout_dict,
using_log_wrapper=True,
p_plate_on_counter=agent_config.get("p_plate_on_counter", 0.0),
)
raise ValueError(f"Unrecognized actor type for {task_name}: '{actor_type}' ({agent_name})")
if task_name == 'dsse':
grid_size = agent_config.get("grid_size", env_kwargs.get("grid_size", 7))
if actor_type == "random_agent":
return DSSERandomPolicyWrapper(using_log_wrapper=True)
if actor_type == "greedy_search_agent":
return DSSEGreedySearchPolicyWrapper(grid_size=grid_size, using_log_wrapper=True)
if actor_type == "sweep_agent":
return DSSESweepPolicyWrapper(grid_size=grid_size, using_log_wrapper=True)
raise ValueError(f"Unrecognized actor type for {task_name}: '{actor_type}' ({agent_name})")
if 'hanabi' in task_name:
# Default to full Hanabi shape; mini-hanabi callers must pass num_colors,
# num_ranks, hand_size, num_actions through agent_config or env_kwargs.
hand_size = agent_config.get("hand_size", env_kwargs.get("hand_size", 5))
num_colors = agent_config.get("num_colors", env_kwargs.get("num_colors", 5))
num_ranks = agent_config.get("num_ranks", env_kwargs.get("num_ranks", 5))
# Action layout: discard + play + color hints + rank hints + noop.
num_actions = agent_config.get(
"num_actions", 2 * hand_size + num_colors + num_ranks + 1
)
common = dict(
hand_size=hand_size, num_colors=num_colors, num_ranks=num_ranks,
num_actions=num_actions, using_log_wrapper=True,
)
if actor_type == "random_agent":
return HanabiRandomPolicyWrapper(
num_actions=num_actions, using_log_wrapper=True
)
if actor_type == "rule_based":
return HanabiRuleBasedPolicyWrapper(
strategy=agent_config.get("strategy", "cautious"), **common
)
if actor_type == "iggi":
return HanabiIGGIPolicyWrapper(**common)
if actor_type == "piers":
return HanabiPiersPolicyWrapper(
play_threshold=agent_config.get("play_threshold", 0.6),
hint_threshold=agent_config.get("hint_threshold", 4),
**common,
)
if actor_type == "flawed":
return HanabiFlawedPolicyWrapper(
play_threshold=agent_config.get("play_threshold", 0.4), **common
)
if actor_type == "outer":
return HanabiOuterPolicyWrapper(**common)
if actor_type == "van_den_bergh":
return HanabiVanDenBerghPolicyWrapper(**common)
if actor_type == "internal":
return HanabiInternalPolicyWrapper(**common)
if actor_type == "cautious":
return HanabiCautiousPolicyWrapper(**common)
if actor_type == "smartbot":
return HanabiSmartBotPolicyWrapper(
card_counts=agent_config.get("card_counts", None), **common
)
if actor_type == "obl_r2d2":
return HanabiOBLPolicyWrapper(
weight_file=agent_config["weight_file"], using_log_wrapper=True
)
if actor_type == "bc_lstm":
return HanabiBCLSTMPolicyWrapper(
weight_file=agent_config["weight_file"],
using_log_wrapper=True,
greedy=agent_config.get("greedy", True),
)
raise ValueError(f"Unrecognized actor type for {task_name}: '{actor_type}' ({agent_name})")
raise ValueError(
f"Unknown task '{task_name}' for heuristic agent {agent_name}. "
f"Expected 'lbf', 'overcooked-v1', 'dsse', or a task containing 'hanabi'."
)
def initialize_rl_agent_from_config(agent_config, agent_name, env, rng):
'''Load RL agent from checkpoint and initialize from config.
The agent_config dictionary should have the following structure:
{
"path": str,
"actor_type": str,
"ckpt_key": str, # key to load from checkpoint. Default is "checkpoints".
"custom_loader": dict, # custom loader for the checkpoint. Default is None.
"idx_list": list, # list of indices to load from checkpoint. If null, all checkpoints will be loaded.
# and any other parameters needed to initialize the agent policy
}
Returns:
policy: policy function
agent_params: agent parameters from checkpoint
init_params: initial agent parameters from initialization
idx_list: list of indices used to extract checkpoints
idx_labels: list of string labels corresponding to the indices
'''
assert "path" in agent_config, "Path to agent checkpoint must be provided."
assert "actor_type" in agent_config, "Actor type must be provided."
assert "idx_list" in agent_config, "Indices to load from checkpoint must be provided."
agent_path = _validate_teammate_path(agent_config["path"])
ckpt_key = agent_config.get("ckpt_key", "checkpoints")
custom_loader_cfg = agent_config.get("custom_loader", None)
agent_ckpt = load_checkpoints(agent_path, ckpt_key=ckpt_key, custom_loader_cfg=custom_loader_cfg)
leaf0_shape = jax.tree.leaves(agent_ckpt)[0].shape
if agent_config["idx_list"] is None: # load all checkpoints
idx_list = None
agent_params = agent_ckpt
else: # load specific checkpoints
# convert omegaconf list config to list recursively
try:
idx_list = OmegaConf.to_object(agent_config["idx_list"])
except Exception as e:
log.warning(f"Error interpreting agent_config['idx_list'] as OmegaConf object: {e}. Treating as list.")
idx_list = agent_config["idx_list"]
idx_list = jax.tree.map(lambda x: int(x), idx_list)
idxs = process_idx_list(idx_list)
agent_params = jax.tree.map(lambda x: x[idxs], agent_ckpt)
log.info(f"Loaded {agent_name} checkpoint where leaf 0 has shape {leaf0_shape}. "
f" Selecting indices {idx_list if idx_list is not None else 'all'} for evaluation.")
# Create index labels for the loaded checkpoints
idx_labels = create_idx_labels(idx_list, leaf0_shape)
rng, init_rng = jax.random.split(rng, 2)
if agent_config["actor_type"] == "s5":
policy, init_params = initialize_s5_agent(agent_config, env, init_rng)
# Make compatible with old naming for S5 layers
if "action_body_0" in agent_params['params'].keys(): # CLEANUP FLAG
agent_param_keys = list(agent_params['params'].keys())
for k in agent_param_keys:
if "body" in k:
new_k = k.replace("body", "body_layers")
agent_params['params'][new_k] = agent_params['params'][k]
del agent_params['params'][k]
elif agent_config["actor_type"] == "mlp":
policy, init_params = initialize_mlp_agent(agent_config, env, init_rng)
elif agent_config["actor_type"] == "rnn":
policy, init_params = initialize_rnn_agent(agent_config, env, init_rng)
elif agent_config["actor_type"] == "actor_with_double_critic":
policy, init_params = initialize_actor_with_double_critic(agent_config, env, init_rng)
elif agent_config["actor_type"] == "actor_with_conditional_critic":
policy, init_params = initialize_actor_with_conditional_critic(agent_config, env, init_rng)
elif agent_config["actor_type"] == "liam":
policy, init_params = initialize_liam_agent(agent_config, env, init_rng)
elif agent_config["actor_type"] == "meliba":
policy, init_params = initialize_meliba_agent(agent_config, env, init_rng)
else:
raise ValueError(f"Invalid actor type: {agent_config['actor_type']}")
assert jax.tree.structure(agent_params) == jax.tree.structure(init_params), "Agent parameters and initial parameters must have the same structure."
return policy, agent_params, init_params, idx_labels