CausalGrok / code /experiments /figure_mi_comparison.py
nileshsarkar-ai's picture
Upload code/experiments
50fa85c verified
"""Generate MI workshop Figure 1: grokking vs standard probe heatmap.
Reads M1 probe outputs from experiments/runs/*/mechinterp/m1_probe_data.json,
produces a 2x2 grid (rows = hospital/tumor probe; cols = grokking/standard)
with epoch-x-layer heatmaps. Hospital = Reds (want fading); Tumor = Greens (want rising).
Picks the strongest grokking run by best_ood and the standard control with
periodic checkpoints (or final.pt if that's all that's available).
Usage:
python -m experiments.figure_mi_comparison \
[--grok-run experiments/runs/<id>] [--std-run experiments/runs/<id>] \
[--out paper_figures/figure1_MI_probe_comparison]
"""
from __future__ import annotations
import argparse
import glob
import json
import os
import sys
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parent.parent
def _load_summary(run_dir: Path) -> dict:
p = run_dir / "results" / "summary.json"
if not p.exists():
return {}
try:
return json.loads(p.read_text())
except Exception:
return {}
def _pick_best_grok_run() -> Path | None:
candidates = []
for f in glob.glob(str(ROOT / "experiments/runs/*/mechinterp/m1_probe_data.json")):
run_dir = Path(f).parent.parent
s = _load_summary(run_dir)
if s.get("condition") != "grokking":
continue
best = s.get("best_ood", 0) or 0
candidates.append((best, run_dir))
if not candidates:
return None
candidates.sort(reverse=True)
return candidates[0][1]
def _pick_std_run() -> Path | None:
candidates = []
for f in glob.glob(str(ROOT / "experiments/runs/*/mechinterp/m1_probe_data.json")):
run_dir = Path(f).parent.parent
s = _load_summary(run_dir)
if s.get("condition") != "standard":
continue
best = s.get("best_ood", s.get("ood_test_acc", 0)) or 0
candidates.append((best, run_dir))
if not candidates:
return None
candidates.sort(reverse=True)
return candidates[0][1]
def _load_probe(run_dir: Path) -> dict | None:
p = run_dir / "mechinterp" / "m1_probe_data.json"
if not p.exists():
return None
return json.loads(p.read_text())
def _heatmap(ax, data: dict, key: str, title: str, cmap: str):
if data is None:
ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(title, fontweight="bold", fontsize=10)
return None
epochs = data["epochs"]
layers = data["layers"]
mat = np.array(data[key]) # shape (n_epochs, n_layers)
if mat.ndim != 2:
mat = mat.reshape(len(epochs), len(layers))
im = ax.imshow(
mat.T,
aspect="auto",
cmap=cmap,
vmin=0.0,
vmax=1.0,
interpolation="nearest",
origin="lower",
)
ax.set_xticks(range(len(epochs)))
ax.set_xticklabels(epochs, rotation=45, ha="right", fontsize=7)
ax.set_yticks(range(len(layers)))
ax.set_yticklabels(layers, fontsize=8)
ax.set_xlabel("Epoch", fontsize=9)
ax.set_title(title, fontweight="bold", fontsize=10)
return im
def make_figure(grok_dir: Path, std_dir: Path | None, out_base: Path):
grok = _load_probe(grok_dir)
std = _load_probe(std_dir) if std_dir else None
if grok is None:
print(f"[ERROR] no probe data at {grok_dir}/mechinterp/m1_probe_data.json")
sys.exit(1)
# Hospital probe on H3 (held-in held-out hospital) β€” has signal.
# H4 version is degenerate (probe class set excludes H4 by construction β†’ ≑ 0).
hosp_key = "hospital_probe_id" if "hospital_probe_id" in grok else "hospital_probe"
# Tumor probe on H4 (truly OOD hospital) β€” measures causal-feature transfer.
tumor_key = "tumor_probe_ood" if "tumor_probe_ood" in grok else "tumor_probe"
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
grok_title = f"Grokking ({grok_dir.name[-30:]})"
std_title = f"Standard ({std_dir.name[-30:]})" if std_dir else "Standard (no data)"
im00 = _heatmap(axes[0][0], grok, hosp_key, f"{grok_title}\nHospital probe on H3 (shortcut recoverability, ↓ good)", "Reds")
im01 = _heatmap(axes[0][1], std, hosp_key, f"{std_title}\nHospital probe on H3 (shortcut recoverability, ↓ good)", "Reds")
im10 = _heatmap(axes[1][0], grok, tumor_key, f"{grok_title}\nTumor probe on H4 (causal transfer, ↑ good)", "Greens")
im11 = _heatmap(axes[1][1], std, tumor_key, f"{std_title}\nTumor probe on H4 (causal transfer, ↑ good)", "Greens")
for im, ax in [(im00, axes[0][0]), (im01, axes[0][1]), (im10, axes[1][0]), (im11, axes[1][1])]:
if im is not None:
plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02)
fig.suptitle(
"Figure 1 β€” Layer-wise circuit analysis: grokking-favorable vs standard training\n"
"Grokking: deep-layer hospital recoverability (Reds) drops over training while tumor recoverability (Greens) is preserved.\n"
"Standard: no localized scrubbing of hospital information.",
fontsize=11,
y=1.005,
fontweight="bold",
)
plt.tight_layout()
out_base.parent.mkdir(parents=True, exist_ok=True)
png = out_base.with_suffix(".png")
pdf = out_base.with_suffix(".pdf")
fig.savefig(png, bbox_inches="tight", dpi=200)
fig.savefig(pdf, bbox_inches="tight")
plt.close(fig)
print(f"Saved {png}")
print(f"Saved {pdf}")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--grok-run", default=None, help="Path to grokking run dir; auto-pick best if omitted")
ap.add_argument("--std-run", default=None, help="Path to standard run dir; auto-pick best if omitted")
ap.add_argument("--out", default="paper_figures/figure1_MI_probe_comparison")
args = ap.parse_args()
grok_dir = Path(args.grok_run) if args.grok_run else _pick_best_grok_run()
std_dir = Path(args.std_run) if args.std_run else _pick_std_run()
if grok_dir is None:
print("[ERROR] No grokking run with M1 probe data found.")
print(" Run experiments/mechinterp_m1.py on a grokking run first.")
sys.exit(2)
print(f"Grokking run : {grok_dir}")
print(f"Standard run : {std_dir if std_dir else '(none β€” figure will show only grokking)'}")
out_base = ROOT / args.out
make_figure(grok_dir, std_dir, out_base)
if __name__ == "__main__":
main()