| """Generate MI workshop Figure 1: grokking vs standard probe heatmap. |
| |
| Reads M1 probe outputs from experiments/runs/*/mechinterp/m1_probe_data.json, |
| produces a 2x2 grid (rows = hospital/tumor probe; cols = grokking/standard) |
| with epoch-x-layer heatmaps. Hospital = Reds (want fading); Tumor = Greens (want rising). |
| |
| Picks the strongest grokking run by best_ood and the standard control with |
| periodic checkpoints (or final.pt if that's all that's available). |
| |
| Usage: |
| python -m experiments.figure_mi_comparison \ |
| [--grok-run experiments/runs/<id>] [--std-run experiments/runs/<id>] \ |
| [--out paper_figures/figure1_MI_probe_comparison] |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import glob |
| import json |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import matplotlib |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
|
|
|
|
| def _load_summary(run_dir: Path) -> dict: |
| p = run_dir / "results" / "summary.json" |
| if not p.exists(): |
| return {} |
| try: |
| return json.loads(p.read_text()) |
| except Exception: |
| return {} |
|
|
|
|
| def _pick_best_grok_run() -> Path | None: |
| candidates = [] |
| for f in glob.glob(str(ROOT / "experiments/runs/*/mechinterp/m1_probe_data.json")): |
| run_dir = Path(f).parent.parent |
| s = _load_summary(run_dir) |
| if s.get("condition") != "grokking": |
| continue |
| best = s.get("best_ood", 0) or 0 |
| candidates.append((best, run_dir)) |
| if not candidates: |
| return None |
| candidates.sort(reverse=True) |
| return candidates[0][1] |
|
|
|
|
| def _pick_std_run() -> Path | None: |
| candidates = [] |
| for f in glob.glob(str(ROOT / "experiments/runs/*/mechinterp/m1_probe_data.json")): |
| run_dir = Path(f).parent.parent |
| s = _load_summary(run_dir) |
| if s.get("condition") != "standard": |
| continue |
| best = s.get("best_ood", s.get("ood_test_acc", 0)) or 0 |
| candidates.append((best, run_dir)) |
| if not candidates: |
| return None |
| candidates.sort(reverse=True) |
| return candidates[0][1] |
|
|
|
|
| def _load_probe(run_dir: Path) -> dict | None: |
| p = run_dir / "mechinterp" / "m1_probe_data.json" |
| if not p.exists(): |
| return None |
| return json.loads(p.read_text()) |
|
|
|
|
| def _heatmap(ax, data: dict, key: str, title: str, cmap: str): |
| if data is None: |
| ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes) |
| ax.set_xticks([]) |
| ax.set_yticks([]) |
| ax.set_title(title, fontweight="bold", fontsize=10) |
| return None |
|
|
| epochs = data["epochs"] |
| layers = data["layers"] |
| mat = np.array(data[key]) |
| if mat.ndim != 2: |
| mat = mat.reshape(len(epochs), len(layers)) |
|
|
| im = ax.imshow( |
| mat.T, |
| aspect="auto", |
| cmap=cmap, |
| vmin=0.0, |
| vmax=1.0, |
| interpolation="nearest", |
| origin="lower", |
| ) |
| ax.set_xticks(range(len(epochs))) |
| ax.set_xticklabels(epochs, rotation=45, ha="right", fontsize=7) |
| ax.set_yticks(range(len(layers))) |
| ax.set_yticklabels(layers, fontsize=8) |
| ax.set_xlabel("Epoch", fontsize=9) |
| ax.set_title(title, fontweight="bold", fontsize=10) |
| return im |
|
|
|
|
| def make_figure(grok_dir: Path, std_dir: Path | None, out_base: Path): |
| grok = _load_probe(grok_dir) |
| std = _load_probe(std_dir) if std_dir else None |
|
|
| if grok is None: |
| print(f"[ERROR] no probe data at {grok_dir}/mechinterp/m1_probe_data.json") |
| sys.exit(1) |
|
|
| |
| |
| hosp_key = "hospital_probe_id" if "hospital_probe_id" in grok else "hospital_probe" |
| |
| tumor_key = "tumor_probe_ood" if "tumor_probe_ood" in grok else "tumor_probe" |
|
|
| fig, axes = plt.subplots(2, 2, figsize=(14, 10)) |
|
|
| grok_title = f"Grokking ({grok_dir.name[-30:]})" |
| std_title = f"Standard ({std_dir.name[-30:]})" if std_dir else "Standard (no data)" |
|
|
| im00 = _heatmap(axes[0][0], grok, hosp_key, f"{grok_title}\nHospital probe on H3 (shortcut recoverability, β good)", "Reds") |
| im01 = _heatmap(axes[0][1], std, hosp_key, f"{std_title}\nHospital probe on H3 (shortcut recoverability, β good)", "Reds") |
| im10 = _heatmap(axes[1][0], grok, tumor_key, f"{grok_title}\nTumor probe on H4 (causal transfer, β good)", "Greens") |
| im11 = _heatmap(axes[1][1], std, tumor_key, f"{std_title}\nTumor probe on H4 (causal transfer, β good)", "Greens") |
|
|
| for im, ax in [(im00, axes[0][0]), (im01, axes[0][1]), (im10, axes[1][0]), (im11, axes[1][1])]: |
| if im is not None: |
| plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02) |
|
|
| fig.suptitle( |
| "Figure 1 β Layer-wise circuit analysis: grokking-favorable vs standard training\n" |
| "Grokking: deep-layer hospital recoverability (Reds) drops over training while tumor recoverability (Greens) is preserved.\n" |
| "Standard: no localized scrubbing of hospital information.", |
| fontsize=11, |
| y=1.005, |
| fontweight="bold", |
| ) |
| plt.tight_layout() |
|
|
| out_base.parent.mkdir(parents=True, exist_ok=True) |
| png = out_base.with_suffix(".png") |
| pdf = out_base.with_suffix(".pdf") |
| fig.savefig(png, bbox_inches="tight", dpi=200) |
| fig.savefig(pdf, bbox_inches="tight") |
| plt.close(fig) |
| print(f"Saved {png}") |
| print(f"Saved {pdf}") |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--grok-run", default=None, help="Path to grokking run dir; auto-pick best if omitted") |
| ap.add_argument("--std-run", default=None, help="Path to standard run dir; auto-pick best if omitted") |
| ap.add_argument("--out", default="paper_figures/figure1_MI_probe_comparison") |
| args = ap.parse_args() |
|
|
| grok_dir = Path(args.grok_run) if args.grok_run else _pick_best_grok_run() |
| std_dir = Path(args.std_run) if args.std_run else _pick_std_run() |
|
|
| if grok_dir is None: |
| print("[ERROR] No grokking run with M1 probe data found.") |
| print(" Run experiments/mechinterp_m1.py on a grokking run first.") |
| sys.exit(2) |
|
|
| print(f"Grokking run : {grok_dir}") |
| print(f"Standard run : {std_dir if std_dir else '(none β figure will show only grokking)'}") |
|
|
| out_base = ROOT / args.out |
| make_figure(grok_dir, std_dir, out_base) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|