| """Per-seed bar chart for M6: shows the seed-level signal that gets averaged |
| away in mean curves. For each run, plot ΔOOD at K=64 and K=256 for the three |
| ablation strategies (shortcut/random/morphology). |
| |
| Outputs: |
| paper_figures/figure_m6_per_seed_bars.{png,pdf} |
| """ |
| from __future__ import annotations |
|
|
| import glob, json |
| from pathlib import Path |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
|
|
|
|
| def main(): |
| rows = [] |
| files = (sorted(glob.glob(str(ROOT / "experiments/runs/20260505-*/mechinterp/m6_neuron_ablation_*.json"))) + |
| sorted(glob.glob(str(ROOT / "experiments/runs/20260508-*/mechinterp/m6_neuron_ablation_*.json")))) |
| for f in files: |
| rd = Path(f).parent.parent |
| s = json.loads((rd / "results" / "summary.json").read_text()) |
| d = json.loads(Path(f).read_text()) |
| if not d.get("include_id"): |
| continue |
| sweep = {r["k"]: r for r in d["sweep"]} |
| base = sweep[0]["shortcut_head_ood"] |
| for K in [64, 256]: |
| r = sweep.get(K) |
| if r is None: continue |
| rows.append({ |
| "label": f"{s['condition'][0]}n{s['n_train']}\ns{s['seed']}", |
| "condition": s["condition"], |
| "K": K, |
| "delta_shortcut": r["shortcut_head_ood"] - base, |
| "delta_random_mu": r["random_head_ood_mean"]- base, |
| "delta_random_sd": r["random_head_ood_std"], |
| "delta_morphology": r.get("morphology_head_ood", float("nan")) - base, |
| }) |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(15, 5.5)) |
| for col, K in enumerate([64, 256]): |
| ax = axes[col] |
| rs = [r for r in rows if r["K"] == K] |
| labels = [r["label"] for r in rs] |
| x = np.arange(len(rs)) |
| w = 0.27 |
|
|
| sc = [r["delta_shortcut"] for r in rs] |
| rd_mu = [r["delta_random_mu"] for r in rs] |
| rd_sd = [r["delta_random_sd"] for r in rs] |
| mo = [r["delta_morphology"] for r in rs] |
| is_grok = ["grokking" in labels[i] or labels[i].startswith("g") for i in range(len(rs))] |
|
|
| b1 = ax.bar(x - w, sc, w, label="top-K shortcut", |
| color=["#c0392b" if c == "grokking" else "#922b21" for r, c in zip(rs, [r["condition"] for r in rs])]) |
| b2 = ax.bar(x, rd_mu, w, yerr=rd_sd, capsize=3, |
| label="K random", |
| color="#34495e") |
| b3 = ax.bar(x + w, mo, w, label="top-K morphology", |
| color="#27ae60") |
|
|
| ax.axhline(0, color="black", lw=0.6) |
| ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=8, rotation=0) |
| ax.set_ylabel("Δ head OOD vs K=0 baseline") |
| ax.set_title(f"K = {K} neurons zeroed", fontweight="bold", fontsize=11) |
| ax.legend(fontsize=8); ax.grid(alpha=0.3, axis="y") |
|
|
| |
| for bs, vals in zip([b1, b2, b3], [sc, rd_mu, mo]): |
| for b, v in zip(bs, vals): |
| if np.isnan(v): continue |
| y = b.get_height() |
| ax.text(b.get_x() + b.get_width() / 2, |
| y + (0.003 if y >= 0 else -0.008), |
| f"{v:+.3f}", ha="center", fontsize=7, |
| color="darkgreen" if v > 0 else "darkred") |
|
|
| fig.suptitle("M6 — Per-Seed Δ OOD from Targeted vs Random Ablation\n" |
| "(positive bar = ablation helps OOD; targeted > random = selective causal intervention)", |
| fontsize=12, fontweight="bold", y=1.02) |
| plt.tight_layout() |
| out = ROOT / "paper_figures" / "figure_m6_per_seed_bars" |
| fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight") |
| fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") |
| plt.close(fig) |
| print(f" Saved {out}.png + .pdf") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|