Spaces:
Running
Running
File size: 6,328 Bytes
5146e76 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | import os
from functools import partial
import jax
import jax.numpy as jnp
import seaborn as sns
import matplotlib.pyplot as plt
def get_metric_names(env_name):
if "lbf" in env_name:
return ("percent_eaten", "returned_episode_returns")
elif env_name == "overcooked-v1":
return ("base_return", "returned_episode_returns")
elif env_name == "dsse":
return ("targets_found", "returned_episode_returns")
elif env_name == "hanabi":
return ("returned_episode_returns",)
else:
return ("returned_episode_returns",)
@partial(jax.jit, static_argnames=['stats'])
def get_stats(metrics, stats: tuple):
'''
Computes mean and std of metrics of interest for each seed and update,
using only the final steps of episodes. Note that each rollout contains multiple episodes.
metrics is a pytree where each leaf has shape
(..., rollout_length, num_envs)
stats is a tuple of strings, each corresponding to a metric of interest in metrics
'''
# Get mask for final steps of episodes
mask = metrics.get("returned_episode", None)
# Initialize output dictionary
all_stats = {}
stats = list(stats) # convert to list to correctly iterate if the tuple only has a single element
# Detect condensed metrics: if mask is scalar per update (no rollout/env dims),
# the metrics are already episode-masked means from scan condensation.
condensed = (mask.ndim <= 2)
for stat_name in stats:
# Get the metric array
metric_data = metrics[stat_name] # Shape: (..., rollout_length, num_envs) or (num_seeds, num_updates)
if metric_data.ndim == 2:
# Data is already pre-averaged over rollout and envs
means = metric_data
stds = jnp.zeros_like(metric_data)
else:
# Compute means and stds for each seed and update
# Use masked operations to only consider final episode steps
mask_sum = jnp.maximum(1, mask.sum(axis=(-2, -1)))
means = jnp.where(mask, metric_data, 0).sum(axis=(-2, -1)) / mask_sum
# For std, first compute masked values
masked_vals = jnp.where(mask, metric_data, 0)
squared_diff = (masked_vals - means[..., None, None]) ** 2
variance = jnp.where(mask, squared_diff, 0).sum(axis=(-2, -1)) / mask_sum
stds = jnp.sqrt(variance)
# Stack means and stds
all_stats[stat_name] = jnp.stack([means, stds], axis=-1)
return all_stats
def plot_train_metrics(all_stats,
num_rollout_steps, num_envs,
savedir=None, savename=None,
show_plots=False
):
'''Each key in all_stats is a metric name, and the value is an array of shape (num_seeds, num_updates, 2)'''
figures = {}
for stat_name, stats in all_stats.items():
stat_name = stat_name.replace("_", " ").title()
num_seeds, num_updates, _ = stats.shape
for i in range(num_seeds):
print("Seed: ", i)
print(f"Mean {stat_name} (Last Episode Step): ", stats[i, -1, 0])
print(f"Std {stat_name} (Last Episode Step): ", stats[i, -1, 1])
xs = jnp.arange(num_updates) * num_envs * num_rollout_steps
means = stats[i, :, 0]
stds = stats[i, :, 1]
# Calculate upper and lower bounds for the shaded region
upper_bound = means + stds
lower_bound = means - stds
# Create the plot
plt.plot(xs, means, label=f"Seed {i}")
# Shade the region between the bounds
plt.fill_between(xs, lower_bound, upper_bound,
alpha=0.3)
plt.xlabel("Time Step")
plt.ylabel(stat_name)
plt.title(f"Learning Curve for {stat_name}")
plt.legend()
# Get the current figure
fig = plt.gcf()
# Save the figure if requested
savepath = None
if savedir is not None and savename is not None:
savepath = os.path.join(savedir, f"{savename}_{stat_name}.pdf")
plt.savefig(savepath)
figures[stat_name] = fig
if show_plots:
plt.show()
plt.close(fig)
return figures, savepath
def plot_xp_matrix(xp_matrix, xlabel, ylabel, title,
higher_is_better=True,
savedir=None, savename=None,
show_plots=False
):
if higher_is_better:
colormap="coolwarm_r"
arrow_str = r" ($\uparrow$)"
else:
colormap="coolwarm"
arrow_str = r" ($\downarrow$)"
# Plot as heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(xp_matrix, cmap=colormap, annot=False)
plt.gca().invert_yaxis()
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title + arrow_str)
# Get the current figure
fig = plt.gcf()
# Save the figure if requested
savepath = None
if savedir is not None and savename is not None:
savepath = os.path.join(savedir, f"{savename}.pdf")
plt.savefig(savepath)
if show_plots:
plt.show()
plt.close(fig)
return fig, savepath
def plot_xp_from_eval_metrics(eval_metrics, metric_name, higher_is_better=True, agent_idx=0,
savedir=None, savename=None,
show_plots=True):
'''
Note that the FCP agent is always agent 0, the partner is agent 1.
eval_metrics is a dictionary with keys corresponding to metric names
and values as arrays of shape (num_seeds, num_fcp_checkpoints, num_eval_checkpoints, num_episodes, num_agents)
'''
# Select agent 0's data and compute mean over seeds and episodes
heatmap_data = jnp.mean(eval_metrics[metric_name][:, :, :, :, agent_idx], axis=(0, 3))
fig, savepath = plot_xp_matrix(heatmap_data,
xlabel="Eval Checkpoint", ylabel="Ego Agent Checkpoint",
title=f"Average {metric_name.replace('_', ' ').title()}",
higher_is_better=higher_is_better,
savedir=savedir, savename=savename,
show_plots=show_plots)
return fig, savepath
|