Spaces:
Running
Running
| import logging | |
| import jax | |
| import numpy as np | |
| import os | |
| from omegaconf import OmegaConf | |
| from agents.initialize_agents import initialize_s5_agent, initialize_mlp_agent, \ | |
| initialize_rnn_agent, initialize_actor_with_double_critic, \ | |
| initialize_actor_with_conditional_critic, \ | |
| initialize_liam_agent, initialize_meliba_agent | |
| from agents.lbf.agent_policy_wrappers import ( | |
| LBFRandomPolicyWrapper, LBFSequentialFruitPolicyWrapper, | |
| LBFEntitledPolicyWrapper, LBFGreedyHeuristicPolicyWrapper, | |
| ) | |
| from agents.overcooked.agent_policy_wrappers import ( | |
| OvercookedRandomPolicyWrapper, OvercookedIndependentPolicyWrapper, | |
| OvercookedOnionPolicyWrapper, OvercookedPlatePolicyWrapper, | |
| OvercookedStaticPolicyWrapper, | |
| ) | |
| from agents.hanabi.agent_policy_wrappers import ( | |
| HanabiRandomPolicyWrapper, HanabiRuleBasedPolicyWrapper, | |
| HanabiIGGIPolicyWrapper, HanabiPiersPolicyWrapper, | |
| HanabiFlawedPolicyWrapper, HanabiOuterPolicyWrapper, | |
| HanabiVanDenBerghPolicyWrapper, HanabiSmartBotPolicyWrapper, | |
| HanabiOBLPolicyWrapper, HanabiBCLSTMPolicyWrapper, | |
| HanabiInternalPolicyWrapper, HanabiCautiousPolicyWrapper, | |
| ) | |
| from agents.dsse.agent_policy_wrappers import ( | |
| DSSERandomPolicyWrapper, DSSEGreedySearchPolicyWrapper, | |
| DSSESweepPolicyWrapper, | |
| ) | |
| from common.save_load_utils import load_checkpoints, REPO_PATH | |
| from envs.overcooked.augmented_layouts import augmented_layouts | |
| log = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| def _validate_teammate_path(path: str) -> str: | |
| """Validate checkpoint path and fail fast if it does not exist.""" | |
| resolved_path = path if os.path.isabs(path) else os.path.join(REPO_PATH, path) | |
| if not os.path.exists(resolved_path): | |
| raise FileNotFoundError( | |
| f"Checkpoint path does not exist: {path}. " | |
| "Use the new eval_teammates/... layout." | |
| ) | |
| return path | |
| def process_idx_list(idx_list): | |
| idxs = np.array(idx_list) | |
| if idxs.ndim == 1: | |
| return idxs | |
| elif idxs.ndim == 2: | |
| rows = idxs[:, 0] | |
| cols = idxs[:, 1] | |
| return (rows, cols) | |
| else: | |
| raise ValueError(f"Invalid index list shape: {idxs.shape}") | |
| def create_idx_labels(idx_list, checkpoint_shape): | |
| """Create human-readable index labels based on the checkpoint extraction pattern. | |
| Args: | |
| idx_list: The list of indices used to extract checkpoints, or None/"all" for all checkpoints | |
| checkpoint_shape: The shape of the checkpoint array | |
| Returns: | |
| Array of string labels with the same shape as the original checkpoint array, | |
| or filtered according to idx_list | |
| """ | |
| # If loading all checkpoints, create labels with the same shape as the original checkpoints | |
| if idx_list is None: | |
| # Handle different dimensional checkpoints | |
| if len(checkpoint_shape) >= 3: # At least 2D of checkpoints (e.g., seeds × steps) | |
| # Create a 2D array of labels with same shape as the first two dimensions | |
| rows, cols = checkpoint_shape[0], checkpoint_shape[1] | |
| idx_labels = [] | |
| for row_idx in range(rows): | |
| row_labels = [] | |
| for col_idx in range(cols): | |
| row_labels.append(f"{row_idx}, {col_idx}") | |
| idx_labels.append(row_labels) | |
| else: # 1D of checkpoints | |
| # Create a 1D array of labels | |
| idx_labels = [f"{i}" for i in range(checkpoint_shape[0])] | |
| # If loading specific checkpoints, create labels by converting idx_list to strings | |
| elif isinstance(idx_list[0], list): | |
| # For 2D indices like [[0, -1], [1, -1], [2, -1]] | |
| idx_labels = [f"{idx[0]}, {idx[1]}" for idx in idx_list] | |
| else: | |
| # For 1D indices like [0, 1, 2] | |
| idx_labels = [f"{idx}" for idx in idx_list] | |
| return idx_labels | |
| def initialize_heuristic_agent_from_config(agent_config, agent_name, task_name, env_kwargs=None): | |
| '''Load a heuristic (non-RL) agent from config, dispatching on task_name. | |
| agent_config must include "actor_type". | |
| env_kwargs is used as a fallback for env-level parameters (e.g. grid_size, | |
| num_fruits for lbf; layout for overcooked-v1). Per-agent values in | |
| agent_config take priority. Note that the LBF enviornment calls this | |
| 'num_food'. | |
| Returns: | |
| policy: policy function (no checkpoint or params required) | |
| ''' | |
| assert "actor_type" in agent_config, "Actor type must be provided." | |
| actor_type = agent_config["actor_type"] | |
| if env_kwargs is None: | |
| env_kwargs = {} | |
| if 'lbf' in task_name: | |
| # Grid dimensions: per-agent config > env_kwargs > defaults (7x7, 3 fruits). | |
| grid_size = agent_config.get("grid_size", env_kwargs.get("grid_size", 7)) | |
| num_fruits = agent_config.get("num_fruits", env_kwargs.get("num_food", 3)) | |
| if actor_type == "random_agent": | |
| return LBFRandomPolicyWrapper() | |
| if actor_type == "seq_agent": | |
| ordering_strategy = agent_config.get("ordering_strategy", "lexicographic") | |
| return LBFSequentialFruitPolicyWrapper( | |
| grid_size=grid_size, | |
| num_fruits=num_fruits, | |
| ordering_strategy=ordering_strategy, | |
| using_log_wrapper=True, | |
| ) | |
| if actor_type == "entitled_agent": | |
| return LBFEntitledPolicyWrapper( | |
| grid_size=grid_size, | |
| num_fruits=num_fruits, | |
| using_log_wrapper=True, | |
| ) | |
| if actor_type == "greedy_agent": | |
| heuristic = agent_config.get("heuristic", "closest_self") | |
| return LBFGreedyHeuristicPolicyWrapper( | |
| grid_size=grid_size, | |
| num_fruits=num_fruits, | |
| heuristic=heuristic, | |
| using_log_wrapper=True, | |
| ) | |
| raise ValueError(f"Unrecognized actor type for {task_name}: '{actor_type}' ({agent_name})") | |
| if 'overcooked-v1' in task_name: | |
| aug_layout_dict = augmented_layouts[env_kwargs["layout"]] | |
| if actor_type == "random_agent": | |
| return OvercookedRandomPolicyWrapper(aug_layout_dict, using_log_wrapper=True) | |
| if actor_type == "static_agent": | |
| return OvercookedStaticPolicyWrapper(aug_layout_dict, using_log_wrapper=True) | |
| if actor_type == "independent_agent": | |
| return OvercookedIndependentPolicyWrapper( | |
| aug_layout_dict, | |
| using_log_wrapper=True, | |
| p_onion_on_counter=agent_config.get("p_onion_on_counter", 0.0), | |
| p_plate_on_counter=agent_config.get("p_plate_on_counter", 0.0), | |
| ) | |
| if actor_type == "onion_agent": | |
| return OvercookedOnionPolicyWrapper( | |
| aug_layout_dict, | |
| using_log_wrapper=True, | |
| p_onion_on_counter=agent_config.get("p_onion_on_counter", 0.0), | |
| ) | |
| if actor_type == "plate_agent": | |
| return OvercookedPlatePolicyWrapper( | |
| aug_layout_dict, | |
| using_log_wrapper=True, | |
| p_plate_on_counter=agent_config.get("p_plate_on_counter", 0.0), | |
| ) | |
| raise ValueError(f"Unrecognized actor type for {task_name}: '{actor_type}' ({agent_name})") | |
| if task_name == 'dsse': | |
| grid_size = agent_config.get("grid_size", env_kwargs.get("grid_size", 7)) | |
| if actor_type == "random_agent": | |
| return DSSERandomPolicyWrapper(using_log_wrapper=True) | |
| if actor_type == "greedy_search_agent": | |
| return DSSEGreedySearchPolicyWrapper(grid_size=grid_size, using_log_wrapper=True) | |
| if actor_type == "sweep_agent": | |
| return DSSESweepPolicyWrapper(grid_size=grid_size, using_log_wrapper=True) | |
| raise ValueError(f"Unrecognized actor type for {task_name}: '{actor_type}' ({agent_name})") | |
| if 'hanabi' in task_name: | |
| # Default to full Hanabi shape; mini-hanabi callers must pass num_colors, | |
| # num_ranks, hand_size, num_actions through agent_config or env_kwargs. | |
| hand_size = agent_config.get("hand_size", env_kwargs.get("hand_size", 5)) | |
| num_colors = agent_config.get("num_colors", env_kwargs.get("num_colors", 5)) | |
| num_ranks = agent_config.get("num_ranks", env_kwargs.get("num_ranks", 5)) | |
| # Action layout: discard + play + color hints + rank hints + noop. | |
| num_actions = agent_config.get( | |
| "num_actions", 2 * hand_size + num_colors + num_ranks + 1 | |
| ) | |
| common = dict( | |
| hand_size=hand_size, num_colors=num_colors, num_ranks=num_ranks, | |
| num_actions=num_actions, using_log_wrapper=True, | |
| ) | |
| if actor_type == "random_agent": | |
| return HanabiRandomPolicyWrapper( | |
| num_actions=num_actions, using_log_wrapper=True | |
| ) | |
| if actor_type == "rule_based": | |
| return HanabiRuleBasedPolicyWrapper( | |
| strategy=agent_config.get("strategy", "cautious"), **common | |
| ) | |
| if actor_type == "iggi": | |
| return HanabiIGGIPolicyWrapper(**common) | |
| if actor_type == "piers": | |
| return HanabiPiersPolicyWrapper( | |
| play_threshold=agent_config.get("play_threshold", 0.6), | |
| hint_threshold=agent_config.get("hint_threshold", 4), | |
| **common, | |
| ) | |
| if actor_type == "flawed": | |
| return HanabiFlawedPolicyWrapper( | |
| play_threshold=agent_config.get("play_threshold", 0.4), **common | |
| ) | |
| if actor_type == "outer": | |
| return HanabiOuterPolicyWrapper(**common) | |
| if actor_type == "van_den_bergh": | |
| return HanabiVanDenBerghPolicyWrapper(**common) | |
| if actor_type == "internal": | |
| return HanabiInternalPolicyWrapper(**common) | |
| if actor_type == "cautious": | |
| return HanabiCautiousPolicyWrapper(**common) | |
| if actor_type == "smartbot": | |
| return HanabiSmartBotPolicyWrapper( | |
| card_counts=agent_config.get("card_counts", None), **common | |
| ) | |
| if actor_type == "obl_r2d2": | |
| return HanabiOBLPolicyWrapper( | |
| weight_file=agent_config["weight_file"], using_log_wrapper=True | |
| ) | |
| if actor_type == "bc_lstm": | |
| return HanabiBCLSTMPolicyWrapper( | |
| weight_file=agent_config["weight_file"], | |
| using_log_wrapper=True, | |
| greedy=agent_config.get("greedy", True), | |
| ) | |
| raise ValueError(f"Unrecognized actor type for {task_name}: '{actor_type}' ({agent_name})") | |
| raise ValueError( | |
| f"Unknown task '{task_name}' for heuristic agent {agent_name}. " | |
| f"Expected 'lbf', 'overcooked-v1', 'dsse', or a task containing 'hanabi'." | |
| ) | |
| def initialize_rl_agent_from_config(agent_config, agent_name, env, rng): | |
| '''Load RL agent from checkpoint and initialize from config. | |
| The agent_config dictionary should have the following structure: | |
| { | |
| "path": str, | |
| "actor_type": str, | |
| "ckpt_key": str, # key to load from checkpoint. Default is "checkpoints". | |
| "custom_loader": dict, # custom loader for the checkpoint. Default is None. | |
| "idx_list": list, # list of indices to load from checkpoint. If null, all checkpoints will be loaded. | |
| # and any other parameters needed to initialize the agent policy | |
| } | |
| Returns: | |
| policy: policy function | |
| agent_params: agent parameters from checkpoint | |
| init_params: initial agent parameters from initialization | |
| idx_list: list of indices used to extract checkpoints | |
| idx_labels: list of string labels corresponding to the indices | |
| ''' | |
| assert "path" in agent_config, "Path to agent checkpoint must be provided." | |
| assert "actor_type" in agent_config, "Actor type must be provided." | |
| assert "idx_list" in agent_config, "Indices to load from checkpoint must be provided." | |
| agent_path = _validate_teammate_path(agent_config["path"]) | |
| ckpt_key = agent_config.get("ckpt_key", "checkpoints") | |
| custom_loader_cfg = agent_config.get("custom_loader", None) | |
| agent_ckpt = load_checkpoints(agent_path, ckpt_key=ckpt_key, custom_loader_cfg=custom_loader_cfg) | |
| leaf0_shape = jax.tree.leaves(agent_ckpt)[0].shape | |
| if agent_config["idx_list"] is None: # load all checkpoints | |
| idx_list = None | |
| agent_params = agent_ckpt | |
| else: # load specific checkpoints | |
| # convert omegaconf list config to list recursively | |
| try: | |
| idx_list = OmegaConf.to_object(agent_config["idx_list"]) | |
| except Exception as e: | |
| log.warning(f"Error interpreting agent_config['idx_list'] as OmegaConf object: {e}. Treating as list.") | |
| idx_list = agent_config["idx_list"] | |
| idx_list = jax.tree.map(lambda x: int(x), idx_list) | |
| idxs = process_idx_list(idx_list) | |
| agent_params = jax.tree.map(lambda x: x[idxs], agent_ckpt) | |
| log.info(f"Loaded {agent_name} checkpoint where leaf 0 has shape {leaf0_shape}. " | |
| f" Selecting indices {idx_list if idx_list is not None else 'all'} for evaluation.") | |
| # Create index labels for the loaded checkpoints | |
| idx_labels = create_idx_labels(idx_list, leaf0_shape) | |
| rng, init_rng = jax.random.split(rng, 2) | |
| if agent_config["actor_type"] == "s5": | |
| policy, init_params = initialize_s5_agent(agent_config, env, init_rng) | |
| # Make compatible with old naming for S5 layers | |
| if "action_body_0" in agent_params['params'].keys(): # CLEANUP FLAG | |
| agent_param_keys = list(agent_params['params'].keys()) | |
| for k in agent_param_keys: | |
| if "body" in k: | |
| new_k = k.replace("body", "body_layers") | |
| agent_params['params'][new_k] = agent_params['params'][k] | |
| del agent_params['params'][k] | |
| elif agent_config["actor_type"] == "mlp": | |
| policy, init_params = initialize_mlp_agent(agent_config, env, init_rng) | |
| elif agent_config["actor_type"] == "rnn": | |
| policy, init_params = initialize_rnn_agent(agent_config, env, init_rng) | |
| elif agent_config["actor_type"] == "actor_with_double_critic": | |
| policy, init_params = initialize_actor_with_double_critic(agent_config, env, init_rng) | |
| elif agent_config["actor_type"] == "actor_with_conditional_critic": | |
| policy, init_params = initialize_actor_with_conditional_critic(agent_config, env, init_rng) | |
| elif agent_config["actor_type"] == "liam": | |
| policy, init_params = initialize_liam_agent(agent_config, env, init_rng) | |
| elif agent_config["actor_type"] == "meliba": | |
| policy, init_params = initialize_meliba_agent(agent_config, env, init_rng) | |
| else: | |
| raise ValueError(f"Invalid actor type: {agent_config['actor_type']}") | |
| assert jax.tree.structure(agent_params) == jax.tree.structure(init_params), "Agent parameters and initial parameters must have the same structure." | |
| return policy, agent_params, init_params, idx_labels | |