jaxaht-benchmark / evaluation /plot_xp_csv_heatmap.py
lainwired's picture
Initial jaxaht-benchmark deployment
5146e76
"""Generate a heatmap from an XP matrix CSV exported by evaluation/run.py.
The CSV cells are formatted like "0.50 (0.45, 0.55)". This script extracts the
mean value from each cell and renders a readable heatmap.
"""
from __future__ import annotations
import argparse
import csv
import re
from pathlib import Path
from typing import Iterable
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
MEAN_RE = re.compile(r"^\s*([-+]?\d*\.?\d+)")
def parse_matrix(csv_path: Path):
with csv_path.open(newline="") as f:
reader = csv.reader(f)
rows = list(reader)
header = rows[0][1:]
row_labels = []
values = []
for row in rows[1:]:
row_labels.append(row[0])
parsed = []
for cell in row[1:]:
match = MEAN_RE.match(cell)
if not match:
raise ValueError(f"Could not parse mean from cell: {cell!r}")
parsed.append(float(match.group(1)))
values.append(parsed)
return row_labels, header, np.array(values, dtype=float)
def plot_heatmap(
row_labels,
col_labels,
values,
title: str,
out_path: Path,
annotate: bool = True,
):
plt.figure(figsize=(max(10, len(col_labels) * 0.8), max(8, len(row_labels) * 0.6)))
sns.heatmap(
values,
cmap="viridis",
annot=annotate,
fmt=".2f",
xticklabels=col_labels,
yticklabels=row_labels,
cbar_kws={"label": "Mean value"},
)
plt.title(title)
plt.xlabel("Best-response policies")
plt.ylabel("Held-out teammates")
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.tight_layout()
out_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(out_path, dpi=220, bbox_inches="tight")
plt.close()
def generate_heatmap_from_csv(
csv_path: Path,
title: str | None = None,
out_path: Path | None = None,
annotate: bool = True,
) -> Path:
row_labels, col_labels, values = parse_matrix(csv_path)
resolved_out_path = out_path or csv_path.with_suffix(".png")
resolved_title = title or csv_path.stem
plot_heatmap(row_labels, col_labels, values, resolved_title, resolved_out_path, annotate)
return resolved_out_path
def iter_csvs(input_path: Path) -> Iterable[Path]:
if input_path.is_file():
yield input_path
return
for csv_path in sorted(input_path.glob("*.csv")):
yield csv_path
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"input_path",
type=Path,
help="Path to a CSV file or a directory containing XP CSV files.",
)
parser.add_argument("--title", default="XP Matrix Heatmap")
parser.add_argument("--out", type=Path, default=None)
parser.add_argument(
"--no-annot",
action="store_true",
help="Disable per-cell numeric annotations.",
)
args = parser.parse_args()
input_path = args.input_path
annotate = not args.no_annot
if input_path.is_file():
out_path = generate_heatmap_from_csv(
input_path,
title=args.title,
out_path=args.out,
annotate=annotate,
)
print(out_path)
return
if not input_path.exists() or not input_path.is_dir():
raise ValueError(f"Input path must be a CSV file or directory: {input_path}")
generated = 0
for csv_path in iter_csvs(input_path):
out_path = generate_heatmap_from_csv(
csv_path,
title=csv_path.stem,
out_path=None,
annotate=annotate,
)
print(out_path)
generated += 1
if generated == 0:
raise ValueError(f"No CSV files found in directory: {input_path}")
if __name__ == "__main__":
main()