File size: 3,308 Bytes
50fa85c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""M6 — Neuron Ablation cross-run comparison figure.

Reads experiments/runs/*/mechinterp/m6_neuron_ablation_*.json
Plots head OOD acc vs K (top-K hospital-correlated neurons zeroed) for every
run, grouped by condition (grokking vs standard).
"""
from __future__ import annotations

import glob, json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

ROOT = Path(__file__).resolve().parent.parent


def main():
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    by_cond = {"grokking": [], "standard": []}
    for f in sorted(glob.glob(str(ROOT / "experiments/runs/20260505-*/mechinterp/m6_neuron_ablation_*.json"))):
        rd = Path(f).parent.parent
        s = json.loads((rd / "results" / "summary.json").read_text())
        d = json.loads(Path(f).read_text())
        cond = s.get("condition")
        if cond not in by_cond:
            continue
        by_cond[cond].append({
            "n": s.get("n_train"),
            "seed": s.get("seed"),
            "epoch": d["epoch"],
            "ks":    [r["k"] for r in d["sweep"]],
            "head":  [r["head_ood_acc"] for r in d["sweep"]],
            "tum":   [r["tumor_probe"]  for r in d["sweep"]],
        })

    cmaps = {"grokking": plt.cm.Blues, "standard": plt.cm.Reds}
    for col, cond in enumerate(["grokking", "standard"]):
        ax = axes[col]
        runs = by_cond[cond]
        if not runs:
            ax.text(0.5, 0.5, f"no {cond} M6 data",
                    ha="center", va="center", transform=ax.transAxes,
                    color="gray", fontsize=11)
            continue

        cmap = cmaps[cond]
        for i, r in enumerate(runs):
            color = cmap(0.45 + 0.45 * i / max(1, len(runs) - 1))
            ax.plot(r["ks"], r["head"], "-o", color=color, lw=2, ms=6,
                    label=f"n={r['n']} s{r['seed']}  ep{r['epoch']}")
        # Aggregate at the most-common K-grid
        common_ks = sorted(set.intersection(*[set(r["ks"]) for r in runs]))
        head_mat = np.array([
            [next(h for k_, h in zip(r["ks"], r["head"]) if k_ == k) for k in common_ks]
            for r in runs
        ])
        if len(runs) >= 2:
            mu = head_mat.mean(0); sd = head_mat.std(0)
            ax.plot(common_ks, mu, "k-", lw=2.5, label=f"mean (n={len(runs)} runs)")
            ax.fill_between(common_ks, mu - sd, mu + sd, color="black", alpha=0.15)

        ax.set_xlabel("K (top-K hospital-correlated neurons zeroed)")
        ax.set_ylabel("Head OOD (H4) accuracy")
        ax.set_xscale("symlog", linthresh=4)
        ax.set_title(f"{cond.upper()} — head OOD vs K\n"
                     f"(decreasing K=0→K=256 measures shortcut-neuron causal weight)",
                     fontweight="bold", fontsize=10)
        ax.legend(fontsize=8, ncol=1); ax.grid(alpha=0.3)

    fig.suptitle("M6 — Neuron-level Causal Ablation across all runs",
                 fontsize=12, fontweight="bold", y=1.01)
    plt.tight_layout()
    out = ROOT / "paper_figures" / "figure_m6_neuron_comparison"
    fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight")
    fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved {out}.png + .pdf")


if __name__ == "__main__":
    main()