File size: 6,542 Bytes
50fa85c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | """Generate MI workshop Figure 1: grokking vs standard probe heatmap.
Reads M1 probe outputs from experiments/runs/*/mechinterp/m1_probe_data.json,
produces a 2x2 grid (rows = hospital/tumor probe; cols = grokking/standard)
with epoch-x-layer heatmaps. Hospital = Reds (want fading); Tumor = Greens (want rising).
Picks the strongest grokking run by best_ood and the standard control with
periodic checkpoints (or final.pt if that's all that's available).
Usage:
python -m experiments.figure_mi_comparison \
[--grok-run experiments/runs/<id>] [--std-run experiments/runs/<id>] \
[--out paper_figures/figure1_MI_probe_comparison]
"""
from __future__ import annotations
import argparse
import glob
import json
import os
import sys
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parent.parent
def _load_summary(run_dir: Path) -> dict:
p = run_dir / "results" / "summary.json"
if not p.exists():
return {}
try:
return json.loads(p.read_text())
except Exception:
return {}
def _pick_best_grok_run() -> Path | None:
candidates = []
for f in glob.glob(str(ROOT / "experiments/runs/*/mechinterp/m1_probe_data.json")):
run_dir = Path(f).parent.parent
s = _load_summary(run_dir)
if s.get("condition") != "grokking":
continue
best = s.get("best_ood", 0) or 0
candidates.append((best, run_dir))
if not candidates:
return None
candidates.sort(reverse=True)
return candidates[0][1]
def _pick_std_run() -> Path | None:
candidates = []
for f in glob.glob(str(ROOT / "experiments/runs/*/mechinterp/m1_probe_data.json")):
run_dir = Path(f).parent.parent
s = _load_summary(run_dir)
if s.get("condition") != "standard":
continue
best = s.get("best_ood", s.get("ood_test_acc", 0)) or 0
candidates.append((best, run_dir))
if not candidates:
return None
candidates.sort(reverse=True)
return candidates[0][1]
def _load_probe(run_dir: Path) -> dict | None:
p = run_dir / "mechinterp" / "m1_probe_data.json"
if not p.exists():
return None
return json.loads(p.read_text())
def _heatmap(ax, data: dict, key: str, title: str, cmap: str):
if data is None:
ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(title, fontweight="bold", fontsize=10)
return None
epochs = data["epochs"]
layers = data["layers"]
mat = np.array(data[key]) # shape (n_epochs, n_layers)
if mat.ndim != 2:
mat = mat.reshape(len(epochs), len(layers))
im = ax.imshow(
mat.T,
aspect="auto",
cmap=cmap,
vmin=0.0,
vmax=1.0,
interpolation="nearest",
origin="lower",
)
ax.set_xticks(range(len(epochs)))
ax.set_xticklabels(epochs, rotation=45, ha="right", fontsize=7)
ax.set_yticks(range(len(layers)))
ax.set_yticklabels(layers, fontsize=8)
ax.set_xlabel("Epoch", fontsize=9)
ax.set_title(title, fontweight="bold", fontsize=10)
return im
def make_figure(grok_dir: Path, std_dir: Path | None, out_base: Path):
grok = _load_probe(grok_dir)
std = _load_probe(std_dir) if std_dir else None
if grok is None:
print(f"[ERROR] no probe data at {grok_dir}/mechinterp/m1_probe_data.json")
sys.exit(1)
# Hospital probe on H3 (held-in held-out hospital) β has signal.
# H4 version is degenerate (probe class set excludes H4 by construction β β‘ 0).
hosp_key = "hospital_probe_id" if "hospital_probe_id" in grok else "hospital_probe"
# Tumor probe on H4 (truly OOD hospital) β measures causal-feature transfer.
tumor_key = "tumor_probe_ood" if "tumor_probe_ood" in grok else "tumor_probe"
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
grok_title = f"Grokking ({grok_dir.name[-30:]})"
std_title = f"Standard ({std_dir.name[-30:]})" if std_dir else "Standard (no data)"
im00 = _heatmap(axes[0][0], grok, hosp_key, f"{grok_title}\nHospital probe on H3 (shortcut recoverability, β good)", "Reds")
im01 = _heatmap(axes[0][1], std, hosp_key, f"{std_title}\nHospital probe on H3 (shortcut recoverability, β good)", "Reds")
im10 = _heatmap(axes[1][0], grok, tumor_key, f"{grok_title}\nTumor probe on H4 (causal transfer, β good)", "Greens")
im11 = _heatmap(axes[1][1], std, tumor_key, f"{std_title}\nTumor probe on H4 (causal transfer, β good)", "Greens")
for im, ax in [(im00, axes[0][0]), (im01, axes[0][1]), (im10, axes[1][0]), (im11, axes[1][1])]:
if im is not None:
plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02)
fig.suptitle(
"Figure 1 β Layer-wise circuit analysis: grokking-favorable vs standard training\n"
"Grokking: deep-layer hospital recoverability (Reds) drops over training while tumor recoverability (Greens) is preserved.\n"
"Standard: no localized scrubbing of hospital information.",
fontsize=11,
y=1.005,
fontweight="bold",
)
plt.tight_layout()
out_base.parent.mkdir(parents=True, exist_ok=True)
png = out_base.with_suffix(".png")
pdf = out_base.with_suffix(".pdf")
fig.savefig(png, bbox_inches="tight", dpi=200)
fig.savefig(pdf, bbox_inches="tight")
plt.close(fig)
print(f"Saved {png}")
print(f"Saved {pdf}")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--grok-run", default=None, help="Path to grokking run dir; auto-pick best if omitted")
ap.add_argument("--std-run", default=None, help="Path to standard run dir; auto-pick best if omitted")
ap.add_argument("--out", default="paper_figures/figure1_MI_probe_comparison")
args = ap.parse_args()
grok_dir = Path(args.grok_run) if args.grok_run else _pick_best_grok_run()
std_dir = Path(args.std_run) if args.std_run else _pick_std_run()
if grok_dir is None:
print("[ERROR] No grokking run with M1 probe data found.")
print(" Run experiments/mechinterp_m1.py on a grokking run first.")
sys.exit(2)
print(f"Grokking run : {grok_dir}")
print(f"Standard run : {std_dir if std_dir else '(none β figure will show only grokking)'}")
out_base = ROOT / args.out
make_figure(grok_dir, std_dir, out_base)
if __name__ == "__main__":
main()
|