| """M6 — Neuron Ablation cross-run comparison figure. |
| |
| Reads experiments/runs/*/mechinterp/m6_neuron_ablation_*.json |
| Plots head OOD acc vs K (top-K hospital-correlated neurons zeroed) for every |
| run, grouped by condition (grokking vs standard). |
| """ |
| from __future__ import annotations |
|
|
| import glob, json |
| from pathlib import Path |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
|
|
|
|
| def main(): |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| by_cond = {"grokking": [], "standard": []} |
| for f in sorted(glob.glob(str(ROOT / "experiments/runs/20260505-*/mechinterp/m6_neuron_ablation_*.json"))): |
| rd = Path(f).parent.parent |
| s = json.loads((rd / "results" / "summary.json").read_text()) |
| d = json.loads(Path(f).read_text()) |
| cond = s.get("condition") |
| if cond not in by_cond: |
| continue |
| by_cond[cond].append({ |
| "n": s.get("n_train"), |
| "seed": s.get("seed"), |
| "epoch": d["epoch"], |
| "ks": [r["k"] for r in d["sweep"]], |
| "head": [r["head_ood_acc"] for r in d["sweep"]], |
| "tum": [r["tumor_probe"] for r in d["sweep"]], |
| }) |
|
|
| cmaps = {"grokking": plt.cm.Blues, "standard": plt.cm.Reds} |
| for col, cond in enumerate(["grokking", "standard"]): |
| ax = axes[col] |
| runs = by_cond[cond] |
| if not runs: |
| ax.text(0.5, 0.5, f"no {cond} M6 data", |
| ha="center", va="center", transform=ax.transAxes, |
| color="gray", fontsize=11) |
| continue |
|
|
| cmap = cmaps[cond] |
| for i, r in enumerate(runs): |
| color = cmap(0.45 + 0.45 * i / max(1, len(runs) - 1)) |
| ax.plot(r["ks"], r["head"], "-o", color=color, lw=2, ms=6, |
| label=f"n={r['n']} s{r['seed']} ep{r['epoch']}") |
| |
| common_ks = sorted(set.intersection(*[set(r["ks"]) for r in runs])) |
| head_mat = np.array([ |
| [next(h for k_, h in zip(r["ks"], r["head"]) if k_ == k) for k in common_ks] |
| for r in runs |
| ]) |
| if len(runs) >= 2: |
| mu = head_mat.mean(0); sd = head_mat.std(0) |
| ax.plot(common_ks, mu, "k-", lw=2.5, label=f"mean (n={len(runs)} runs)") |
| ax.fill_between(common_ks, mu - sd, mu + sd, color="black", alpha=0.15) |
|
|
| ax.set_xlabel("K (top-K hospital-correlated neurons zeroed)") |
| ax.set_ylabel("Head OOD (H4) accuracy") |
| ax.set_xscale("symlog", linthresh=4) |
| ax.set_title(f"{cond.upper()} — head OOD vs K\n" |
| f"(decreasing K=0→K=256 measures shortcut-neuron causal weight)", |
| fontweight="bold", fontsize=10) |
| ax.legend(fontsize=8, ncol=1); ax.grid(alpha=0.3) |
|
|
| fig.suptitle("M6 — Neuron-level Causal Ablation across all runs", |
| fontsize=12, fontweight="bold", y=1.01) |
| plt.tight_layout() |
| out = ROOT / "paper_figures" / "figure_m6_neuron_comparison" |
| fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight") |
| fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") |
| plt.close(fig) |
| print(f" Saved {out}.png + .pdf") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|