CausalGrok / code /experiments /figure_multiseed_intervention.py
nileshsarkar-ai's picture
Upload code/experiments
50fa85c verified
"""Multi-seed mean Β± std figure for M4 and M5 across the 3 n=1000 grokking seeds
(s42, s123, s456) plus single-seed standard baselines as reference lines.
Reads:
experiments/runs/<id>/mechinterp/m4_ablation_avgpool_trajectory.json
experiments/runs/<id>/mechinterp/m5_steering_*.json
Outputs:
paper_figures/figure_multiseed_intervention.{png,pdf}
"""
from __future__ import annotations
import glob
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parent.parent
def _load_summary(run_dir: Path) -> dict:
p = run_dir / "results" / "summary.json"
return json.loads(p.read_text()) if p.exists() else {}
def _gather():
"""Return per-(cond, n) β†’ list of run dicts with M4 traj + M5 sweep."""
out = {}
# Use both 2026-05-05 (initial 7) and 2026-05-08 (4 new) runs
run_dirs = sorted(list(ROOT.glob("experiments/runs/20260505-*/")) +
list(ROOT.glob("experiments/runs/20260508-*/")))
for run_dir in run_dirs:
s = _load_summary(run_dir)
cond = s.get("condition")
n = s.get("n_train")
if cond is None or n is None:
continue
m4_p = run_dir / "mechinterp" / "m4_ablation_avgpool_trajectory.json"
m5_p = list((run_dir / "mechinterp").glob("m5_steering_*.json"))
m4 = json.loads(m4_p.read_text()) if m4_p.exists() else None
m5 = json.loads(m5_p[0].read_text()) if m5_p else None
if m4 is None and m5 is None:
continue
out.setdefault((cond, n), []).append({
"run_id": run_dir.name,
"seed": s.get("seed"),
"best_ood": s.get("best_ood"),
"m4": m4, "m5": m5,
})
return out
def _stack_m4(runs):
"""Align M4 trajectories on shared epochs and return:
(eps, raw_mat, abl_mat, delta_mat) β€” each (n_runs, n_eps)."""
if not runs:
return None, None, None, None
eps_set = set.intersection(*[
{r["epoch"] for r in run["m4"]} for run in runs if run["m4"]
])
eps = sorted(eps_set)
raw = np.array([
[next(r for r in run["m4"] if r["epoch"] == e)["head_ood_acc_raw"]
for e in eps] for run in runs
])
abl = np.array([
[next(r for r in run["m4"] if r["epoch"] == e)["head_ood_acc_ablated"]
for e in eps] for run in runs
])
return np.array(eps), raw, abl, abl - raw
def _stack_m5(runs):
"""Align M5 sweeps on shared Ξ± and return (alphas, head_mat, hosp_mat, tum_mat)."""
runs = [run for run in runs if run["m5"]]
if not runs:
return None, None, None, None
alphas_set = set.intersection(*[
{r["alpha"] for r in run["m5"]["sweep"]} for run in runs
])
alphas = sorted(alphas_set)
head = np.array([
[next(r for r in run["m5"]["sweep"] if r["alpha"] == a)["head_ood_acc"]
for a in alphas] for run in runs
])
return np.array(alphas), head
def main():
data = _gather()
n1000_grok = data.get(("grokking", 1000), [])
n1000_std = data.get(("standard", 1000), [])
print(f" grokking n=1000 seeds: {[r['seed'] for r in n1000_grok]}")
print(f" standard n=1000 seeds: {[r['seed'] for r in n1000_std]}")
if len(n1000_grok) < 2:
print(" Need β‰₯2 grokking n=1000 seeds; aborting.")
return
fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
# ─────────────── Panel A: M4 trajectory mean Β± std ───────────────
ax = axes[0]
eps, raw_g, abl_g, _ = _stack_m4(n1000_grok)
if eps is not None:
n_seeds_g = len(n1000_grok)
raw_mu, raw_sd = raw_g.mean(0), raw_g.std(0)
abl_mu, abl_sd = abl_g.mean(0), abl_g.std(0)
ax.plot(eps, raw_mu, "-", color="navy", lw=2.2,
label=f"grok raw (n={n_seeds_g} seeds)")
ax.fill_between(eps, raw_mu - raw_sd, raw_mu + raw_sd,
color="navy", alpha=0.18)
ax.plot(eps, abl_mu, "--", color="darkorange", lw=2.2,
label=f"grok ablated (n={n_seeds_g} seeds)")
ax.fill_between(eps, abl_mu - abl_sd, abl_mu + abl_sd,
color="darkorange", alpha=0.18)
if n1000_std:
eps_s, raw_s, abl_s, _ = _stack_m4(n1000_std)
if eps_s is not None:
ax.plot(eps_s, raw_s.mean(0), "-", color="darkred", lw=1.6,
alpha=0.9, label=f"standard raw (n={len(n1000_std)} seed)")
ax.plot(eps_s, abl_s.mean(0), "--", color="darkred", lw=1.6,
alpha=0.9, label=f"standard ablated (n={len(n1000_std)} seed)")
ax.set_xlabel("Training epoch")
ax.set_ylabel("Head OOD (H4) accuracy")
ax.set_title("M4 β€” Shortcut subspace ablation across training\n"
"(mean Β± std over seeds for n=1000)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8, loc="lower right"); ax.grid(alpha=0.3)
ax.set_ylim(0.42, 0.80)
# ─────────────── Panel B: M5 steering mean Β± std ───────────────
ax = axes[1]
a_g, head_g = _stack_m5(n1000_grok)
if a_g is not None and len(n1000_grok) >= 2:
mu = head_g.mean(0); sd = head_g.std(0)
ax.plot(a_g, mu, "-o", color="navy", lw=2.2, ms=7,
label=f"grokking n=1000 (n={len(n1000_grok)} seeds)")
ax.fill_between(a_g, mu - sd, mu + sd, color="navy", alpha=0.20)
if n1000_std:
a_s, head_s = _stack_m5(n1000_std)
if a_s is not None:
ax.plot(a_s, head_s.mean(0), "-s", color="darkred", lw=1.8, ms=6,
alpha=0.9, label=f"standard n=1000 (n={len(n1000_std)} seed)")
ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5)
ax.set_xlabel("Steering coefficient Ξ± (Οƒ-units of v_s)")
ax.set_ylabel("Head OOD (H4) accuracy")
ax.set_title("M5 β€” Activation steering response\n"
"(mean Β± std over seeds for n=1000)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8); ax.grid(alpha=0.3)
fig.suptitle("Multi-seed Causal-Intervention Robustness (n=1000, 3 seeds for grokking)",
fontsize=12, fontweight="bold", y=1.01)
plt.tight_layout()
out = ROOT / "paper_figures" / "figure_multiseed_intervention"
fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight")
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
plt.close(fig)
print(f" Saved {out}.png + .pdf")
if __name__ == "__main__":
main()