jaxaht-benchmark / evaluation /heldout_eval.py
lainwired's picture
Initial jaxaht-benchmark deployment
5146e76
''''Implementation of heldout evaluation helper functions used by learners.'''
import logging
import time
from functools import partial
import shutil
import jax
import numpy as np
import hydra
from common.save_load_utils import save_train_run
from common.run_episodes import run_episodes
from common.tree_utils import tree_stack
from common.stat_utils import compute_aggregate_stat_and_ci, compute_aggregate_stat_and_ci_per_task, get_aggregate_stat_fn
from envs import make_env
from envs.log_wrapper import LogWrapper
from evaluation.heldout_evaluator import load_heldout_set, normalize_metrics, eval_egos_vs_heldouts as eval_1d_egos_vs_heldouts
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def eval_2d_egos_vs_heldouts(config, env, rng, num_episodes, ego_policy, ego_params,
heldout_agent_list, ego_test_mode=False):
'''Evaluate all ego agents against all heldout partners using vmap over egos.
Ego_params must be a pytree of shape (num_seeds, num_oel_iters, ...)
'''
num_agents = env.num_agents
assert num_agents == 2, "This eval code assumes exactly 2 agents."
num_ego_seeds, num_ego_iters = jax.tree.leaves(ego_params)[0].shape[:2]
tot_ego_agents = num_ego_seeds * num_ego_iters
num_partner_total = len(heldout_agent_list)
def _eval_ego_vs_one_partner(rng_for_ego, single_ego_params, single_ego_policy,
heldout_params, heldout_policy, heldout_test_mode):
return run_episodes(rng_for_ego, env,
agent_0_policy=single_ego_policy, agent_0_param=single_ego_params,
agent_1_policy=heldout_policy, agent_1_param=heldout_params,
max_episode_steps=config["global_heldout_settings"]["MAX_EPISODE_STEPS"],
num_eps=num_episodes,
agent_0_test_mode=ego_test_mode,
agent_1_test_mode=heldout_test_mode)
# Outer Python loop over heterogeneous heldout partners
all_metrics_for_partners = []
partner_rngs = jax.random.split(rng, num_partner_total)
start_time = time.time()
for partner_idx in range(num_partner_total):
heldout_policy, heldout_params, heldout_test_mode, heldout_performance_bounds = heldout_agent_list[partner_idx]
ego_rngs = jax.random.split(partner_rngs[partner_idx], tot_ego_agents)
ego_rngs = ego_rngs.reshape(num_ego_seeds, num_ego_iters, 2)
# Use partial to fix the heldout agent for the function being vmapped
func_to_vmap = partial(_eval_ego_vs_one_partner,
single_ego_policy=ego_policy,
heldout_params=heldout_params,
heldout_policy=heldout_policy,
heldout_test_mode=heldout_test_mode)
# Inner vmap: Maps over the 'num_oel_iters' dimension.
# Operates on the partially applied function `eval_partial`.
vmap_over_iters = jax.vmap(
func_to_vmap,
in_axes=(0, 0) # Map over axis 0 of single_ego_params and rng_for_ego
)
# Outer vmap: Maps the 'vmap_over_iters' function over the 'num_seeds' dimension.
vmap_over_seeds_and_iters = jax.vmap(
vmap_over_iters,
in_axes=(0, 0) # Map over axis 0 of ego_params and ego_rngs
)
# Execute the nested vmap
results_for_this_partner = vmap_over_seeds_and_iters(
ego_rngs, # shape (num_seeds, num_oel_iters, 2)
ego_params # shape (num_seeds, num_oel_iters, ...)
)
# results_for_this_partner shape: (num_seeds, num_oel_iters, num_episodes, ...)
if config["global_heldout_settings"]["NORMALIZE_RETURNS"]:
if heldout_performance_bounds is not None:
results_for_this_partner = normalize_metrics(results_for_this_partner, heldout_performance_bounds)
else:
print(f"Warning: no performance bounds provided for {heldout_agent_list[partner_idx]}. Skipping normalization.")
all_metrics_for_partners.append(results_for_this_partner)
end_time = time.time()
print(f"Time taken for vmap evaluation loop: {end_time - start_time:.2f} seconds")
# Result shape: (num_partners, num_seeds, num_oel_iters, num_episodes, ...)
final_metrics = tree_stack(all_metrics_for_partners)
# Transpose to (num_seeds, num_oel_iters, num_partners, num_episodes, ...)
final_metrics = jax.tree.map(lambda x: x.transpose(1, 2, 0, 3, 4), final_metrics)
return final_metrics
def run_heldout_evaluation(config, ego_policy, ego_params, init_ego_params,
ego_as_2d: bool, ego_test_mode=False):
'''Run heldout evaluation given an ego policy, ego params, and init_ego_params.
Ego_params can be a pytree of shape (num_seeds, num_oel_iters, ...) or (num_seeds, ...).
Args:
config: Configuration dictionary
ego_policy: Policy for the ego agent
ego_params: Parameters for the ego agent
init_ego_params: Initial parameters for the ego agent
ego_as_2d: Whether to treat the ego agent params as a 2D or 1D array of ego agents
ego_test_mode: Whether the ego agent should run in test mode (default: False)
'''
log.info("Running heldout evaluation...")
env = make_env(config["ENV_NAME"], config["ENV_KWARGS"])
env = LogWrapper(env)
rng = jax.random.PRNGKey(config["global_heldout_settings"]["EVAL_SEED"])
rng, heldout_init_rng, eval_rng = jax.random.split(rng, 3)
if ego_as_2d:
num_seeds, num_oel_iters = jax.tree.leaves(ego_params)[0].shape[:2]
ego_names = [f"ego (seed={i}, iter={j})" for i in range(num_seeds) for j in range(num_oel_iters)]
else:
# flatten ego params
ego_params = jax.tree.map(lambda x, y: x.reshape((-1,) + y.shape), ego_params, init_ego_params)
num_ego_agents = jax.tree.leaves(ego_params)[0].shape[0]
ego_names = [f"ego ({i})" for i in range(num_ego_agents)]
# load heldout agents
heldout_cfg = config["heldout_set"][config["TASK_NAME"]]
heldout_agents = load_heldout_set(heldout_cfg, env, config["TASK_NAME"], config["ENV_KWARGS"], heldout_init_rng)
heldout_agent_list = list(heldout_agents.values())
heldout_names = list(heldout_agents.keys())
# run evaluation
if ego_as_2d:
eval_metrics = eval_2d_egos_vs_heldouts(config, env, eval_rng, config["global_heldout_settings"]["NUM_EVAL_EPISODES"],
ego_policy, ego_params, heldout_agent_list, ego_test_mode)
else:
eval_metrics = eval_1d_egos_vs_heldouts(config, env, eval_rng, config["global_heldout_settings"]["NUM_EVAL_EPISODES"],
ego_policy, ego_params, heldout_agent_list, ego_test_mode)
return eval_metrics, ego_names, heldout_names
def log_heldout_metrics(config, logger, eval_metrics,
ego_names, heldout_names, metric_names: tuple,
ego_as_2d: bool):
'''Log heldout evaluation metrics.'''
if ego_as_2d:
table_data = heldout_metrics_2d(config, logger, eval_metrics, ego_names, heldout_names, metric_names)
else:
table_data = heldout_metrics_1d(config, logger, eval_metrics, ego_names, heldout_names, metric_names)
# table_data shape (num_metrics, num_heldout_agents)
# Add metric name column to the table data
metric_names_array = np.array(metric_names).reshape(-1, 1) # Convert to column vector
# Add algo name column to the table data
algo_name = config["algorithm"]["ALG"]
algo_name_array = np.full_like(metric_names_array, algo_name)
# Log table
table_data_with_names = np.hstack((algo_name_array, metric_names_array, table_data))
# Additionally log each metric separately for parameter sweep analysis
for i in range(table_data_with_names.shape[0]):
logger.log_item(f"HeldoutEval/FinalEgoVsHeldout/{table_data_with_names[i, 1]}/mean", float(table_data_with_names[i, 2].split()[0]))
aggregate_stat = config["global_heldout_settings"]["AGGREGATE_STAT"]
logger.log_xp_matrix(f"HeldoutEval/FinalEgoVsHeldout-{aggregate_stat.capitalize()}-CI", table_data_with_names,
columns=["Algorithm", "Metric", f"{aggregate_stat.capitalize()} (all)"] + list(heldout_names), commit=True)
# Saving artifacts
savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
out_savepath = save_train_run(eval_metrics, savedir, savename="heldout_eval_metrics")
if config["logger"]["log_eval_out"]:
logger.log_artifact(name="heldout_eval_metrics", path=out_savepath, type_name="eval_metrics")
# Cleanup locally logged out file
if not config["local_logger"]["save_eval_out"]:
shutil.rmtree(out_savepath)
def heldout_metrics_1d(config, logger, eval_metrics,
ego_names, heldout_names, metric_names: tuple):
'''Treat the first dimension of eval_metrics as (num_seeds, ...).
Returns the data for a table where the rows are the metrics and the columns are the heldout agents.
'''
num_seeds, num_heldout_agents, num_eval_episodes, _ = eval_metrics[metric_names[0]].shape
table_data = []
aggregate_stat = config["global_heldout_settings"]["AGGREGATE_STAT"]
for metric_name in metric_names:
# shape of eval_metrics[metric_name] is (num_seeds, num_heldout_agents, num_eval_episodes, num_agents_per_game)
# we first take the mean over the num_agents_per_game dimension
data = eval_metrics[metric_name].mean(axis=-1
).transpose(0, 2, 1
).reshape(-1, num_heldout_agents) # final shape (num_seeds*num_eval_episodes, num_heldout_agents)
data = np.array(data)
# compute per-heldout-agent aggregate stat+CIs
point_est_per_task, interval_ests_per_task = compute_aggregate_stat_and_ci_per_task(data, aggregate_stat, return_interval_est=True)
lower_ci = interval_ests_per_task[:, 0]
upper_ci = interval_ests_per_task[:, 1]
col_strs = [f"{point_est_per_task[i]:.3f} ({lower_ci[i]:.3f}, {upper_ci[i]:.3f})" for i in range(len(point_est_per_task))]
# compute aggregate stat+CI over all heldout agents
point_est_all, interval_ests_all = compute_aggregate_stat_and_ci(data, aggregate_stat, return_interval_est=True)
lower_ci = interval_ests_all[0]
upper_ci = interval_ests_all[1]
col_strs.insert(0, f"{point_est_all:.3f} ({lower_ci:.3f}, {upper_ci:.3f})")
table_data.append(col_strs)
return np.array(table_data)
def heldout_metrics_2d(config, logger, eval_metrics,
ego_names, heldout_names, metric_names: tuple):
'''Treat the first two dimensions of eval_metrics as (seeds, iters, ...) dimensions.
Logs a curve for each metric over the iters dimension.
Returns the data for a table where the rows are the metrics and the columns are the heldout agents.
'''
num_seeds, num_oel_iter, num_heldout_agents, \
num_eval_episodes, num_agents_per_game = eval_metrics[metric_names[0]].shape
table_data = []
aggregate_stat = config["global_heldout_settings"]["AGGREGATE_STAT"]
aggregate_stat_fn = get_aggregate_stat_fn(aggregate_stat)
for metric_name in metric_names:
# shape of eval_metrics[metric_name] is
# (num_seeds, num_oel_iter, num_heldout_agents, num_eval_episodes, num_agents_per_game)
for i in range(num_oel_iter):
# we first take the mean over the num_agents_per_game dimension
data = eval_metrics[metric_name][:, i].mean(axis=-1
).transpose(0, 2, 1
).reshape(-1, num_heldout_agents) # final shape (num_seeds*num_eval_episodes, num_heldout_agents)
data = np.array(data)
point_est = aggregate_stat_fn(data)
# log curve aggregated over all heldout agents
logger.log_item(f"HeldoutEval/AvgEgo_{metric_name}_", point_est, iter=i)
# now compute per-heldout-agent aggregate stat+CIs corresponding to the LAST ego iter
last_iter_data = data
point_est_per_task, interval_ests_per_task = compute_aggregate_stat_and_ci_per_task(last_iter_data, aggregate_stat, return_interval_est=True)
lower_ci = interval_ests_per_task[:, 0]
upper_ci = interval_ests_per_task[:, 1]
col_strs = [f"{point_est_per_task[i]:.3f} ({lower_ci[i]:.3f}, {upper_ci[i]:.3f})" for i in range(len(point_est_per_task))]
# compute aggregate stat+CI over all heldout agents
point_est_all, interval_ests_all = compute_aggregate_stat_and_ci(last_iter_data, aggregate_stat, return_interval_est=True)
lower_ci = interval_ests_all[0]
upper_ci = interval_ests_all[1]
col_strs.insert(0, f"{point_est_all:.3f} ({lower_ci:.3f}, {upper_ci:.3f})")
table_data.append(col_strs)
return np.array(table_data)