Spaces:
Running
Running
| '''This script implements evaluating ego agents against heldout agents. | |
| Warning: ActorCritic agents that rely on auxiliary information to compute actions are not currently supported. | |
| ''' | |
| import jax | |
| import numpy as np | |
| from prettytable import PrettyTable | |
| from functools import partial | |
| import time | |
| import os | |
| import hydra | |
| from common.agent_loader_from_config import ( | |
| initialize_rl_agent_from_config, | |
| initialize_heuristic_agent_from_config, | |
| ) | |
| from common.run_episodes import run_episodes | |
| from common.tree_utils import tree_stack | |
| from common.plot_utils import get_metric_names | |
| from common.stat_utils import compute_aggregate_stat_and_ci_per_task | |
| from envs import make_env | |
| from envs.log_wrapper import LogWrapper | |
| def extract_params(params, init_params, idx_labels=None): | |
| '''params is a pytree of n model checkpoints, where each leaf has an unknown number | |
| of checkpoint dimensions, and the last dimension corresponds to the layer dimension. | |
| This function extracts each of the n checkpoints and returns a list of n pytrees, | |
| where each pytree has the same structure as init_params. | |
| Args: | |
| params: pytree of n checkpoints (n >= 1) | |
| init_params: pytree corresp. to ONE checkpoint. used as a reference for the structure of the output pytrees. | |
| idx_labels: array of string labels with the same shape as the original checkpoints. If None, numeric indices will be used. | |
| Returns: | |
| Tuple of: | |
| - list of n pytrees with same structure as init_params | |
| - list of n index labels identifying the original location of each checkpoint | |
| ''' | |
| assert jax.tree.structure(params) == jax.tree.structure(init_params), "Params and init_params must have the same structure." | |
| model_list = [] | |
| flattened_idx_labels = [] | |
| params_shape = jax.tree.leaves(params)[0].shape | |
| init_params_shape = jax.tree.leaves(init_params)[0].shape | |
| # already matches init_params_shape, no extraction needed | |
| if params_shape == init_params_shape: | |
| model_list = [params] | |
| n_models = 1 | |
| if idx_labels is not None: | |
| flattened_idx_labels = idx_labels | |
| # multiple models, extract each one | |
| else: | |
| # first, flatten the params so that each leaf has shape (..., init_params_shape) | |
| flattened_params = jax.tree.map(lambda x, y: x.reshape((-1,) + y.shape), params, init_params) | |
| # then, extract each model | |
| n_models = jax.tree.leaves(flattened_params)[0].shape[0] | |
| # Now, flatten the idx_labels to match the flattened parameters | |
| if idx_labels is not None: | |
| flattened_idx_labels = np.array(idx_labels).reshape(n_models) | |
| # Extract each model | |
| for i in range(n_models): | |
| model_i = jax.tree.map(lambda x: x[i], flattened_params) | |
| model_list.append(model_i) | |
| if idx_labels is None: | |
| flattened_idx_labels = [str(i) for i in range(n_models)] | |
| return model_list, flattened_idx_labels | |
| def extract_performance_bounds(agent_config, n_models): | |
| '''Flatten performance bounds dictionary into n_models dictionaries. | |
| Each leaf has the same structure as idx_list. | |
| ''' | |
| performance_bounds = agent_config.get("performance_bounds", None) | |
| if performance_bounds is None: | |
| return [None for _ in range(n_models)] | |
| else: | |
| ret_list = [] | |
| for i in range(n_models): | |
| perf_i = {} | |
| for stat_name, bound_list in performance_bounds.items(): | |
| assert len(bound_list[i]) == 2, "Performance bounds must be a list of two values (upper and lower bounds)." | |
| perf_i[stat_name] = bound_list[i] | |
| ret_list.append(perf_i) | |
| return ret_list | |
| def load_heldout_set(heldout_config, env, task_name, env_kwargs, rng): | |
| '''Load heldout evaluation agents from config. | |
| Returns a dictionary of agents with keys as agent names and values as tuples of | |
| (policy, params, test_mode). | |
| ''' | |
| heldout_agents = {} | |
| for agent_name, agent_config in heldout_config.items(): | |
| # Allow env-specific configs to null out entries inherited from a | |
| # base config (skip entries set to null in the task-specific block). | |
| if agent_config is None: | |
| continue | |
| params_list = None | |
| idx_labels = None | |
| test_mode = agent_config.get("test_mode", False) | |
| # Load RL-based agents | |
| if "path" in agent_config: | |
| # ensure that each rl agent has a unique initialization rng | |
| rng, init_rng = jax.random.split(rng) | |
| policy, params, init_params, idx_labels = initialize_rl_agent_from_config(agent_config, agent_name, env, init_rng) | |
| # params contains multiple model checkpoints, so we need to extract each one | |
| params_list, idx_labels = extract_params(params, init_params, idx_labels) | |
| performance_bounds_list = extract_performance_bounds(agent_config, len(params_list)) | |
| # Load non-RL-based heuristic agents | |
| else: | |
| performance_bounds = agent_config.get("performance_bounds", None) | |
| policy = initialize_heuristic_agent_from_config( | |
| agent_config, agent_name, task_name, env_kwargs | |
| ) | |
| # Generate agent labels | |
| if params_list is None: # heuristic agent | |
| heldout_agents[agent_name] = (policy, None, test_mode, performance_bounds) | |
| else: # rl agent | |
| for i, params_i in enumerate(params_list): | |
| if idx_labels is None: | |
| agent_label = f'{agent_name} ({i})' | |
| else: | |
| agent_label = f'{agent_name} ({idx_labels[i]})' | |
| heldout_agents[agent_label] = (policy, params_i, test_mode, performance_bounds_list[i]) | |
| return heldout_agents | |
| def normalize_metrics(metrics, performance_bounds): | |
| '''For the metrics in performance_bounds, normalize the metrics in eval_metrics | |
| using the performance bounds.''' | |
| for k, v in performance_bounds.items(): | |
| lower, upper = v[0], v[1] | |
| metrics[k] = (metrics[k] - lower) / (upper - lower) | |
| return metrics | |
| def eval_egos_vs_heldouts(config, env, rng, num_episodes, ego_policy, ego_params, | |
| heldout_agent_list, heldout_agent_names=None, ego_test_mode=False): | |
| '''Evaluate all ego agents against all heldout partners using vmap over egos. | |
| Ego_params must be a pytree of shape (num_ego_agents, ...) | |
| ''' | |
| num_agents = env.num_agents | |
| assert num_agents == 2, "This eval code assumes exactly 2 agents." | |
| num_ego_agents = jax.tree.leaves(ego_params)[0].shape[0] | |
| num_partner_total = len(heldout_agent_list) | |
| def _eval_ego_vs_one_partner(single_ego_policy, single_ego_params, rng_for_ego, | |
| heldout_policy, heldout_params, heldout_test_mode): | |
| return run_episodes(rng_for_ego, env, | |
| agent_0_policy=single_ego_policy, agent_0_param=single_ego_params, | |
| agent_1_policy=heldout_policy, agent_1_param=heldout_params, | |
| max_episode_steps=config["global_heldout_settings"]["MAX_EPISODE_STEPS"], | |
| num_eps=num_episodes, | |
| agent_0_test_mode=ego_test_mode, | |
| agent_1_test_mode=heldout_test_mode) | |
| # Outer Python loop over heterogeneous heldout partners | |
| all_metrics_for_partners = [] | |
| rng, sub_rng = jax.random.split(rng) | |
| partner_rngs = jax.random.split(sub_rng, num_partner_total) | |
| start_time = time.time() | |
| for partner_idx in range(num_partner_total): | |
| heldout_policy, heldout_params, heldout_test_mode, heldout_performance_bounds = heldout_agent_list[partner_idx] | |
| ego_rngs = jax.random.split(partner_rngs[partner_idx], num_ego_agents) | |
| # Use partial to fix the heldout agent for the function being vmapped | |
| func_to_vmap = partial(_eval_ego_vs_one_partner, | |
| heldout_policy=heldout_policy, | |
| heldout_params=heldout_params, | |
| heldout_test_mode=heldout_test_mode) | |
| # vmap over the stacked ego agents and their RNGs | |
| results_for_this_partner = jax.vmap( | |
| func_to_vmap, | |
| in_axes=(None, 0, 0) # Map over axis 0 of ego_policies, ego_params, ego_rngs | |
| )(ego_policy, ego_params, ego_rngs) | |
| # results_for_this_partner shape: (num_ego_agents, num_episodes, ...) | |
| if config["global_heldout_settings"]["NORMALIZE_RETURNS"]: | |
| if heldout_performance_bounds is not None: | |
| results_for_this_partner = normalize_metrics(results_for_this_partner, heldout_performance_bounds) | |
| else: | |
| agent_name = heldout_agent_names[partner_idx] if heldout_agent_names is not None else f"partner_{partner_idx}" | |
| print(f"Warning: no performance bounds provided for {agent_name}. Skipping normalization.") | |
| all_metrics_for_partners.append(results_for_this_partner) | |
| end_time = time.time() | |
| print(f"Time taken for vmap evaluation loop: {end_time - start_time:.2f} seconds") | |
| # Result shape: (num_partners, num_egos, num_episodes, ...) | |
| final_metrics = tree_stack(all_metrics_for_partners) | |
| # Transpose to (num_egos, num_partners, num_episodes, ...) | |
| final_metrics = jax.tree.map(lambda x: x.transpose(1, 0, 2, 3), final_metrics) | |
| return final_metrics | |
| def run_heldout_evaluation(config, print_metrics=False): | |
| '''Run heldout evaluation''' | |
| # Create only one environment instance | |
| env = make_env(config["ENV_NAME"], config["ENV_KWARGS"]) | |
| env = LogWrapper(env) | |
| rng = jax.random.PRNGKey(config["global_heldout_settings"]["EVAL_SEED"]) | |
| rng, ego_init_rng, heldout_init_rng, eval_rng = jax.random.split(rng, 4) | |
| # load ego agents | |
| ego_agent_config = dict(config["ego_agent"]) | |
| ego_test_mode = ego_agent_config.get("test_mode", False) | |
| ego_policy, ego_params, init_ego_params, ego_idx_labels = initialize_rl_agent_from_config(ego_agent_config, "ego", env, ego_init_rng) | |
| # flatten ego params and idx labels | |
| ego_idx_labels = np.array(ego_idx_labels).reshape(-1) # flatten the list of ego agent labels | |
| flattened_ego_params = jax.tree.map(lambda x, y: x.reshape((-1,) + y.shape), ego_params, init_ego_params) | |
| # load heldout agents | |
| heldout_cfg = config["heldout_set"][config["TASK_NAME"]] | |
| heldout_agents = load_heldout_set(heldout_cfg, env, config["TASK_NAME"], config["ENV_KWARGS"], heldout_init_rng) | |
| heldout_agent_names = list(heldout_agents.keys()) | |
| heldout_agent_list = list(heldout_agents.values()) | |
| # run evaluation | |
| eval_metrics = eval_egos_vs_heldouts( | |
| config, env, eval_rng, config["global_heldout_settings"]["NUM_EVAL_EPISODES"], | |
| ego_policy, flattened_ego_params, heldout_agent_list, heldout_agent_names, ego_test_mode) | |
| if print_metrics: | |
| # each leaf of eval_metrics has shape (num_ego_agents, num_heldout_agents, num_eval_episodes, num_agents_per_env) | |
| metric_names = get_metric_names(config["ENV_NAME"]) | |
| aggregate_stat = config["global_heldout_settings"]["AGGREGATE_STAT"] | |
| ego_names = [f"ego ({label})" for label in ego_idx_labels] | |
| heldout_names = list(heldout_agents.keys()) | |
| for metric_name in metric_names: | |
| print_metrics_table(eval_metrics, metric_name, ego_names, heldout_names, | |
| aggregate_stat, config["global_heldout_settings"]["NORMALIZE_RETURNS"]) | |
| return eval_metrics | |
| def print_metrics_table(eval_metrics, metric_name, ego_names, heldout_names, | |
| aggregate_stat: str, normalized_metrics: bool, | |
| save: bool = False, save_heatmap: bool = False): | |
| '''Generate a table of the aggregate stat and CI of the metric for each ego agent and heldout agent.''' | |
| # eval_metrics[metric_name] shape (num_ego_agents, num_heldout_agents, num_eval_episodes, num_agents_per_env) | |
| # we first take the mean over the num_agents_per_env dimension | |
| eval_metric_data = np.array(eval_metrics[metric_name]).mean(axis=-1) # shape (num_ego_agents, num_heldout_agents, num_eval_episodes, 2) | |
| table = PrettyTable() | |
| table.field_names = ["---", *heldout_names] | |
| tidy_rows = [] | |
| for i, ego_name in enumerate(ego_names): | |
| data = eval_metric_data[i].transpose(1, 0) # shape (num_eval_episodes, num_heldout_agents) | |
| point_est_all, interval_ests_all = compute_aggregate_stat_and_ci_per_task(data, aggregate_stat, return_interval_est=True) | |
| lower_ci = interval_ests_all[:, 0] | |
| upper_ci = interval_ests_all[:, 1] | |
| row = [ego_name] + [f"{point_est_all[j]:.2f} ({lower_ci[j]:.2f}, {upper_ci[j]:.2f})" for j in range(len(heldout_names))] | |
| table.add_row(row) | |
| for j, heldout_name in enumerate(heldout_names): | |
| tidy_rows.append({ | |
| "row_agent": ego_name, | |
| "col_agent": heldout_name, | |
| "metric_name": metric_name, | |
| "aggregate_stat": aggregate_stat, | |
| "normalized": normalized_metrics, | |
| "mean": float(point_est_all[j]), | |
| "ci_lower": float(lower_ci[j]), | |
| "ci_upper": float(upper_ci[j]), | |
| }) | |
| print(f"\n{metric_name} ({aggregate_stat} ± CI):") | |
| if normalized_metrics: | |
| print("Metrics are normalized to [lower_bound, upper_bound].") | |
| print(table) | |
| if save: | |
| output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir | |
| # if not os.path.exists(output_dir): | |
| # os.makedirs(output_dir) | |
| # Sanitize metric_name for use in filename | |
| safe_metric_name = "".join(c if c.isalnum() else "_" for c in metric_name) | |
| csv_filename = os.path.join(output_dir, f"{safe_metric_name}_{aggregate_stat}_normalized={normalized_metrics}.csv") | |
| with open(csv_filename, 'w', newline='') as f_output: | |
| f_output.write(table.get_csv_string()) | |
| print(f"Table saved to {csv_filename}") | |
| tidy_csv_filename = os.path.join(output_dir, f"{safe_metric_name}_{aggregate_stat}_normalized={normalized_metrics}_tidy.csv") | |
| import csv | |
| with open(tidy_csv_filename, 'w', newline='') as tidy_file: | |
| writer = csv.DictWriter( | |
| tidy_file, | |
| fieldnames=[ | |
| "row_agent", | |
| "col_agent", | |
| "metric_name", | |
| "aggregate_stat", | |
| "normalized", | |
| "mean", | |
| "ci_lower", | |
| "ci_upper", | |
| ], | |
| ) | |
| writer.writeheader() | |
| writer.writerows(tidy_rows) | |
| print(f"Tidy table saved to {tidy_csv_filename}") | |
| if save_heatmap: | |
| try: | |
| from pathlib import Path | |
| from evaluation.plot_xp_csv_heatmap import generate_heatmap_from_csv | |
| heatmap_title = f"XP Matrix: {metric_name} ({aggregate_stat})" | |
| png_path = generate_heatmap_from_csv(Path(csv_filename), title=heatmap_title) | |
| print(f"Heatmap saved to {png_path}") | |
| except Exception as exc: | |
| print(f"Warning: failed to generate heatmap for {csv_filename}: {exc}") | |