| """Cross-run M6 figure: targeted-shortcut vs random ablation, mean ± std bands. |
| |
| Reads experiments/runs/*/mechinterp/m6_neuron_ablation_*.json (extended format |
| with random/morphology/ID controls). Plots, per condition (grokking/standard): |
| Top: head OOD vs K — shortcut (red), random (black), morphology (green) |
| Bottom: head ID vs K — same conditions, dashed style |
| |
| The key reviewer question — "is targeted ablation different from random damage?" |
| — is answered visually by the gap between the red and black curves. |
| |
| Outputs: |
| paper_figures/figure_m6_targeted_vs_random.{png,pdf} |
| """ |
| from __future__ import annotations |
|
|
| import glob, json |
| from pathlib import Path |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
|
|
|
|
| def _gather(only_n1000: bool = True): |
| """Gather M6 outputs. By default restrict to n=1000 since the paper's |
| multi-seed claim is at n=1000 (5 grokking + 3 standard).""" |
| by_cond = {"grokking": [], "standard": []} |
| files = (sorted(glob.glob(str(ROOT / "experiments/runs/20260505-*/mechinterp/m6_neuron_ablation_*.json"))) + |
| sorted(glob.glob(str(ROOT / "experiments/runs/20260508-*/mechinterp/m6_neuron_ablation_*.json")))) |
| for f in files: |
| rd = Path(f).parent.parent |
| s = json.loads((rd / "results" / "summary.json").read_text()) |
| d = json.loads(Path(f).read_text()) |
| if not d.get("include_id"): |
| continue |
| cond = s.get("condition") |
| if cond not in by_cond: continue |
| if only_n1000 and s.get("n_train") != 1000: |
| continue |
| sweep = d["sweep"] |
| ks = [r["k"] for r in sweep] |
| by_cond[cond].append({ |
| "n": s.get("n_train"), "seed": s.get("seed"), |
| "epoch": d["epoch"], "ks": ks, |
| "shortcut_ood": [r["shortcut_head_ood"] for r in sweep], |
| "shortcut_id": [r.get("shortcut_head_id", float("nan")) for r in sweep], |
| "random_ood_mu": [r["random_head_ood_mean"] for r in sweep], |
| "random_ood_sd": [r["random_head_ood_std"] for r in sweep], |
| "random_id_mu": [r.get("random_head_id_mean", float("nan")) for r in sweep], |
| "morph_ood": [r.get("morphology_head_ood", float("nan")) for r in sweep], |
| }) |
| return by_cond |
|
|
|
|
| def _stack(runs, key): |
| """Align runs on shared K-grid and return (ks, matrix shape (n_runs, n_ks)).""" |
| if not runs: |
| return None, None |
| ks_set = set.intersection(*[set(r["ks"]) for r in runs]) |
| ks = sorted(ks_set) |
| mat = np.array([ |
| [next(v for k_, v in zip(r["ks"], r[key]) if k_ == k) for k in ks] |
| for r in runs |
| ]) |
| return np.array(ks), mat |
|
|
|
|
| def main(): |
| data = _gather() |
| print(f"Grokking runs: {len(data['grokking'])}, Standard runs: {len(data['standard'])}") |
|
|
| fig, axes = plt.subplots(2, 2, figsize=(15, 9)) |
|
|
| for col, cond in enumerate(["grokking", "standard"]): |
| runs = data[cond] |
| if not runs: |
| for r in range(2): |
| axes[r][col].text(0.5, 0.5, f"no {cond} M6 data", |
| ha="center", va="center", |
| transform=axes[r][col].transAxes, color="gray") |
| continue |
|
|
| |
| ks, sc_ood = _stack(runs, "shortcut_ood") |
| _, rd_ood = _stack(runs, "random_ood_mu") |
| _, mo_ood = _stack(runs, "morph_ood") |
| _, sc_id = _stack(runs, "shortcut_id") |
| _, rd_id = _stack(runs, "random_id_mu") |
|
|
| |
| sc_ood_d = sc_ood - sc_ood[:, 0:1] |
| rd_ood_d = rd_ood - rd_ood[:, 0:1] |
| mo_ood_d = mo_ood - mo_ood[:, 0:1] |
| sc_id_d = sc_id - sc_id[:, 0:1] |
| rd_id_d = rd_id - rd_id[:, 0:1] |
|
|
| n_seeds = len(runs) |
|
|
| |
| ax = axes[0][col] |
| ax.plot(ks, sc_ood_d.mean(0), "r-o", lw=2.4, ms=7, |
| label=f"top-K shortcut (n={n_seeds})") |
| ax.fill_between(ks, sc_ood_d.mean(0) - sc_ood_d.std(0), |
| sc_ood_d.mean(0) + sc_ood_d.std(0), |
| color="red", alpha=0.18) |
| ax.plot(ks, rd_ood_d.mean(0), "k-s", lw=2.0, ms=6, |
| label=f"K random (n={n_seeds})") |
| ax.fill_between(ks, rd_ood_d.mean(0) - rd_ood_d.std(0), |
| rd_ood_d.mean(0) + rd_ood_d.std(0), |
| color="black", alpha=0.12) |
| if not np.isnan(mo_ood_d).all(): |
| ax.plot(ks, mo_ood_d.mean(0), "g-^", lw=1.8, ms=5, |
| label=f"top-K morphology (n={n_seeds})") |
| ax.axhline(0, color="gray", ls=":", lw=1, alpha=0.5) |
| ax.set_xscale("symlog", linthresh=4) |
| ax.set_xlabel("K (avgpool neurons zeroed)") |
| ax.set_ylabel("Δ head OOD vs K=0 baseline") |
| ax.set_title(f"{cond.upper()} — change in head OOD vs K\n" |
| f"(positive = ablation HELPS OOD; red ≠ black = targeted ablation is selective)", |
| fontweight="bold", fontsize=10) |
| ax.legend(fontsize=8); ax.grid(alpha=0.3) |
|
|
| |
| ax = axes[1][col] |
| ax.plot(ks, sc_id_d.mean(0), "r--o", lw=2.0, ms=6, alpha=0.85, |
| label=f"top-K shortcut") |
| ax.fill_between(ks, sc_id_d.mean(0) - sc_id_d.std(0), |
| sc_id_d.mean(0) + sc_id_d.std(0), |
| color="red", alpha=0.12) |
| ax.plot(ks, rd_id_d.mean(0), "k--s", lw=1.8, ms=5, alpha=0.85, |
| label=f"K random") |
| ax.fill_between(ks, rd_id_d.mean(0) - rd_id_d.std(0), |
| rd_id_d.mean(0) + rd_id_d.std(0), |
| color="black", alpha=0.10) |
| ax.axhline(0, color="gray", ls=":", lw=1, alpha=0.5) |
| ax.set_xscale("symlog", linthresh=4) |
| ax.set_xlabel("K (avgpool neurons zeroed)") |
| ax.set_ylabel("Δ head ID vs K=0 baseline") |
| ax.set_title(f"{cond.upper()} — change in head ID vs K\n" |
| f"(both should drop under heavy ablation; targeted ≈ random ID = no extra ID damage)", |
| fontweight="bold", fontsize=10) |
| ax.legend(fontsize=8); ax.grid(alpha=0.3) |
|
|
| fig.suptitle("M6 — Targeted Shortcut Neuron Ablation vs Random Control (n=1000)\n" |
| "Per-seed selectivity: 3/5 grokking show targeted-shortcut > random at K=64; 0/3 standard.", |
| fontsize=12, fontweight="bold", y=1.005) |
| plt.tight_layout() |
| out = ROOT / "paper_figures" / "figure_m6_targeted_vs_random" |
| fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight") |
| fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") |
| plt.close(fig) |
| print(f" Saved {out}.png + .pdf") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|