Spaces:
Running
Running
| ''''Implementation of heldout evaluation helper functions used by learners.''' | |
| import logging | |
| import time | |
| from functools import partial | |
| import shutil | |
| import jax | |
| import numpy as np | |
| import hydra | |
| from common.save_load_utils import save_train_run | |
| from common.run_episodes import run_episodes | |
| from common.tree_utils import tree_stack | |
| from common.stat_utils import compute_aggregate_stat_and_ci, compute_aggregate_stat_and_ci_per_task, get_aggregate_stat_fn | |
| from envs import make_env | |
| from envs.log_wrapper import LogWrapper | |
| from evaluation.heldout_evaluator import load_heldout_set, normalize_metrics, eval_egos_vs_heldouts as eval_1d_egos_vs_heldouts | |
| log = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| def eval_2d_egos_vs_heldouts(config, env, rng, num_episodes, ego_policy, ego_params, | |
| heldout_agent_list, ego_test_mode=False): | |
| '''Evaluate all ego agents against all heldout partners using vmap over egos. | |
| Ego_params must be a pytree of shape (num_seeds, num_oel_iters, ...) | |
| ''' | |
| num_agents = env.num_agents | |
| assert num_agents == 2, "This eval code assumes exactly 2 agents." | |
| num_ego_seeds, num_ego_iters = jax.tree.leaves(ego_params)[0].shape[:2] | |
| tot_ego_agents = num_ego_seeds * num_ego_iters | |
| num_partner_total = len(heldout_agent_list) | |
| def _eval_ego_vs_one_partner(rng_for_ego, single_ego_params, single_ego_policy, | |
| heldout_params, heldout_policy, heldout_test_mode): | |
| return run_episodes(rng_for_ego, env, | |
| agent_0_policy=single_ego_policy, agent_0_param=single_ego_params, | |
| agent_1_policy=heldout_policy, agent_1_param=heldout_params, | |
| max_episode_steps=config["global_heldout_settings"]["MAX_EPISODE_STEPS"], | |
| num_eps=num_episodes, | |
| agent_0_test_mode=ego_test_mode, | |
| agent_1_test_mode=heldout_test_mode) | |
| # Outer Python loop over heterogeneous heldout partners | |
| all_metrics_for_partners = [] | |
| partner_rngs = jax.random.split(rng, num_partner_total) | |
| start_time = time.time() | |
| for partner_idx in range(num_partner_total): | |
| heldout_policy, heldout_params, heldout_test_mode, heldout_performance_bounds = heldout_agent_list[partner_idx] | |
| ego_rngs = jax.random.split(partner_rngs[partner_idx], tot_ego_agents) | |
| ego_rngs = ego_rngs.reshape(num_ego_seeds, num_ego_iters, 2) | |
| # Use partial to fix the heldout agent for the function being vmapped | |
| func_to_vmap = partial(_eval_ego_vs_one_partner, | |
| single_ego_policy=ego_policy, | |
| heldout_params=heldout_params, | |
| heldout_policy=heldout_policy, | |
| heldout_test_mode=heldout_test_mode) | |
| # Inner vmap: Maps over the 'num_oel_iters' dimension. | |
| # Operates on the partially applied function `eval_partial`. | |
| vmap_over_iters = jax.vmap( | |
| func_to_vmap, | |
| in_axes=(0, 0) # Map over axis 0 of single_ego_params and rng_for_ego | |
| ) | |
| # Outer vmap: Maps the 'vmap_over_iters' function over the 'num_seeds' dimension. | |
| vmap_over_seeds_and_iters = jax.vmap( | |
| vmap_over_iters, | |
| in_axes=(0, 0) # Map over axis 0 of ego_params and ego_rngs | |
| ) | |
| # Execute the nested vmap | |
| results_for_this_partner = vmap_over_seeds_and_iters( | |
| ego_rngs, # shape (num_seeds, num_oel_iters, 2) | |
| ego_params # shape (num_seeds, num_oel_iters, ...) | |
| ) | |
| # results_for_this_partner shape: (num_seeds, num_oel_iters, num_episodes, ...) | |
| if config["global_heldout_settings"]["NORMALIZE_RETURNS"]: | |
| if heldout_performance_bounds is not None: | |
| results_for_this_partner = normalize_metrics(results_for_this_partner, heldout_performance_bounds) | |
| else: | |
| print(f"Warning: no performance bounds provided for {heldout_agent_list[partner_idx]}. Skipping normalization.") | |
| all_metrics_for_partners.append(results_for_this_partner) | |
| end_time = time.time() | |
| print(f"Time taken for vmap evaluation loop: {end_time - start_time:.2f} seconds") | |
| # Result shape: (num_partners, num_seeds, num_oel_iters, num_episodes, ...) | |
| final_metrics = tree_stack(all_metrics_for_partners) | |
| # Transpose to (num_seeds, num_oel_iters, num_partners, num_episodes, ...) | |
| final_metrics = jax.tree.map(lambda x: x.transpose(1, 2, 0, 3, 4), final_metrics) | |
| return final_metrics | |
| def run_heldout_evaluation(config, ego_policy, ego_params, init_ego_params, | |
| ego_as_2d: bool, ego_test_mode=False): | |
| '''Run heldout evaluation given an ego policy, ego params, and init_ego_params. | |
| Ego_params can be a pytree of shape (num_seeds, num_oel_iters, ...) or (num_seeds, ...). | |
| Args: | |
| config: Configuration dictionary | |
| ego_policy: Policy for the ego agent | |
| ego_params: Parameters for the ego agent | |
| init_ego_params: Initial parameters for the ego agent | |
| ego_as_2d: Whether to treat the ego agent params as a 2D or 1D array of ego agents | |
| ego_test_mode: Whether the ego agent should run in test mode (default: False) | |
| ''' | |
| log.info("Running heldout evaluation...") | |
| env = make_env(config["ENV_NAME"], config["ENV_KWARGS"]) | |
| env = LogWrapper(env) | |
| rng = jax.random.PRNGKey(config["global_heldout_settings"]["EVAL_SEED"]) | |
| rng, heldout_init_rng, eval_rng = jax.random.split(rng, 3) | |
| if ego_as_2d: | |
| num_seeds, num_oel_iters = jax.tree.leaves(ego_params)[0].shape[:2] | |
| ego_names = [f"ego (seed={i}, iter={j})" for i in range(num_seeds) for j in range(num_oel_iters)] | |
| else: | |
| # flatten ego params | |
| ego_params = jax.tree.map(lambda x, y: x.reshape((-1,) + y.shape), ego_params, init_ego_params) | |
| num_ego_agents = jax.tree.leaves(ego_params)[0].shape[0] | |
| ego_names = [f"ego ({i})" for i in range(num_ego_agents)] | |
| # load heldout agents | |
| heldout_cfg = config["heldout_set"][config["TASK_NAME"]] | |
| heldout_agents = load_heldout_set(heldout_cfg, env, config["TASK_NAME"], config["ENV_KWARGS"], heldout_init_rng) | |
| heldout_agent_list = list(heldout_agents.values()) | |
| heldout_names = list(heldout_agents.keys()) | |
| # run evaluation | |
| if ego_as_2d: | |
| eval_metrics = eval_2d_egos_vs_heldouts(config, env, eval_rng, config["global_heldout_settings"]["NUM_EVAL_EPISODES"], | |
| ego_policy, ego_params, heldout_agent_list, ego_test_mode) | |
| else: | |
| eval_metrics = eval_1d_egos_vs_heldouts(config, env, eval_rng, config["global_heldout_settings"]["NUM_EVAL_EPISODES"], | |
| ego_policy, ego_params, heldout_agent_list, ego_test_mode) | |
| return eval_metrics, ego_names, heldout_names | |
| def log_heldout_metrics(config, logger, eval_metrics, | |
| ego_names, heldout_names, metric_names: tuple, | |
| ego_as_2d: bool): | |
| '''Log heldout evaluation metrics.''' | |
| if ego_as_2d: | |
| table_data = heldout_metrics_2d(config, logger, eval_metrics, ego_names, heldout_names, metric_names) | |
| else: | |
| table_data = heldout_metrics_1d(config, logger, eval_metrics, ego_names, heldout_names, metric_names) | |
| # table_data shape (num_metrics, num_heldout_agents) | |
| # Add metric name column to the table data | |
| metric_names_array = np.array(metric_names).reshape(-1, 1) # Convert to column vector | |
| # Add algo name column to the table data | |
| algo_name = config["algorithm"]["ALG"] | |
| algo_name_array = np.full_like(metric_names_array, algo_name) | |
| # Log table | |
| table_data_with_names = np.hstack((algo_name_array, metric_names_array, table_data)) | |
| # Additionally log each metric separately for parameter sweep analysis | |
| for i in range(table_data_with_names.shape[0]): | |
| logger.log_item(f"HeldoutEval/FinalEgoVsHeldout/{table_data_with_names[i, 1]}/mean", float(table_data_with_names[i, 2].split()[0])) | |
| aggregate_stat = config["global_heldout_settings"]["AGGREGATE_STAT"] | |
| logger.log_xp_matrix(f"HeldoutEval/FinalEgoVsHeldout-{aggregate_stat.capitalize()}-CI", table_data_with_names, | |
| columns=["Algorithm", "Metric", f"{aggregate_stat.capitalize()} (all)"] + list(heldout_names), commit=True) | |
| # Saving artifacts | |
| savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir | |
| out_savepath = save_train_run(eval_metrics, savedir, savename="heldout_eval_metrics") | |
| if config["logger"]["log_eval_out"]: | |
| logger.log_artifact(name="heldout_eval_metrics", path=out_savepath, type_name="eval_metrics") | |
| # Cleanup locally logged out file | |
| if not config["local_logger"]["save_eval_out"]: | |
| shutil.rmtree(out_savepath) | |
| def heldout_metrics_1d(config, logger, eval_metrics, | |
| ego_names, heldout_names, metric_names: tuple): | |
| '''Treat the first dimension of eval_metrics as (num_seeds, ...). | |
| Returns the data for a table where the rows are the metrics and the columns are the heldout agents. | |
| ''' | |
| num_seeds, num_heldout_agents, num_eval_episodes, _ = eval_metrics[metric_names[0]].shape | |
| table_data = [] | |
| aggregate_stat = config["global_heldout_settings"]["AGGREGATE_STAT"] | |
| for metric_name in metric_names: | |
| # shape of eval_metrics[metric_name] is (num_seeds, num_heldout_agents, num_eval_episodes, num_agents_per_game) | |
| # we first take the mean over the num_agents_per_game dimension | |
| data = eval_metrics[metric_name].mean(axis=-1 | |
| ).transpose(0, 2, 1 | |
| ).reshape(-1, num_heldout_agents) # final shape (num_seeds*num_eval_episodes, num_heldout_agents) | |
| data = np.array(data) | |
| # compute per-heldout-agent aggregate stat+CIs | |
| point_est_per_task, interval_ests_per_task = compute_aggregate_stat_and_ci_per_task(data, aggregate_stat, return_interval_est=True) | |
| lower_ci = interval_ests_per_task[:, 0] | |
| upper_ci = interval_ests_per_task[:, 1] | |
| col_strs = [f"{point_est_per_task[i]:.3f} ({lower_ci[i]:.3f}, {upper_ci[i]:.3f})" for i in range(len(point_est_per_task))] | |
| # compute aggregate stat+CI over all heldout agents | |
| point_est_all, interval_ests_all = compute_aggregate_stat_and_ci(data, aggregate_stat, return_interval_est=True) | |
| lower_ci = interval_ests_all[0] | |
| upper_ci = interval_ests_all[1] | |
| col_strs.insert(0, f"{point_est_all:.3f} ({lower_ci:.3f}, {upper_ci:.3f})") | |
| table_data.append(col_strs) | |
| return np.array(table_data) | |
| def heldout_metrics_2d(config, logger, eval_metrics, | |
| ego_names, heldout_names, metric_names: tuple): | |
| '''Treat the first two dimensions of eval_metrics as (seeds, iters, ...) dimensions. | |
| Logs a curve for each metric over the iters dimension. | |
| Returns the data for a table where the rows are the metrics and the columns are the heldout agents. | |
| ''' | |
| num_seeds, num_oel_iter, num_heldout_agents, \ | |
| num_eval_episodes, num_agents_per_game = eval_metrics[metric_names[0]].shape | |
| table_data = [] | |
| aggregate_stat = config["global_heldout_settings"]["AGGREGATE_STAT"] | |
| aggregate_stat_fn = get_aggregate_stat_fn(aggregate_stat) | |
| for metric_name in metric_names: | |
| # shape of eval_metrics[metric_name] is | |
| # (num_seeds, num_oel_iter, num_heldout_agents, num_eval_episodes, num_agents_per_game) | |
| for i in range(num_oel_iter): | |
| # we first take the mean over the num_agents_per_game dimension | |
| data = eval_metrics[metric_name][:, i].mean(axis=-1 | |
| ).transpose(0, 2, 1 | |
| ).reshape(-1, num_heldout_agents) # final shape (num_seeds*num_eval_episodes, num_heldout_agents) | |
| data = np.array(data) | |
| point_est = aggregate_stat_fn(data) | |
| # log curve aggregated over all heldout agents | |
| logger.log_item(f"HeldoutEval/AvgEgo_{metric_name}_", point_est, iter=i) | |
| # now compute per-heldout-agent aggregate stat+CIs corresponding to the LAST ego iter | |
| last_iter_data = data | |
| point_est_per_task, interval_ests_per_task = compute_aggregate_stat_and_ci_per_task(last_iter_data, aggregate_stat, return_interval_est=True) | |
| lower_ci = interval_ests_per_task[:, 0] | |
| upper_ci = interval_ests_per_task[:, 1] | |
| col_strs = [f"{point_est_per_task[i]:.3f} ({lower_ci[i]:.3f}, {upper_ci[i]:.3f})" for i in range(len(point_est_per_task))] | |
| # compute aggregate stat+CI over all heldout agents | |
| point_est_all, interval_ests_all = compute_aggregate_stat_and_ci(last_iter_data, aggregate_stat, return_interval_est=True) | |
| lower_ci = interval_ests_all[0] | |
| upper_ci = interval_ests_all[1] | |
| col_strs.insert(0, f"{point_est_all:.3f} ({lower_ci:.3f}, {upper_ci:.3f})") | |
| table_data.append(col_strs) | |
| return np.array(table_data) | |