Spaces:
Running
Running
| """Generate a heatmap from an XP matrix CSV exported by evaluation/run.py. | |
| The CSV cells are formatted like "0.50 (0.45, 0.55)". This script extracts the | |
| mean value from each cell and renders a readable heatmap. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import re | |
| from pathlib import Path | |
| from typing import Iterable | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import seaborn as sns | |
| MEAN_RE = re.compile(r"^\s*([-+]?\d*\.?\d+)") | |
| def parse_matrix(csv_path: Path): | |
| with csv_path.open(newline="") as f: | |
| reader = csv.reader(f) | |
| rows = list(reader) | |
| header = rows[0][1:] | |
| row_labels = [] | |
| values = [] | |
| for row in rows[1:]: | |
| row_labels.append(row[0]) | |
| parsed = [] | |
| for cell in row[1:]: | |
| match = MEAN_RE.match(cell) | |
| if not match: | |
| raise ValueError(f"Could not parse mean from cell: {cell!r}") | |
| parsed.append(float(match.group(1))) | |
| values.append(parsed) | |
| return row_labels, header, np.array(values, dtype=float) | |
| def plot_heatmap( | |
| row_labels, | |
| col_labels, | |
| values, | |
| title: str, | |
| out_path: Path, | |
| annotate: bool = True, | |
| ): | |
| plt.figure(figsize=(max(10, len(col_labels) * 0.8), max(8, len(row_labels) * 0.6))) | |
| sns.heatmap( | |
| values, | |
| cmap="viridis", | |
| annot=annotate, | |
| fmt=".2f", | |
| xticklabels=col_labels, | |
| yticklabels=row_labels, | |
| cbar_kws={"label": "Mean value"}, | |
| ) | |
| plt.title(title) | |
| plt.xlabel("Best-response policies") | |
| plt.ylabel("Held-out teammates") | |
| plt.xticks(rotation=45, ha="right") | |
| plt.yticks(rotation=0) | |
| plt.tight_layout() | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| plt.savefig(out_path, dpi=220, bbox_inches="tight") | |
| plt.close() | |
| def generate_heatmap_from_csv( | |
| csv_path: Path, | |
| title: str | None = None, | |
| out_path: Path | None = None, | |
| annotate: bool = True, | |
| ) -> Path: | |
| row_labels, col_labels, values = parse_matrix(csv_path) | |
| resolved_out_path = out_path or csv_path.with_suffix(".png") | |
| resolved_title = title or csv_path.stem | |
| plot_heatmap(row_labels, col_labels, values, resolved_title, resolved_out_path, annotate) | |
| return resolved_out_path | |
| def iter_csvs(input_path: Path) -> Iterable[Path]: | |
| if input_path.is_file(): | |
| yield input_path | |
| return | |
| for csv_path in sorted(input_path.glob("*.csv")): | |
| yield csv_path | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "input_path", | |
| type=Path, | |
| help="Path to a CSV file or a directory containing XP CSV files.", | |
| ) | |
| parser.add_argument("--title", default="XP Matrix Heatmap") | |
| parser.add_argument("--out", type=Path, default=None) | |
| parser.add_argument( | |
| "--no-annot", | |
| action="store_true", | |
| help="Disable per-cell numeric annotations.", | |
| ) | |
| args = parser.parse_args() | |
| input_path = args.input_path | |
| annotate = not args.no_annot | |
| if input_path.is_file(): | |
| out_path = generate_heatmap_from_csv( | |
| input_path, | |
| title=args.title, | |
| out_path=args.out, | |
| annotate=annotate, | |
| ) | |
| print(out_path) | |
| return | |
| if not input_path.exists() or not input_path.is_dir(): | |
| raise ValueError(f"Input path must be a CSV file or directory: {input_path}") | |
| generated = 0 | |
| for csv_path in iter_csvs(input_path): | |
| out_path = generate_heatmap_from_csv( | |
| csv_path, | |
| title=csv_path.stem, | |
| out_path=None, | |
| annotate=annotate, | |
| ) | |
| print(out_path) | |
| generated += 1 | |
| if generated == 0: | |
| raise ValueError(f"No CSV files found in directory: {input_path}") | |
| if __name__ == "__main__": | |
| main() | |