CausalGrok / code /experiments /figure_m6_neuron_comparison.py
nileshsarkar-ai's picture
Upload code/experiments
50fa85c verified
"""M6 — Neuron Ablation cross-run comparison figure.
Reads experiments/runs/*/mechinterp/m6_neuron_ablation_*.json
Plots head OOD acc vs K (top-K hospital-correlated neurons zeroed) for every
run, grouped by condition (grokking vs standard).
"""
from __future__ import annotations
import glob, json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parent.parent
def main():
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
by_cond = {"grokking": [], "standard": []}
for f in sorted(glob.glob(str(ROOT / "experiments/runs/20260505-*/mechinterp/m6_neuron_ablation_*.json"))):
rd = Path(f).parent.parent
s = json.loads((rd / "results" / "summary.json").read_text())
d = json.loads(Path(f).read_text())
cond = s.get("condition")
if cond not in by_cond:
continue
by_cond[cond].append({
"n": s.get("n_train"),
"seed": s.get("seed"),
"epoch": d["epoch"],
"ks": [r["k"] for r in d["sweep"]],
"head": [r["head_ood_acc"] for r in d["sweep"]],
"tum": [r["tumor_probe"] for r in d["sweep"]],
})
cmaps = {"grokking": plt.cm.Blues, "standard": plt.cm.Reds}
for col, cond in enumerate(["grokking", "standard"]):
ax = axes[col]
runs = by_cond[cond]
if not runs:
ax.text(0.5, 0.5, f"no {cond} M6 data",
ha="center", va="center", transform=ax.transAxes,
color="gray", fontsize=11)
continue
cmap = cmaps[cond]
for i, r in enumerate(runs):
color = cmap(0.45 + 0.45 * i / max(1, len(runs) - 1))
ax.plot(r["ks"], r["head"], "-o", color=color, lw=2, ms=6,
label=f"n={r['n']} s{r['seed']} ep{r['epoch']}")
# Aggregate at the most-common K-grid
common_ks = sorted(set.intersection(*[set(r["ks"]) for r in runs]))
head_mat = np.array([
[next(h for k_, h in zip(r["ks"], r["head"]) if k_ == k) for k in common_ks]
for r in runs
])
if len(runs) >= 2:
mu = head_mat.mean(0); sd = head_mat.std(0)
ax.plot(common_ks, mu, "k-", lw=2.5, label=f"mean (n={len(runs)} runs)")
ax.fill_between(common_ks, mu - sd, mu + sd, color="black", alpha=0.15)
ax.set_xlabel("K (top-K hospital-correlated neurons zeroed)")
ax.set_ylabel("Head OOD (H4) accuracy")
ax.set_xscale("symlog", linthresh=4)
ax.set_title(f"{cond.upper()} — head OOD vs K\n"
f"(decreasing K=0→K=256 measures shortcut-neuron causal weight)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8, ncol=1); ax.grid(alpha=0.3)
fig.suptitle("M6 — Neuron-level Causal Ablation across all runs",
fontsize=12, fontweight="bold", y=1.01)
plt.tight_layout()
out = ROOT / "paper_figures" / "figure_m6_neuron_comparison"
fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight")
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
plt.close(fig)
print(f" Saved {out}.png + .pdf")
if __name__ == "__main__":
main()