File size: 6,912 Bytes
50fa85c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """Cross-run M6 figure: targeted-shortcut vs random ablation, mean Β± std bands.
Reads experiments/runs/*/mechinterp/m6_neuron_ablation_*.json (extended format
with random/morphology/ID controls). Plots, per condition (grokking/standard):
Top: head OOD vs K β shortcut (red), random (black), morphology (green)
Bottom: head ID vs K β same conditions, dashed style
The key reviewer question β "is targeted ablation different from random damage?"
β is answered visually by the gap between the red and black curves.
Outputs:
paper_figures/figure_m6_targeted_vs_random.{png,pdf}
"""
from __future__ import annotations
import glob, json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parent.parent
def _gather(only_n1000: bool = True):
"""Gather M6 outputs. By default restrict to n=1000 since the paper's
multi-seed claim is at n=1000 (5 grokking + 3 standard)."""
by_cond = {"grokking": [], "standard": []}
files = (sorted(glob.glob(str(ROOT / "experiments/runs/20260505-*/mechinterp/m6_neuron_ablation_*.json"))) +
sorted(glob.glob(str(ROOT / "experiments/runs/20260508-*/mechinterp/m6_neuron_ablation_*.json"))))
for f in files:
rd = Path(f).parent.parent
s = json.loads((rd / "results" / "summary.json").read_text())
d = json.loads(Path(f).read_text())
if not d.get("include_id"):
continue
cond = s.get("condition")
if cond not in by_cond: continue
if only_n1000 and s.get("n_train") != 1000:
continue
sweep = d["sweep"]
ks = [r["k"] for r in sweep]
by_cond[cond].append({
"n": s.get("n_train"), "seed": s.get("seed"),
"epoch": d["epoch"], "ks": ks,
"shortcut_ood": [r["shortcut_head_ood"] for r in sweep],
"shortcut_id": [r.get("shortcut_head_id", float("nan")) for r in sweep],
"random_ood_mu": [r["random_head_ood_mean"] for r in sweep],
"random_ood_sd": [r["random_head_ood_std"] for r in sweep],
"random_id_mu": [r.get("random_head_id_mean", float("nan")) for r in sweep],
"morph_ood": [r.get("morphology_head_ood", float("nan")) for r in sweep],
})
return by_cond
def _stack(runs, key):
"""Align runs on shared K-grid and return (ks, matrix shape (n_runs, n_ks))."""
if not runs:
return None, None
ks_set = set.intersection(*[set(r["ks"]) for r in runs])
ks = sorted(ks_set)
mat = np.array([
[next(v for k_, v in zip(r["ks"], r[key]) if k_ == k) for k in ks]
for r in runs
])
return np.array(ks), mat
def main():
data = _gather()
print(f"Grokking runs: {len(data['grokking'])}, Standard runs: {len(data['standard'])}")
fig, axes = plt.subplots(2, 2, figsize=(15, 9))
for col, cond in enumerate(["grokking", "standard"]):
runs = data[cond]
if not runs:
for r in range(2):
axes[r][col].text(0.5, 0.5, f"no {cond} M6 data",
ha="center", va="center",
transform=axes[r][col].transAxes, color="gray")
continue
# Stack each metric
ks, sc_ood = _stack(runs, "shortcut_ood")
_, rd_ood = _stack(runs, "random_ood_mu")
_, mo_ood = _stack(runs, "morph_ood")
_, sc_id = _stack(runs, "shortcut_id")
_, rd_id = _stack(runs, "random_id_mu")
# Convert each metric to DELTA-from-K=0-baseline per run
sc_ood_d = sc_ood - sc_ood[:, 0:1]
rd_ood_d = rd_ood - rd_ood[:, 0:1]
mo_ood_d = mo_ood - mo_ood[:, 0:1]
sc_id_d = sc_id - sc_id[:, 0:1]
rd_id_d = rd_id - rd_id[:, 0:1]
n_seeds = len(runs)
# ββ Top row: ΞOOD curves, mean Β± std ββ
ax = axes[0][col]
ax.plot(ks, sc_ood_d.mean(0), "r-o", lw=2.4, ms=7,
label=f"top-K shortcut (n={n_seeds})")
ax.fill_between(ks, sc_ood_d.mean(0) - sc_ood_d.std(0),
sc_ood_d.mean(0) + sc_ood_d.std(0),
color="red", alpha=0.18)
ax.plot(ks, rd_ood_d.mean(0), "k-s", lw=2.0, ms=6,
label=f"K random (n={n_seeds})")
ax.fill_between(ks, rd_ood_d.mean(0) - rd_ood_d.std(0),
rd_ood_d.mean(0) + rd_ood_d.std(0),
color="black", alpha=0.12)
if not np.isnan(mo_ood_d).all():
ax.plot(ks, mo_ood_d.mean(0), "g-^", lw=1.8, ms=5,
label=f"top-K morphology (n={n_seeds})")
ax.axhline(0, color="gray", ls=":", lw=1, alpha=0.5)
ax.set_xscale("symlog", linthresh=4)
ax.set_xlabel("K (avgpool neurons zeroed)")
ax.set_ylabel("Ξ head OOD vs K=0 baseline")
ax.set_title(f"{cond.upper()} β change in head OOD vs K\n"
f"(positive = ablation HELPS OOD; red β black = targeted ablation is selective)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8); ax.grid(alpha=0.3)
# ββ Bottom row: ΞID curves ββ
ax = axes[1][col]
ax.plot(ks, sc_id_d.mean(0), "r--o", lw=2.0, ms=6, alpha=0.85,
label=f"top-K shortcut")
ax.fill_between(ks, sc_id_d.mean(0) - sc_id_d.std(0),
sc_id_d.mean(0) + sc_id_d.std(0),
color="red", alpha=0.12)
ax.plot(ks, rd_id_d.mean(0), "k--s", lw=1.8, ms=5, alpha=0.85,
label=f"K random")
ax.fill_between(ks, rd_id_d.mean(0) - rd_id_d.std(0),
rd_id_d.mean(0) + rd_id_d.std(0),
color="black", alpha=0.10)
ax.axhline(0, color="gray", ls=":", lw=1, alpha=0.5)
ax.set_xscale("symlog", linthresh=4)
ax.set_xlabel("K (avgpool neurons zeroed)")
ax.set_ylabel("Ξ head ID vs K=0 baseline")
ax.set_title(f"{cond.upper()} β change in head ID vs K\n"
f"(both should drop under heavy ablation; targeted β random ID = no extra ID damage)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8); ax.grid(alpha=0.3)
fig.suptitle("M6 β Targeted Shortcut Neuron Ablation vs Random Control (n=1000)\n"
"Per-seed selectivity: 3/5 grokking show targeted-shortcut > random at K=64; 0/3 standard.",
fontsize=12, fontweight="bold", y=1.005)
plt.tight_layout()
out = ROOT / "paper_figures" / "figure_m6_targeted_vs_random"
fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight")
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
plt.close(fig)
print(f" Saved {out}.png + .pdf")
if __name__ == "__main__":
main()
|