| """Multi-seed mean Β± std figure for M4 and M5 across the 3 n=1000 grokking seeds |
| (s42, s123, s456) plus single-seed standard baselines as reference lines. |
| |
| Reads: |
| experiments/runs/<id>/mechinterp/m4_ablation_avgpool_trajectory.json |
| experiments/runs/<id>/mechinterp/m5_steering_*.json |
| |
| Outputs: |
| paper_figures/figure_multiseed_intervention.{png,pdf} |
| """ |
| from __future__ import annotations |
|
|
| import glob |
| import json |
| from pathlib import Path |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
|
|
|
|
| def _load_summary(run_dir: Path) -> dict: |
| p = run_dir / "results" / "summary.json" |
| return json.loads(p.read_text()) if p.exists() else {} |
|
|
|
|
| def _gather(): |
| """Return per-(cond, n) β list of run dicts with M4 traj + M5 sweep.""" |
| out = {} |
| |
| run_dirs = sorted(list(ROOT.glob("experiments/runs/20260505-*/")) + |
| list(ROOT.glob("experiments/runs/20260508-*/"))) |
| for run_dir in run_dirs: |
| s = _load_summary(run_dir) |
| cond = s.get("condition") |
| n = s.get("n_train") |
| if cond is None or n is None: |
| continue |
| m4_p = run_dir / "mechinterp" / "m4_ablation_avgpool_trajectory.json" |
| m5_p = list((run_dir / "mechinterp").glob("m5_steering_*.json")) |
| m4 = json.loads(m4_p.read_text()) if m4_p.exists() else None |
| m5 = json.loads(m5_p[0].read_text()) if m5_p else None |
| if m4 is None and m5 is None: |
| continue |
| out.setdefault((cond, n), []).append({ |
| "run_id": run_dir.name, |
| "seed": s.get("seed"), |
| "best_ood": s.get("best_ood"), |
| "m4": m4, "m5": m5, |
| }) |
| return out |
|
|
|
|
| def _stack_m4(runs): |
| """Align M4 trajectories on shared epochs and return: |
| (eps, raw_mat, abl_mat, delta_mat) β each (n_runs, n_eps).""" |
| if not runs: |
| return None, None, None, None |
| eps_set = set.intersection(*[ |
| {r["epoch"] for r in run["m4"]} for run in runs if run["m4"] |
| ]) |
| eps = sorted(eps_set) |
| raw = np.array([ |
| [next(r for r in run["m4"] if r["epoch"] == e)["head_ood_acc_raw"] |
| for e in eps] for run in runs |
| ]) |
| abl = np.array([ |
| [next(r for r in run["m4"] if r["epoch"] == e)["head_ood_acc_ablated"] |
| for e in eps] for run in runs |
| ]) |
| return np.array(eps), raw, abl, abl - raw |
|
|
|
|
| def _stack_m5(runs): |
| """Align M5 sweeps on shared Ξ± and return (alphas, head_mat, hosp_mat, tum_mat).""" |
| runs = [run for run in runs if run["m5"]] |
| if not runs: |
| return None, None, None, None |
| alphas_set = set.intersection(*[ |
| {r["alpha"] for r in run["m5"]["sweep"]} for run in runs |
| ]) |
| alphas = sorted(alphas_set) |
| head = np.array([ |
| [next(r for r in run["m5"]["sweep"] if r["alpha"] == a)["head_ood_acc"] |
| for a in alphas] for run in runs |
| ]) |
| return np.array(alphas), head |
|
|
|
|
| def main(): |
| data = _gather() |
| n1000_grok = data.get(("grokking", 1000), []) |
| n1000_std = data.get(("standard", 1000), []) |
| print(f" grokking n=1000 seeds: {[r['seed'] for r in n1000_grok]}") |
| print(f" standard n=1000 seeds: {[r['seed'] for r in n1000_std]}") |
|
|
| if len(n1000_grok) < 2: |
| print(" Need β₯2 grokking n=1000 seeds; aborting.") |
| return |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5.5)) |
|
|
| |
| ax = axes[0] |
| eps, raw_g, abl_g, _ = _stack_m4(n1000_grok) |
| if eps is not None: |
| n_seeds_g = len(n1000_grok) |
| raw_mu, raw_sd = raw_g.mean(0), raw_g.std(0) |
| abl_mu, abl_sd = abl_g.mean(0), abl_g.std(0) |
| ax.plot(eps, raw_mu, "-", color="navy", lw=2.2, |
| label=f"grok raw (n={n_seeds_g} seeds)") |
| ax.fill_between(eps, raw_mu - raw_sd, raw_mu + raw_sd, |
| color="navy", alpha=0.18) |
| ax.plot(eps, abl_mu, "--", color="darkorange", lw=2.2, |
| label=f"grok ablated (n={n_seeds_g} seeds)") |
| ax.fill_between(eps, abl_mu - abl_sd, abl_mu + abl_sd, |
| color="darkorange", alpha=0.18) |
|
|
| if n1000_std: |
| eps_s, raw_s, abl_s, _ = _stack_m4(n1000_std) |
| if eps_s is not None: |
| ax.plot(eps_s, raw_s.mean(0), "-", color="darkred", lw=1.6, |
| alpha=0.9, label=f"standard raw (n={len(n1000_std)} seed)") |
| ax.plot(eps_s, abl_s.mean(0), "--", color="darkred", lw=1.6, |
| alpha=0.9, label=f"standard ablated (n={len(n1000_std)} seed)") |
|
|
| ax.set_xlabel("Training epoch") |
| ax.set_ylabel("Head OOD (H4) accuracy") |
| ax.set_title("M4 β Shortcut subspace ablation across training\n" |
| "(mean Β± std over seeds for n=1000)", |
| fontweight="bold", fontsize=10) |
| ax.legend(fontsize=8, loc="lower right"); ax.grid(alpha=0.3) |
| ax.set_ylim(0.42, 0.80) |
|
|
| |
| ax = axes[1] |
| a_g, head_g = _stack_m5(n1000_grok) |
| if a_g is not None and len(n1000_grok) >= 2: |
| mu = head_g.mean(0); sd = head_g.std(0) |
| ax.plot(a_g, mu, "-o", color="navy", lw=2.2, ms=7, |
| label=f"grokking n=1000 (n={len(n1000_grok)} seeds)") |
| ax.fill_between(a_g, mu - sd, mu + sd, color="navy", alpha=0.20) |
| if n1000_std: |
| a_s, head_s = _stack_m5(n1000_std) |
| if a_s is not None: |
| ax.plot(a_s, head_s.mean(0), "-s", color="darkred", lw=1.8, ms=6, |
| alpha=0.9, label=f"standard n=1000 (n={len(n1000_std)} seed)") |
| ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5) |
| ax.set_xlabel("Steering coefficient Ξ± (Ο-units of v_s)") |
| ax.set_ylabel("Head OOD (H4) accuracy") |
| ax.set_title("M5 β Activation steering response\n" |
| "(mean Β± std over seeds for n=1000)", |
| fontweight="bold", fontsize=10) |
| ax.legend(fontsize=8); ax.grid(alpha=0.3) |
|
|
| fig.suptitle("Multi-seed Causal-Intervention Robustness (n=1000, 3 seeds for grokking)", |
| fontsize=12, fontweight="bold", y=1.01) |
| plt.tight_layout() |
| out = ROOT / "paper_figures" / "figure_multiseed_intervention" |
| fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight") |
| fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") |
| plt.close(fig) |
| print(f" Saved {out}.png + .pdf") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|