CausalGrok / code /experiments /figure_m6_per_seed_bars.py
nileshsarkar-ai's picture
Upload code/experiments
50fa85c verified
"""Per-seed bar chart for M6: shows the seed-level signal that gets averaged
away in mean curves. For each run, plot ΔOOD at K=64 and K=256 for the three
ablation strategies (shortcut/random/morphology).
Outputs:
paper_figures/figure_m6_per_seed_bars.{png,pdf}
"""
from __future__ import annotations
import glob, json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parent.parent
def main():
rows = []
files = (sorted(glob.glob(str(ROOT / "experiments/runs/20260505-*/mechinterp/m6_neuron_ablation_*.json"))) +
sorted(glob.glob(str(ROOT / "experiments/runs/20260508-*/mechinterp/m6_neuron_ablation_*.json"))))
for f in files:
rd = Path(f).parent.parent
s = json.loads((rd / "results" / "summary.json").read_text())
d = json.loads(Path(f).read_text())
if not d.get("include_id"):
continue
sweep = {r["k"]: r for r in d["sweep"]}
base = sweep[0]["shortcut_head_ood"]
for K in [64, 256]:
r = sweep.get(K)
if r is None: continue
rows.append({
"label": f"{s['condition'][0]}n{s['n_train']}\ns{s['seed']}",
"condition": s["condition"],
"K": K,
"delta_shortcut": r["shortcut_head_ood"] - base,
"delta_random_mu": r["random_head_ood_mean"]- base,
"delta_random_sd": r["random_head_ood_std"],
"delta_morphology": r.get("morphology_head_ood", float("nan")) - base,
})
fig, axes = plt.subplots(1, 2, figsize=(15, 5.5))
for col, K in enumerate([64, 256]):
ax = axes[col]
rs = [r for r in rows if r["K"] == K]
labels = [r["label"] for r in rs]
x = np.arange(len(rs))
w = 0.27
sc = [r["delta_shortcut"] for r in rs]
rd_mu = [r["delta_random_mu"] for r in rs]
rd_sd = [r["delta_random_sd"] for r in rs]
mo = [r["delta_morphology"] for r in rs]
is_grok = ["grokking" in labels[i] or labels[i].startswith("g") for i in range(len(rs))]
b1 = ax.bar(x - w, sc, w, label="top-K shortcut",
color=["#c0392b" if c == "grokking" else "#922b21" for r, c in zip(rs, [r["condition"] for r in rs])])
b2 = ax.bar(x, rd_mu, w, yerr=rd_sd, capsize=3,
label="K random",
color="#34495e")
b3 = ax.bar(x + w, mo, w, label="top-K morphology",
color="#27ae60")
ax.axhline(0, color="black", lw=0.6)
ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=8, rotation=0)
ax.set_ylabel("Δ head OOD vs K=0 baseline")
ax.set_title(f"K = {K} neurons zeroed", fontweight="bold", fontsize=11)
ax.legend(fontsize=8); ax.grid(alpha=0.3, axis="y")
# Annotate values
for bs, vals in zip([b1, b2, b3], [sc, rd_mu, mo]):
for b, v in zip(bs, vals):
if np.isnan(v): continue
y = b.get_height()
ax.text(b.get_x() + b.get_width() / 2,
y + (0.003 if y >= 0 else -0.008),
f"{v:+.3f}", ha="center", fontsize=7,
color="darkgreen" if v > 0 else "darkred")
fig.suptitle("M6 — Per-Seed Δ OOD from Targeted vs Random Ablation\n"
"(positive bar = ablation helps OOD; targeted > random = selective causal intervention)",
fontsize=12, fontweight="bold", y=1.02)
plt.tight_layout()
out = ROOT / "paper_figures" / "figure_m6_per_seed_bars"
fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight")
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
plt.close(fig)
print(f" Saved {out}.png + .pdf")
if __name__ == "__main__":
main()