File size: 10,832 Bytes
580868a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import json
import os
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter


def append_leafstats(tree1: Dict, tree2: Dict):
    """
    Append each corresponding leaf of tree2 to tree1.
    """
    for key, value in tree2.items():
        if key not in tree1:
            tree1[key] = value
        elif isinstance(value, dict):
            append_leafstats(tree1[key], value)
        elif isinstance(value, list):
            if isinstance(tree1[key], list):
                tree1[key].extend(value)
            else:
                tree1[key] = [tree1[key]] + value
        else:
            if isinstance(tree1[key], list):
                tree1[key].append(value)
            else:
                tree1[key] = [tree1[key], value]


def get_mean_leafstats(tree: Dict) -> Dict:
    """
    scores a leafstats where each leaf is replaced by the mean of the leaf values.
    """
    result = {}
    for key, value in tree.items():
        if isinstance(value, dict):
            result[key] = get_mean_leafstats(value)
        elif isinstance(value, list):
            cleaned_values = [v for v in value if v is not None]
            result[key] = np.mean(cleaned_values) if cleaned_values else None
        else:
            result[key] = value
    return result


def get_mean_leafstats(trees: List[Dict]) -> Dict:
    """
    Computes a smart element-wise mean across multiple leafstats.

    For each path that exists in any of the input trees:
    - If the path contains arrays in multiple trees, computes the element-wise mean
    - Handles arrays of different lengths by taking means of available values at each index
    - If a path exists in only some trees, still computes the mean using available data

    Args:
        trees: List of leafstats dictionaries to compute means from

    Returns:
        A new leafstats where each leaf is the element-wise mean of the corresponding array leaves in the input trees
    """
    if not trees:
        return {}

    if len(trees) == 1:
        return trees[0]

    # Collect all possible keys at this level
    all_keys = set()
    for tree in trees:
        all_keys.update(tree.keys())

    result = {}

    for key in all_keys:
        # Collect all values at this key position
        values_at_key = []

        for tree in trees:
            if key in tree:
                values_at_key.append(tree[key])

        # If all values are dictionaries, recursively compute means
        if all(isinstance(v, dict) for v in values_at_key):
            result[key] = get_mean_leafstats(values_at_key)

        # If any value is a list, compute element-wise means
        elif any(isinstance(v, list) for v in values_at_key):
            # First, convert non-list values to lists (singleton)
            list_values = []
            for v in values_at_key:
                if isinstance(v, list):
                    list_values.append(v)
                else:
                    list_values.append([v])

            # Find the maximum length among all lists
            max_length = max(len(lst) for lst in list_values)

            # Initialize result array
            mean_array = []

            # Compute element-wise means
            for i in range(max_length):
                # Collect values at this index that are not None
                values_at_index = [
                    lst[i] for lst in list_values if i < len(lst) and lst[i] is not None
                ]

                # If we have valid values, compute mean; otherwise, use None
                if values_at_index:
                    mean_array.append(np.mean(values_at_index))
                else:
                    mean_array.append(None)

            result[key] = mean_array

        # If all values are scalars (not dict or list), compute their mean
        else:
            # Filter out None values
            non_none_values = [v for v in values_at_key if v is not None]

            if non_none_values:
                result[key] = np.mean(non_none_values)
            else:
                result[key] = None

    return result


def get_var_leafstats(tree: Dict) -> Dict:
    """
    scores a leafstats where each leaf is replaced by the variance of the leaf values.
    """
    result = {}
    for key, value in tree.items():
        if isinstance(value, dict):
            result[key] = get_var_leafstats(value)
        elif isinstance(value, list):
            cleaned_values = [v for v in value if v is not None]
            result[key] = np.var(cleaned_values) if cleaned_values else None
        else:
            result[key] = 0  # Single value has variance 0
    return result


def plot_leafstats(tree: Dict, folder: str, path: str = ""):
    """
    Plots the leaves of the leafstats and saves them to the specified folder.
    """
    os.makedirs(folder, exist_ok=True)
    for key, value in tree.items():
        new_path = f"{path}_{key}"
        if isinstance(value, dict):
            plot_leafstats(value, folder, new_path)
        elif isinstance(value, list):
            plt.figure()
            plt.plot(value)
            plt.title(new_path)
            plt.savefig(os.path.join(folder, f"{new_path.replace('/', '_')}.png"))
            plt.close()


def plot_EMA_leafstats(tree: Dict, folder: str, path: str = "", alpha: float = 0.1):
    """
    Plots the exponential moving average of the leaves of the leafstats and saves them to the specified folder.
    """
    os.makedirs(folder, exist_ok=True)
    for key, value in tree.items():
        new_path = f"{path}/{key}" if path else key
        if isinstance(value, dict):
            plot_EMA_leafstats(value, folder, new_path, alpha)
        elif isinstance(value, list):
            value = np.array(value)
            nb_elements = len(value)
            coefficients = (1 - alpha) ** np.arange(nb_elements, 0, -1)
            value = np.cumsum(value * coefficients)
            value /= (1 - alpha) ** np.arange(nb_elements - 1, -1, -1)  # renormalize
            out_path = f"EMA_alpha_{alpha}_{path}_{key}"
            out_path = os.path.join(folder, f"{out_path.replace('/', '_')}.png")
            plt.figure()
            plt.plot(value)
            plt.title(new_path)
            plt.savefig(out_path)
            plt.close()


def plot_SMA_leafstats(tree: Dict, folder: str, path: str = "", window: int = 33):
    """
    Plots the simple moving average of the leaves of the leafstats and saves them to the specified folder.
    """
    os.makedirs(folder, exist_ok=True)
    for key, value in tree.items():
        new_path = f"{path}/{key}" if path else key
        if isinstance(value, dict):
            plot_SMA_leafstats(value, folder, new_path, window)
        elif isinstance(value, list):
            value = np.array(value)
            nb_elements = len(value)

            assert window % 2 == 1  # Even numbers are annoying for centered windows
            value = np.convolve(v=value, a=np.ones(window) / window, mode="same")
            # Adjust out of window for start and finish
            value[: window // 2] *= window / np.arange(window // 2 + 1, window)
            value[-window // 2 + 1 :] *= window / np.arange(window - 1, window // 2, -1)

            out_path = f"SMA_window_{window}_{path}_{key}"
            out_path = os.path.join(folder, f"{out_path.replace('/', '_')}.png")
            plt.figure()
            plt.plot(value)
            plt.title(new_path)
            plt.savefig(out_path)
            plt.close()


def save_leafstats(tree: Dict, folder: str, path: str = ""):
    """
    Saves the leaves of the leafstats to the specified folder.
    """
    os.makedirs(folder, exist_ok=True)
    for key, value in tree.items():
        new_path = f"{path}/{key}" if path else key
        if isinstance(value, dict):
            save_leafstats(value, folder, new_path)
        elif (
            isinstance(value, list)
            or isinstance(value, np.ndarray)
            or isinstance(value, float)
            or isinstance(value, int)
        ):
            with open(
                os.path.join(folder, f"{new_path.replace('/', '_')}.json"), "w"
            ) as f:
                json.dump(value, f, indent=4)


def tb_leafstats(tree: Dict, writer, path: str = ""):
    """
    Logs the leaves of the leafstats to TensorBoard.
    """
    for key, value in tree.items():
        new_path = f"{path}/{key}" if path else key
        if isinstance(value, dict):
            tb_leafstats(value, writer, new_path)
        elif isinstance(value, list):
            for i, v in enumerate(value):
                if v is not None:
                    writer.add_scalar(new_path, v, i)


def update_agent_statistics(input_path, output_file):
    """
    Computes statistics for the current iteration and updates the global statistics file.

    Args:
        input_path (str): Path to the folder containing agent JSON files for the current iteration.
        output_file (str): Path to the JSON file where statistics are stored.
    """

    # Build leafstats by appending each dict from JSON files in "input_path" folder
    leafstats = {}
    for filename in os.listdir(input_path):
        if filename.endswith(".json"):
            with open(os.path.join(input_path, filename), "r") as f:
                data = json.load(f)
                append_leafstats(leafstats, data)
    # Get epoch mean leafstats
    mean_leafstats = get_mean_leafstats(leafstats)

    # Add mean leafstats to global stats file
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            global_stats = json.load(f)
    else:
        global_stats = {}

    append_leafstats(global_stats, mean_leafstats)

    with open(output_file, "w") as f:
        json.dump(global_stats, f, indent=4)


def generate_agent_stats_plots(
    global_stats_path, matplotlib_log_dir, tensorboard_log_dir, wandb_log_dir
):
    """
    Visualizes the global statistics by logging them to TensorBoard and Weights & Biases.

    Args:
        global_stats_path (str): Path to the global statistics JSON file.
        tensorboard_log_dir (str): Directory to save TensorBoard logs.
        wandb_log_dir (str): Directory for Weights & Biases run metadata.
    """
    os.makedirs(matplotlib_log_dir, exist_ok=True)
    os.makedirs(tensorboard_log_dir, exist_ok=True)
    os.makedirs(wandb_log_dir, exist_ok=True)

    with open(global_stats_path, "r") as f:
        global_stats = json.load(f)

    plot_leafstats(global_stats, folder=matplotlib_log_dir)

    # Log statistics to TensorBoard
    writer = SummaryWriter(tensorboard_log_dir)
    tb_leafstats(global_stats, writer)
    writer.close()

    # wb_leafstats(global_stats)