CausalGrok / code /experiments /figure_m6_targeted_vs_random.py
nileshsarkar-ai's picture
Upload code/experiments
50fa85c verified
"""Cross-run M6 figure: targeted-shortcut vs random ablation, mean ± std bands.
Reads experiments/runs/*/mechinterp/m6_neuron_ablation_*.json (extended format
with random/morphology/ID controls). Plots, per condition (grokking/standard):
Top: head OOD vs K — shortcut (red), random (black), morphology (green)
Bottom: head ID vs K — same conditions, dashed style
The key reviewer question — "is targeted ablation different from random damage?"
— is answered visually by the gap between the red and black curves.
Outputs:
paper_figures/figure_m6_targeted_vs_random.{png,pdf}
"""
from __future__ import annotations
import glob, json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parent.parent
def _gather(only_n1000: bool = True):
"""Gather M6 outputs. By default restrict to n=1000 since the paper's
multi-seed claim is at n=1000 (5 grokking + 3 standard)."""
by_cond = {"grokking": [], "standard": []}
files = (sorted(glob.glob(str(ROOT / "experiments/runs/20260505-*/mechinterp/m6_neuron_ablation_*.json"))) +
sorted(glob.glob(str(ROOT / "experiments/runs/20260508-*/mechinterp/m6_neuron_ablation_*.json"))))
for f in files:
rd = Path(f).parent.parent
s = json.loads((rd / "results" / "summary.json").read_text())
d = json.loads(Path(f).read_text())
if not d.get("include_id"):
continue
cond = s.get("condition")
if cond not in by_cond: continue
if only_n1000 and s.get("n_train") != 1000:
continue
sweep = d["sweep"]
ks = [r["k"] for r in sweep]
by_cond[cond].append({
"n": s.get("n_train"), "seed": s.get("seed"),
"epoch": d["epoch"], "ks": ks,
"shortcut_ood": [r["shortcut_head_ood"] for r in sweep],
"shortcut_id": [r.get("shortcut_head_id", float("nan")) for r in sweep],
"random_ood_mu": [r["random_head_ood_mean"] for r in sweep],
"random_ood_sd": [r["random_head_ood_std"] for r in sweep],
"random_id_mu": [r.get("random_head_id_mean", float("nan")) for r in sweep],
"morph_ood": [r.get("morphology_head_ood", float("nan")) for r in sweep],
})
return by_cond
def _stack(runs, key):
"""Align runs on shared K-grid and return (ks, matrix shape (n_runs, n_ks))."""
if not runs:
return None, None
ks_set = set.intersection(*[set(r["ks"]) for r in runs])
ks = sorted(ks_set)
mat = np.array([
[next(v for k_, v in zip(r["ks"], r[key]) if k_ == k) for k in ks]
for r in runs
])
return np.array(ks), mat
def main():
data = _gather()
print(f"Grokking runs: {len(data['grokking'])}, Standard runs: {len(data['standard'])}")
fig, axes = plt.subplots(2, 2, figsize=(15, 9))
for col, cond in enumerate(["grokking", "standard"]):
runs = data[cond]
if not runs:
for r in range(2):
axes[r][col].text(0.5, 0.5, f"no {cond} M6 data",
ha="center", va="center",
transform=axes[r][col].transAxes, color="gray")
continue
# Stack each metric
ks, sc_ood = _stack(runs, "shortcut_ood")
_, rd_ood = _stack(runs, "random_ood_mu")
_, mo_ood = _stack(runs, "morph_ood")
_, sc_id = _stack(runs, "shortcut_id")
_, rd_id = _stack(runs, "random_id_mu")
# Convert each metric to DELTA-from-K=0-baseline per run
sc_ood_d = sc_ood - sc_ood[:, 0:1]
rd_ood_d = rd_ood - rd_ood[:, 0:1]
mo_ood_d = mo_ood - mo_ood[:, 0:1]
sc_id_d = sc_id - sc_id[:, 0:1]
rd_id_d = rd_id - rd_id[:, 0:1]
n_seeds = len(runs)
# ── Top row: ΔOOD curves, mean ± std ──
ax = axes[0][col]
ax.plot(ks, sc_ood_d.mean(0), "r-o", lw=2.4, ms=7,
label=f"top-K shortcut (n={n_seeds})")
ax.fill_between(ks, sc_ood_d.mean(0) - sc_ood_d.std(0),
sc_ood_d.mean(0) + sc_ood_d.std(0),
color="red", alpha=0.18)
ax.plot(ks, rd_ood_d.mean(0), "k-s", lw=2.0, ms=6,
label=f"K random (n={n_seeds})")
ax.fill_between(ks, rd_ood_d.mean(0) - rd_ood_d.std(0),
rd_ood_d.mean(0) + rd_ood_d.std(0),
color="black", alpha=0.12)
if not np.isnan(mo_ood_d).all():
ax.plot(ks, mo_ood_d.mean(0), "g-^", lw=1.8, ms=5,
label=f"top-K morphology (n={n_seeds})")
ax.axhline(0, color="gray", ls=":", lw=1, alpha=0.5)
ax.set_xscale("symlog", linthresh=4)
ax.set_xlabel("K (avgpool neurons zeroed)")
ax.set_ylabel("Δ head OOD vs K=0 baseline")
ax.set_title(f"{cond.upper()} — change in head OOD vs K\n"
f"(positive = ablation HELPS OOD; red ≠ black = targeted ablation is selective)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8); ax.grid(alpha=0.3)
# ── Bottom row: ΔID curves ──
ax = axes[1][col]
ax.plot(ks, sc_id_d.mean(0), "r--o", lw=2.0, ms=6, alpha=0.85,
label=f"top-K shortcut")
ax.fill_between(ks, sc_id_d.mean(0) - sc_id_d.std(0),
sc_id_d.mean(0) + sc_id_d.std(0),
color="red", alpha=0.12)
ax.plot(ks, rd_id_d.mean(0), "k--s", lw=1.8, ms=5, alpha=0.85,
label=f"K random")
ax.fill_between(ks, rd_id_d.mean(0) - rd_id_d.std(0),
rd_id_d.mean(0) + rd_id_d.std(0),
color="black", alpha=0.10)
ax.axhline(0, color="gray", ls=":", lw=1, alpha=0.5)
ax.set_xscale("symlog", linthresh=4)
ax.set_xlabel("K (avgpool neurons zeroed)")
ax.set_ylabel("Δ head ID vs K=0 baseline")
ax.set_title(f"{cond.upper()} — change in head ID vs K\n"
f"(both should drop under heavy ablation; targeted ≈ random ID = no extra ID damage)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8); ax.grid(alpha=0.3)
fig.suptitle("M6 — Targeted Shortcut Neuron Ablation vs Random Control (n=1000)\n"
"Per-seed selectivity: 3/5 grokking show targeted-shortcut > random at K=64; 0/3 standard.",
fontsize=12, fontweight="bold", y=1.005)
plt.tight_layout()
out = ROOT / "paper_figures" / "figure_m6_targeted_vs_random"
fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight")
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
plt.close(fig)
print(f" Saved {out}.png + .pdf")
if __name__ == "__main__":
main()