jaxaht-benchmark / evaluation /heldout_evaluator.py
lainwired's picture
Initial jaxaht-benchmark deployment
5146e76
'''This script implements evaluating ego agents against heldout agents.
Warning: ActorCritic agents that rely on auxiliary information to compute actions are not currently supported.
'''
import jax
import numpy as np
from prettytable import PrettyTable
from functools import partial
import time
import os
import hydra
from common.agent_loader_from_config import (
initialize_rl_agent_from_config,
initialize_heuristic_agent_from_config,
)
from common.run_episodes import run_episodes
from common.tree_utils import tree_stack
from common.plot_utils import get_metric_names
from common.stat_utils import compute_aggregate_stat_and_ci_per_task
from envs import make_env
from envs.log_wrapper import LogWrapper
def extract_params(params, init_params, idx_labels=None):
'''params is a pytree of n model checkpoints, where each leaf has an unknown number
of checkpoint dimensions, and the last dimension corresponds to the layer dimension.
This function extracts each of the n checkpoints and returns a list of n pytrees,
where each pytree has the same structure as init_params.
Args:
params: pytree of n checkpoints (n >= 1)
init_params: pytree corresp. to ONE checkpoint. used as a reference for the structure of the output pytrees.
idx_labels: array of string labels with the same shape as the original checkpoints. If None, numeric indices will be used.
Returns:
Tuple of:
- list of n pytrees with same structure as init_params
- list of n index labels identifying the original location of each checkpoint
'''
assert jax.tree.structure(params) == jax.tree.structure(init_params), "Params and init_params must have the same structure."
model_list = []
flattened_idx_labels = []
params_shape = jax.tree.leaves(params)[0].shape
init_params_shape = jax.tree.leaves(init_params)[0].shape
# already matches init_params_shape, no extraction needed
if params_shape == init_params_shape:
model_list = [params]
n_models = 1
if idx_labels is not None:
flattened_idx_labels = idx_labels
# multiple models, extract each one
else:
# first, flatten the params so that each leaf has shape (..., init_params_shape)
flattened_params = jax.tree.map(lambda x, y: x.reshape((-1,) + y.shape), params, init_params)
# then, extract each model
n_models = jax.tree.leaves(flattened_params)[0].shape[0]
# Now, flatten the idx_labels to match the flattened parameters
if idx_labels is not None:
flattened_idx_labels = np.array(idx_labels).reshape(n_models)
# Extract each model
for i in range(n_models):
model_i = jax.tree.map(lambda x: x[i], flattened_params)
model_list.append(model_i)
if idx_labels is None:
flattened_idx_labels = [str(i) for i in range(n_models)]
return model_list, flattened_idx_labels
def extract_performance_bounds(agent_config, n_models):
'''Flatten performance bounds dictionary into n_models dictionaries.
Each leaf has the same structure as idx_list.
'''
performance_bounds = agent_config.get("performance_bounds", None)
if performance_bounds is None:
return [None for _ in range(n_models)]
else:
ret_list = []
for i in range(n_models):
perf_i = {}
for stat_name, bound_list in performance_bounds.items():
assert len(bound_list[i]) == 2, "Performance bounds must be a list of two values (upper and lower bounds)."
perf_i[stat_name] = bound_list[i]
ret_list.append(perf_i)
return ret_list
def load_heldout_set(heldout_config, env, task_name, env_kwargs, rng):
'''Load heldout evaluation agents from config.
Returns a dictionary of agents with keys as agent names and values as tuples of
(policy, params, test_mode).
'''
heldout_agents = {}
for agent_name, agent_config in heldout_config.items():
# Allow env-specific configs to null out entries inherited from a
# base config (skip entries set to null in the task-specific block).
if agent_config is None:
continue
params_list = None
idx_labels = None
test_mode = agent_config.get("test_mode", False)
# Load RL-based agents
if "path" in agent_config:
# ensure that each rl agent has a unique initialization rng
rng, init_rng = jax.random.split(rng)
policy, params, init_params, idx_labels = initialize_rl_agent_from_config(agent_config, agent_name, env, init_rng)
# params contains multiple model checkpoints, so we need to extract each one
params_list, idx_labels = extract_params(params, init_params, idx_labels)
performance_bounds_list = extract_performance_bounds(agent_config, len(params_list))
# Load non-RL-based heuristic agents
else:
performance_bounds = agent_config.get("performance_bounds", None)
policy = initialize_heuristic_agent_from_config(
agent_config, agent_name, task_name, env_kwargs
)
# Generate agent labels
if params_list is None: # heuristic agent
heldout_agents[agent_name] = (policy, None, test_mode, performance_bounds)
else: # rl agent
for i, params_i in enumerate(params_list):
if idx_labels is None:
agent_label = f'{agent_name} ({i})'
else:
agent_label = f'{agent_name} ({idx_labels[i]})'
heldout_agents[agent_label] = (policy, params_i, test_mode, performance_bounds_list[i])
return heldout_agents
def normalize_metrics(metrics, performance_bounds):
'''For the metrics in performance_bounds, normalize the metrics in eval_metrics
using the performance bounds.'''
for k, v in performance_bounds.items():
lower, upper = v[0], v[1]
metrics[k] = (metrics[k] - lower) / (upper - lower)
return metrics
def eval_egos_vs_heldouts(config, env, rng, num_episodes, ego_policy, ego_params,
heldout_agent_list, heldout_agent_names=None, ego_test_mode=False):
'''Evaluate all ego agents against all heldout partners using vmap over egos.
Ego_params must be a pytree of shape (num_ego_agents, ...)
'''
num_agents = env.num_agents
assert num_agents == 2, "This eval code assumes exactly 2 agents."
num_ego_agents = jax.tree.leaves(ego_params)[0].shape[0]
num_partner_total = len(heldout_agent_list)
def _eval_ego_vs_one_partner(single_ego_policy, single_ego_params, rng_for_ego,
heldout_policy, heldout_params, heldout_test_mode):
return run_episodes(rng_for_ego, env,
agent_0_policy=single_ego_policy, agent_0_param=single_ego_params,
agent_1_policy=heldout_policy, agent_1_param=heldout_params,
max_episode_steps=config["global_heldout_settings"]["MAX_EPISODE_STEPS"],
num_eps=num_episodes,
agent_0_test_mode=ego_test_mode,
agent_1_test_mode=heldout_test_mode)
# Outer Python loop over heterogeneous heldout partners
all_metrics_for_partners = []
rng, sub_rng = jax.random.split(rng)
partner_rngs = jax.random.split(sub_rng, num_partner_total)
start_time = time.time()
for partner_idx in range(num_partner_total):
heldout_policy, heldout_params, heldout_test_mode, heldout_performance_bounds = heldout_agent_list[partner_idx]
ego_rngs = jax.random.split(partner_rngs[partner_idx], num_ego_agents)
# Use partial to fix the heldout agent for the function being vmapped
func_to_vmap = partial(_eval_ego_vs_one_partner,
heldout_policy=heldout_policy,
heldout_params=heldout_params,
heldout_test_mode=heldout_test_mode)
# vmap over the stacked ego agents and their RNGs
results_for_this_partner = jax.vmap(
func_to_vmap,
in_axes=(None, 0, 0) # Map over axis 0 of ego_policies, ego_params, ego_rngs
)(ego_policy, ego_params, ego_rngs)
# results_for_this_partner shape: (num_ego_agents, num_episodes, ...)
if config["global_heldout_settings"]["NORMALIZE_RETURNS"]:
if heldout_performance_bounds is not None:
results_for_this_partner = normalize_metrics(results_for_this_partner, heldout_performance_bounds)
else:
agent_name = heldout_agent_names[partner_idx] if heldout_agent_names is not None else f"partner_{partner_idx}"
print(f"Warning: no performance bounds provided for {agent_name}. Skipping normalization.")
all_metrics_for_partners.append(results_for_this_partner)
end_time = time.time()
print(f"Time taken for vmap evaluation loop: {end_time - start_time:.2f} seconds")
# Result shape: (num_partners, num_egos, num_episodes, ...)
final_metrics = tree_stack(all_metrics_for_partners)
# Transpose to (num_egos, num_partners, num_episodes, ...)
final_metrics = jax.tree.map(lambda x: x.transpose(1, 0, 2, 3), final_metrics)
return final_metrics
def run_heldout_evaluation(config, print_metrics=False):
'''Run heldout evaluation'''
# Create only one environment instance
env = make_env(config["ENV_NAME"], config["ENV_KWARGS"])
env = LogWrapper(env)
rng = jax.random.PRNGKey(config["global_heldout_settings"]["EVAL_SEED"])
rng, ego_init_rng, heldout_init_rng, eval_rng = jax.random.split(rng, 4)
# load ego agents
ego_agent_config = dict(config["ego_agent"])
ego_test_mode = ego_agent_config.get("test_mode", False)
ego_policy, ego_params, init_ego_params, ego_idx_labels = initialize_rl_agent_from_config(ego_agent_config, "ego", env, ego_init_rng)
# flatten ego params and idx labels
ego_idx_labels = np.array(ego_idx_labels).reshape(-1) # flatten the list of ego agent labels
flattened_ego_params = jax.tree.map(lambda x, y: x.reshape((-1,) + y.shape), ego_params, init_ego_params)
# load heldout agents
heldout_cfg = config["heldout_set"][config["TASK_NAME"]]
heldout_agents = load_heldout_set(heldout_cfg, env, config["TASK_NAME"], config["ENV_KWARGS"], heldout_init_rng)
heldout_agent_names = list(heldout_agents.keys())
heldout_agent_list = list(heldout_agents.values())
# run evaluation
eval_metrics = eval_egos_vs_heldouts(
config, env, eval_rng, config["global_heldout_settings"]["NUM_EVAL_EPISODES"],
ego_policy, flattened_ego_params, heldout_agent_list, heldout_agent_names, ego_test_mode)
if print_metrics:
# each leaf of eval_metrics has shape (num_ego_agents, num_heldout_agents, num_eval_episodes, num_agents_per_env)
metric_names = get_metric_names(config["ENV_NAME"])
aggregate_stat = config["global_heldout_settings"]["AGGREGATE_STAT"]
ego_names = [f"ego ({label})" for label in ego_idx_labels]
heldout_names = list(heldout_agents.keys())
for metric_name in metric_names:
print_metrics_table(eval_metrics, metric_name, ego_names, heldout_names,
aggregate_stat, config["global_heldout_settings"]["NORMALIZE_RETURNS"])
return eval_metrics
def print_metrics_table(eval_metrics, metric_name, ego_names, heldout_names,
aggregate_stat: str, normalized_metrics: bool,
save: bool = False, save_heatmap: bool = False):
'''Generate a table of the aggregate stat and CI of the metric for each ego agent and heldout agent.'''
# eval_metrics[metric_name] shape (num_ego_agents, num_heldout_agents, num_eval_episodes, num_agents_per_env)
# we first take the mean over the num_agents_per_env dimension
eval_metric_data = np.array(eval_metrics[metric_name]).mean(axis=-1) # shape (num_ego_agents, num_heldout_agents, num_eval_episodes, 2)
table = PrettyTable()
table.field_names = ["---", *heldout_names]
tidy_rows = []
for i, ego_name in enumerate(ego_names):
data = eval_metric_data[i].transpose(1, 0) # shape (num_eval_episodes, num_heldout_agents)
point_est_all, interval_ests_all = compute_aggregate_stat_and_ci_per_task(data, aggregate_stat, return_interval_est=True)
lower_ci = interval_ests_all[:, 0]
upper_ci = interval_ests_all[:, 1]
row = [ego_name] + [f"{point_est_all[j]:.2f} ({lower_ci[j]:.2f}, {upper_ci[j]:.2f})" for j in range(len(heldout_names))]
table.add_row(row)
for j, heldout_name in enumerate(heldout_names):
tidy_rows.append({
"row_agent": ego_name,
"col_agent": heldout_name,
"metric_name": metric_name,
"aggregate_stat": aggregate_stat,
"normalized": normalized_metrics,
"mean": float(point_est_all[j]),
"ci_lower": float(lower_ci[j]),
"ci_upper": float(upper_ci[j]),
})
print(f"\n{metric_name} ({aggregate_stat} ± CI):")
if normalized_metrics:
print("Metrics are normalized to [lower_bound, upper_bound].")
print(table)
if save:
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
# if not os.path.exists(output_dir):
# os.makedirs(output_dir)
# Sanitize metric_name for use in filename
safe_metric_name = "".join(c if c.isalnum() else "_" for c in metric_name)
csv_filename = os.path.join(output_dir, f"{safe_metric_name}_{aggregate_stat}_normalized={normalized_metrics}.csv")
with open(csv_filename, 'w', newline='') as f_output:
f_output.write(table.get_csv_string())
print(f"Table saved to {csv_filename}")
tidy_csv_filename = os.path.join(output_dir, f"{safe_metric_name}_{aggregate_stat}_normalized={normalized_metrics}_tidy.csv")
import csv
with open(tidy_csv_filename, 'w', newline='') as tidy_file:
writer = csv.DictWriter(
tidy_file,
fieldnames=[
"row_agent",
"col_agent",
"metric_name",
"aggregate_stat",
"normalized",
"mean",
"ci_lower",
"ci_upper",
],
)
writer.writeheader()
writer.writerows(tidy_rows)
print(f"Tidy table saved to {tidy_csv_filename}")
if save_heatmap:
try:
from pathlib import Path
from evaluation.plot_xp_csv_heatmap import generate_heatmap_from_csv
heatmap_title = f"XP Matrix: {metric_name} ({aggregate_stat})"
png_path = generate_heatmap_from_csv(Path(csv_filename), title=heatmap_title)
print(f"Heatmap saved to {png_path}")
except Exception as exc:
print(f"Warning: failed to generate heatmap for {csv_filename}: {exc}")