Spaces:
Running
Running
| import os | |
| from functools import partial | |
| import jax | |
| import jax.numpy as jnp | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| def get_metric_names(env_name): | |
| if "lbf" in env_name: | |
| return ("percent_eaten", "returned_episode_returns") | |
| elif env_name == "overcooked-v1": | |
| return ("base_return", "returned_episode_returns") | |
| elif env_name == "dsse": | |
| return ("targets_found", "returned_episode_returns") | |
| elif env_name == "hanabi": | |
| return ("returned_episode_returns",) | |
| else: | |
| return ("returned_episode_returns",) | |
| def get_stats(metrics, stats: tuple): | |
| ''' | |
| Computes mean and std of metrics of interest for each seed and update, | |
| using only the final steps of episodes. Note that each rollout contains multiple episodes. | |
| metrics is a pytree where each leaf has shape | |
| (..., rollout_length, num_envs) | |
| stats is a tuple of strings, each corresponding to a metric of interest in metrics | |
| ''' | |
| # Get mask for final steps of episodes | |
| mask = metrics.get("returned_episode", None) | |
| # Initialize output dictionary | |
| all_stats = {} | |
| stats = list(stats) # convert to list to correctly iterate if the tuple only has a single element | |
| # Detect condensed metrics: if mask is scalar per update (no rollout/env dims), | |
| # the metrics are already episode-masked means from scan condensation. | |
| condensed = (mask.ndim <= 2) | |
| for stat_name in stats: | |
| # Get the metric array | |
| metric_data = metrics[stat_name] # Shape: (..., rollout_length, num_envs) or (num_seeds, num_updates) | |
| if metric_data.ndim == 2: | |
| # Data is already pre-averaged over rollout and envs | |
| means = metric_data | |
| stds = jnp.zeros_like(metric_data) | |
| else: | |
| # Compute means and stds for each seed and update | |
| # Use masked operations to only consider final episode steps | |
| mask_sum = jnp.maximum(1, mask.sum(axis=(-2, -1))) | |
| means = jnp.where(mask, metric_data, 0).sum(axis=(-2, -1)) / mask_sum | |
| # For std, first compute masked values | |
| masked_vals = jnp.where(mask, metric_data, 0) | |
| squared_diff = (masked_vals - means[..., None, None]) ** 2 | |
| variance = jnp.where(mask, squared_diff, 0).sum(axis=(-2, -1)) / mask_sum | |
| stds = jnp.sqrt(variance) | |
| # Stack means and stds | |
| all_stats[stat_name] = jnp.stack([means, stds], axis=-1) | |
| return all_stats | |
| def plot_train_metrics(all_stats, | |
| num_rollout_steps, num_envs, | |
| savedir=None, savename=None, | |
| show_plots=False | |
| ): | |
| '''Each key in all_stats is a metric name, and the value is an array of shape (num_seeds, num_updates, 2)''' | |
| figures = {} | |
| for stat_name, stats in all_stats.items(): | |
| stat_name = stat_name.replace("_", " ").title() | |
| num_seeds, num_updates, _ = stats.shape | |
| for i in range(num_seeds): | |
| print("Seed: ", i) | |
| print(f"Mean {stat_name} (Last Episode Step): ", stats[i, -1, 0]) | |
| print(f"Std {stat_name} (Last Episode Step): ", stats[i, -1, 1]) | |
| xs = jnp.arange(num_updates) * num_envs * num_rollout_steps | |
| means = stats[i, :, 0] | |
| stds = stats[i, :, 1] | |
| # Calculate upper and lower bounds for the shaded region | |
| upper_bound = means + stds | |
| lower_bound = means - stds | |
| # Create the plot | |
| plt.plot(xs, means, label=f"Seed {i}") | |
| # Shade the region between the bounds | |
| plt.fill_between(xs, lower_bound, upper_bound, | |
| alpha=0.3) | |
| plt.xlabel("Time Step") | |
| plt.ylabel(stat_name) | |
| plt.title(f"Learning Curve for {stat_name}") | |
| plt.legend() | |
| # Get the current figure | |
| fig = plt.gcf() | |
| # Save the figure if requested | |
| savepath = None | |
| if savedir is not None and savename is not None: | |
| savepath = os.path.join(savedir, f"{savename}_{stat_name}.pdf") | |
| plt.savefig(savepath) | |
| figures[stat_name] = fig | |
| if show_plots: | |
| plt.show() | |
| plt.close(fig) | |
| return figures, savepath | |
| def plot_xp_matrix(xp_matrix, xlabel, ylabel, title, | |
| higher_is_better=True, | |
| savedir=None, savename=None, | |
| show_plots=False | |
| ): | |
| if higher_is_better: | |
| colormap="coolwarm_r" | |
| arrow_str = r" ($\uparrow$)" | |
| else: | |
| colormap="coolwarm" | |
| arrow_str = r" ($\downarrow$)" | |
| # Plot as heatmap | |
| plt.figure(figsize=(6, 5)) | |
| sns.heatmap(xp_matrix, cmap=colormap, annot=False) | |
| plt.gca().invert_yaxis() | |
| plt.xlabel(xlabel) | |
| plt.ylabel(ylabel) | |
| plt.title(title + arrow_str) | |
| # Get the current figure | |
| fig = plt.gcf() | |
| # Save the figure if requested | |
| savepath = None | |
| if savedir is not None and savename is not None: | |
| savepath = os.path.join(savedir, f"{savename}.pdf") | |
| plt.savefig(savepath) | |
| if show_plots: | |
| plt.show() | |
| plt.close(fig) | |
| return fig, savepath | |
| def plot_xp_from_eval_metrics(eval_metrics, metric_name, higher_is_better=True, agent_idx=0, | |
| savedir=None, savename=None, | |
| show_plots=True): | |
| ''' | |
| Note that the FCP agent is always agent 0, the partner is agent 1. | |
| eval_metrics is a dictionary with keys corresponding to metric names | |
| and values as arrays of shape (num_seeds, num_fcp_checkpoints, num_eval_checkpoints, num_episodes, num_agents) | |
| ''' | |
| # Select agent 0's data and compute mean over seeds and episodes | |
| heatmap_data = jnp.mean(eval_metrics[metric_name][:, :, :, :, agent_idx], axis=(0, 3)) | |
| fig, savepath = plot_xp_matrix(heatmap_data, | |
| xlabel="Eval Checkpoint", ylabel="Ego Agent Checkpoint", | |
| title=f"Average {metric_name.replace('_', ' ').title()}", | |
| higher_is_better=higher_is_better, | |
| savedir=savedir, savename=savename, | |
| show_plots=show_plots) | |
| return fig, savepath | |