File size: 6,328 Bytes
5146e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
from functools import partial

import jax
import jax.numpy as jnp
import seaborn as sns
import matplotlib.pyplot as plt

def get_metric_names(env_name):
    if "lbf" in env_name:
        return ("percent_eaten", "returned_episode_returns")
    elif env_name == "overcooked-v1":
        return ("base_return", "returned_episode_returns")
    elif env_name == "dsse":
        return ("targets_found", "returned_episode_returns")
    elif env_name == "hanabi":
        return ("returned_episode_returns",)
    else:
        return ("returned_episode_returns",)

@partial(jax.jit, static_argnames=['stats'])
def get_stats(metrics, stats: tuple):
    '''
    Computes mean and std of metrics of interest for each seed and update, 
    using only the final steps of episodes. Note that each rollout contains multiple episodes.

    metrics is a pytree where each leaf has shape 
        (..., rollout_length, num_envs)
    stats is a tuple of strings, each corresponding to a metric of interest in metrics
    '''
    # Get mask for final steps of episodes
    mask = metrics.get("returned_episode", None)
    
    # Initialize output dictionary
    all_stats = {}
    stats = list(stats) # convert to list to correctly iterate if the tuple only has a single element

    # Detect condensed metrics: if mask is scalar per update (no rollout/env dims),
    # the metrics are already episode-masked means from scan condensation.
    condensed = (mask.ndim <= 2)

    for stat_name in stats:
        # Get the metric array
        metric_data = metrics[stat_name]  # Shape: (..., rollout_length, num_envs) or (num_seeds, num_updates)

        if metric_data.ndim == 2:
            # Data is already pre-averaged over rollout and envs
            means = metric_data
            stds = jnp.zeros_like(metric_data)
        else:
            # Compute means and stds for each seed and update
            # Use masked operations to only consider final episode steps
            mask_sum = jnp.maximum(1, mask.sum(axis=(-2, -1)))
            means = jnp.where(mask, metric_data, 0).sum(axis=(-2, -1)) / mask_sum
            # For std, first compute masked values
            masked_vals = jnp.where(mask, metric_data, 0)
            squared_diff = (masked_vals - means[..., None, None]) ** 2
            variance = jnp.where(mask, squared_diff, 0).sum(axis=(-2, -1)) / mask_sum
            stds = jnp.sqrt(variance)
            
        # Stack means and stds
        all_stats[stat_name] = jnp.stack([means, stds], axis=-1)
    
    return all_stats


def plot_train_metrics(all_stats, 
                       num_rollout_steps, num_envs,
                       savedir=None, savename=None,
                       show_plots=False
                       ):
    '''Each key in all_stats is a metric name, and the value is an array of shape (num_seeds, num_updates, 2)'''
    figures = {}
    for stat_name, stats in all_stats.items():
        stat_name = stat_name.replace("_", " ").title()
        num_seeds, num_updates, _ = stats.shape
        for i in range(num_seeds):
            print("Seed: ", i)
            print(f"Mean {stat_name} (Last Episode Step): ", stats[i, -1, 0])
            print(f"Std {stat_name} (Last Episode Step): ", stats[i, -1, 1])
            xs = jnp.arange(num_updates) * num_envs * num_rollout_steps
            means = stats[i, :, 0]
            stds = stats[i, :, 1]

            # Calculate upper and lower bounds for the shaded region
            upper_bound = means + stds
            lower_bound = means - stds

            # Create the plot
            plt.plot(xs, means, label=f"Seed {i}")

            # Shade the region between the bounds
            plt.fill_between(xs, lower_bound, upper_bound, 
                            alpha=0.3)
                        
        plt.xlabel("Time Step")
        plt.ylabel(stat_name)
        plt.title(f"Learning Curve for {stat_name}")
        plt.legend()
        
        # Get the current figure
        fig = plt.gcf()
        
        # Save the figure if requested
        savepath = None
        if savedir is not None and savename is not None:
            savepath = os.path.join(savedir, f"{savename}_{stat_name}.pdf")
            plt.savefig(savepath)
        figures[stat_name] = fig
        if show_plots:
            plt.show()
        
        plt.close(fig)
    
    return figures, savepath


def plot_xp_matrix(xp_matrix, xlabel, ylabel, title, 
                   higher_is_better=True,
                   savedir=None, savename=None,
                   show_plots=False
    ):
    if higher_is_better:
        colormap="coolwarm_r"
        arrow_str = r" ($\uparrow$)"
    else:    
        colormap="coolwarm"
        arrow_str = r" ($\downarrow$)"
    # Plot as heatmap
    plt.figure(figsize=(6, 5))
    sns.heatmap(xp_matrix, cmap=colormap, annot=False)
    plt.gca().invert_yaxis()
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title + arrow_str)

    # Get the current figure
    fig = plt.gcf()
    
    # Save the figure if requested
    savepath = None
    if savedir is not None and savename is not None:
        savepath = os.path.join(savedir, f"{savename}.pdf")
        plt.savefig(savepath)
    if show_plots:
        plt.show()
    
    plt.close(fig)
    
    return fig, savepath

def plot_xp_from_eval_metrics(eval_metrics, metric_name, higher_is_better=True, agent_idx=0,
                      savedir=None, savename=None,
                      show_plots=True):
    '''
    Note that the FCP agent is always agent 0, the partner is agent 1. 
    
    eval_metrics is a dictionary with keys corresponding to metric names 
    and values as arrays of shape (num_seeds, num_fcp_checkpoints, num_eval_checkpoints, num_episodes, num_agents)
    '''
    # Select agent 0's data and compute mean over seeds and episodes
    heatmap_data = jnp.mean(eval_metrics[metric_name][:, :, :, :, agent_idx], axis=(0, 3))
    fig, savepath = plot_xp_matrix(heatmap_data, 
                   xlabel="Eval Checkpoint", ylabel="Ego Agent Checkpoint", 
                   title=f"Average {metric_name.replace('_', ' ').title()}", 
                   higher_is_better=higher_is_better,
                   savedir=savedir, savename=savename,
                   show_plots=show_plots)
    return fig, savepath