import json import os from collections import Counter from typing import Any, Dict, List, Optional, Tuple import matplotlib.pyplot as plt import numpy as np from torch.utils.tensorboard import SummaryWriter def append_leafstats(tree1: Dict, tree2: Dict): """ Append each corresponding leaf of tree2 to tree1. """ for key, value in tree2.items(): if key not in tree1: tree1[key] = value elif isinstance(value, dict): append_leafstats(tree1[key], value) elif isinstance(value, list): if isinstance(tree1[key], list): tree1[key].extend(value) else: tree1[key] = [tree1[key]] + value else: if isinstance(tree1[key], list): tree1[key].append(value) else: tree1[key] = [tree1[key], value] def get_mean_leafstats(tree: Dict) -> Dict: """ scores a leafstats where each leaf is replaced by the mean of the leaf values. """ result = {} for key, value in tree.items(): if isinstance(value, dict): result[key] = get_mean_leafstats(value) elif isinstance(value, list): cleaned_values = [v for v in value if v is not None] result[key] = np.mean(cleaned_values) if cleaned_values else None else: result[key] = value return result def get_mean_leafstats(trees: List[Dict]) -> Dict: """ Computes a smart element-wise mean across multiple leafstats. For each path that exists in any of the input trees: - If the path contains arrays in multiple trees, computes the element-wise mean - Handles arrays of different lengths by taking means of available values at each index - If a path exists in only some trees, still computes the mean using available data Args: trees: List of leafstats dictionaries to compute means from Returns: A new leafstats where each leaf is the element-wise mean of the corresponding array leaves in the input trees """ if not trees: return {} if len(trees) == 1: return trees[0] # Collect all possible keys at this level all_keys = set() for tree in trees: all_keys.update(tree.keys()) result = {} for key in all_keys: # Collect all values at this key position values_at_key = [] for tree in trees: if key in tree: values_at_key.append(tree[key]) # If all values are dictionaries, recursively compute means if all(isinstance(v, dict) for v in values_at_key): result[key] = get_mean_leafstats(values_at_key) # If any value is a list, compute element-wise means elif any(isinstance(v, list) for v in values_at_key): # First, convert non-list values to lists (singleton) list_values = [] for v in values_at_key: if isinstance(v, list): list_values.append(v) else: list_values.append([v]) # Find the maximum length among all lists max_length = max(len(lst) for lst in list_values) # Initialize result array mean_array = [] # Compute element-wise means for i in range(max_length): # Collect values at this index that are not None values_at_index = [ lst[i] for lst in list_values if i < len(lst) and lst[i] is not None ] # If we have valid values, compute mean; otherwise, use None if values_at_index: mean_array.append(np.mean(values_at_index)) else: mean_array.append(None) result[key] = mean_array # If all values are scalars (not dict or list), compute their mean else: # Filter out None values non_none_values = [v for v in values_at_key if v is not None] if non_none_values: result[key] = np.mean(non_none_values) else: result[key] = None return result def get_var_leafstats(tree: Dict) -> Dict: """ scores a leafstats where each leaf is replaced by the variance of the leaf values. """ result = {} for key, value in tree.items(): if isinstance(value, dict): result[key] = get_var_leafstats(value) elif isinstance(value, list): cleaned_values = [v for v in value if v is not None] result[key] = np.var(cleaned_values) if cleaned_values else None else: result[key] = 0 # Single value has variance 0 return result def plot_leafstats(tree: Dict, folder: str, path: str = ""): """ Plots the leaves of the leafstats and saves them to the specified folder. """ os.makedirs(folder, exist_ok=True) for key, value in tree.items(): new_path = f"{path}_{key}" if isinstance(value, dict): plot_leafstats(value, folder, new_path) elif isinstance(value, list): plt.figure() plt.plot(value) plt.title(new_path) plt.savefig(os.path.join(folder, f"{new_path.replace('/', '_')}.png")) plt.close() def plot_EMA_leafstats(tree: Dict, folder: str, path: str = "", alpha: float = 0.1): """ Plots the exponential moving average of the leaves of the leafstats and saves them to the specified folder. """ os.makedirs(folder, exist_ok=True) for key, value in tree.items(): new_path = f"{path}/{key}" if path else key if isinstance(value, dict): plot_EMA_leafstats(value, folder, new_path, alpha) elif isinstance(value, list): value = np.array(value) nb_elements = len(value) coefficients = (1 - alpha) ** np.arange(nb_elements, 0, -1) value = np.cumsum(value * coefficients) value /= (1 - alpha) ** np.arange(nb_elements - 1, -1, -1) # renormalize out_path = f"EMA_alpha_{alpha}_{path}_{key}" out_path = os.path.join(folder, f"{out_path.replace('/', '_')}.png") plt.figure() plt.plot(value) plt.title(new_path) plt.savefig(out_path) plt.close() def plot_SMA_leafstats(tree: Dict, folder: str, path: str = "", window: int = 33): """ Plots the simple moving average of the leaves of the leafstats and saves them to the specified folder. """ os.makedirs(folder, exist_ok=True) for key, value in tree.items(): new_path = f"{path}/{key}" if path else key if isinstance(value, dict): plot_SMA_leafstats(value, folder, new_path, window) elif isinstance(value, list): value = np.array(value) nb_elements = len(value) assert window % 2 == 1 # Even numbers are annoying for centered windows value = np.convolve(v=value, a=np.ones(window) / window, mode="same") # Adjust out of window for start and finish value[: window // 2] *= window / np.arange(window // 2 + 1, window) value[-window // 2 + 1 :] *= window / np.arange(window - 1, window // 2, -1) out_path = f"SMA_window_{window}_{path}_{key}" out_path = os.path.join(folder, f"{out_path.replace('/', '_')}.png") plt.figure() plt.plot(value) plt.title(new_path) plt.savefig(out_path) plt.close() def save_leafstats(tree: Dict, folder: str, path: str = ""): """ Saves the leaves of the leafstats to the specified folder. """ os.makedirs(folder, exist_ok=True) for key, value in tree.items(): new_path = f"{path}/{key}" if path else key if isinstance(value, dict): save_leafstats(value, folder, new_path) elif ( isinstance(value, list) or isinstance(value, np.ndarray) or isinstance(value, float) or isinstance(value, int) ): with open( os.path.join(folder, f"{new_path.replace('/', '_')}.json"), "w" ) as f: json.dump(value, f, indent=4) def tb_leafstats(tree: Dict, writer, path: str = ""): """ Logs the leaves of the leafstats to TensorBoard. """ for key, value in tree.items(): new_path = f"{path}/{key}" if path else key if isinstance(value, dict): tb_leafstats(value, writer, new_path) elif isinstance(value, list): for i, v in enumerate(value): if v is not None: writer.add_scalar(new_path, v, i) def update_agent_statistics(input_path, output_file): """ Computes statistics for the current iteration and updates the global statistics file. Args: input_path (str): Path to the folder containing agent JSON files for the current iteration. output_file (str): Path to the JSON file where statistics are stored. """ # Build leafstats by appending each dict from JSON files in "input_path" folder leafstats = {} for filename in os.listdir(input_path): if filename.endswith(".json"): with open(os.path.join(input_path, filename), "r") as f: data = json.load(f) append_leafstats(leafstats, data) # Get epoch mean leafstats mean_leafstats = get_mean_leafstats(leafstats) # Add mean leafstats to global stats file if os.path.exists(output_file): with open(output_file, "r") as f: global_stats = json.load(f) else: global_stats = {} append_leafstats(global_stats, mean_leafstats) with open(output_file, "w") as f: json.dump(global_stats, f, indent=4) def generate_agent_stats_plots( global_stats_path, matplotlib_log_dir, tensorboard_log_dir, wandb_log_dir ): """ Visualizes the global statistics by logging them to TensorBoard and Weights & Biases. Args: global_stats_path (str): Path to the global statistics JSON file. tensorboard_log_dir (str): Directory to save TensorBoard logs. wandb_log_dir (str): Directory for Weights & Biases run metadata. """ os.makedirs(matplotlib_log_dir, exist_ok=True) os.makedirs(tensorboard_log_dir, exist_ok=True) os.makedirs(wandb_log_dir, exist_ok=True) with open(global_stats_path, "r") as f: global_stats = json.load(f) plot_leafstats(global_stats, folder=matplotlib_log_dir) # Log statistics to TensorBoard writer = SummaryWriter(tensorboard_log_dir) tb_leafstats(global_stats, writer) writer.close() # wb_leafstats(global_stats)