File size: 6,639 Bytes
50fa85c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """Multi-seed mean Β± std figure for M4 and M5 across the 3 n=1000 grokking seeds
(s42, s123, s456) plus single-seed standard baselines as reference lines.
Reads:
experiments/runs/<id>/mechinterp/m4_ablation_avgpool_trajectory.json
experiments/runs/<id>/mechinterp/m5_steering_*.json
Outputs:
paper_figures/figure_multiseed_intervention.{png,pdf}
"""
from __future__ import annotations
import glob
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parent.parent
def _load_summary(run_dir: Path) -> dict:
p = run_dir / "results" / "summary.json"
return json.loads(p.read_text()) if p.exists() else {}
def _gather():
"""Return per-(cond, n) β list of run dicts with M4 traj + M5 sweep."""
out = {}
# Use both 2026-05-05 (initial 7) and 2026-05-08 (4 new) runs
run_dirs = sorted(list(ROOT.glob("experiments/runs/20260505-*/")) +
list(ROOT.glob("experiments/runs/20260508-*/")))
for run_dir in run_dirs:
s = _load_summary(run_dir)
cond = s.get("condition")
n = s.get("n_train")
if cond is None or n is None:
continue
m4_p = run_dir / "mechinterp" / "m4_ablation_avgpool_trajectory.json"
m5_p = list((run_dir / "mechinterp").glob("m5_steering_*.json"))
m4 = json.loads(m4_p.read_text()) if m4_p.exists() else None
m5 = json.loads(m5_p[0].read_text()) if m5_p else None
if m4 is None and m5 is None:
continue
out.setdefault((cond, n), []).append({
"run_id": run_dir.name,
"seed": s.get("seed"),
"best_ood": s.get("best_ood"),
"m4": m4, "m5": m5,
})
return out
def _stack_m4(runs):
"""Align M4 trajectories on shared epochs and return:
(eps, raw_mat, abl_mat, delta_mat) β each (n_runs, n_eps)."""
if not runs:
return None, None, None, None
eps_set = set.intersection(*[
{r["epoch"] for r in run["m4"]} for run in runs if run["m4"]
])
eps = sorted(eps_set)
raw = np.array([
[next(r for r in run["m4"] if r["epoch"] == e)["head_ood_acc_raw"]
for e in eps] for run in runs
])
abl = np.array([
[next(r for r in run["m4"] if r["epoch"] == e)["head_ood_acc_ablated"]
for e in eps] for run in runs
])
return np.array(eps), raw, abl, abl - raw
def _stack_m5(runs):
"""Align M5 sweeps on shared Ξ± and return (alphas, head_mat, hosp_mat, tum_mat)."""
runs = [run for run in runs if run["m5"]]
if not runs:
return None, None, None, None
alphas_set = set.intersection(*[
{r["alpha"] for r in run["m5"]["sweep"]} for run in runs
])
alphas = sorted(alphas_set)
head = np.array([
[next(r for r in run["m5"]["sweep"] if r["alpha"] == a)["head_ood_acc"]
for a in alphas] for run in runs
])
return np.array(alphas), head
def main():
data = _gather()
n1000_grok = data.get(("grokking", 1000), [])
n1000_std = data.get(("standard", 1000), [])
print(f" grokking n=1000 seeds: {[r['seed'] for r in n1000_grok]}")
print(f" standard n=1000 seeds: {[r['seed'] for r in n1000_std]}")
if len(n1000_grok) < 2:
print(" Need β₯2 grokking n=1000 seeds; aborting.")
return
fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
# βββββββββββββββ Panel A: M4 trajectory mean Β± std βββββββββββββββ
ax = axes[0]
eps, raw_g, abl_g, _ = _stack_m4(n1000_grok)
if eps is not None:
n_seeds_g = len(n1000_grok)
raw_mu, raw_sd = raw_g.mean(0), raw_g.std(0)
abl_mu, abl_sd = abl_g.mean(0), abl_g.std(0)
ax.plot(eps, raw_mu, "-", color="navy", lw=2.2,
label=f"grok raw (n={n_seeds_g} seeds)")
ax.fill_between(eps, raw_mu - raw_sd, raw_mu + raw_sd,
color="navy", alpha=0.18)
ax.plot(eps, abl_mu, "--", color="darkorange", lw=2.2,
label=f"grok ablated (n={n_seeds_g} seeds)")
ax.fill_between(eps, abl_mu - abl_sd, abl_mu + abl_sd,
color="darkorange", alpha=0.18)
if n1000_std:
eps_s, raw_s, abl_s, _ = _stack_m4(n1000_std)
if eps_s is not None:
ax.plot(eps_s, raw_s.mean(0), "-", color="darkred", lw=1.6,
alpha=0.9, label=f"standard raw (n={len(n1000_std)} seed)")
ax.plot(eps_s, abl_s.mean(0), "--", color="darkred", lw=1.6,
alpha=0.9, label=f"standard ablated (n={len(n1000_std)} seed)")
ax.set_xlabel("Training epoch")
ax.set_ylabel("Head OOD (H4) accuracy")
ax.set_title("M4 β Shortcut subspace ablation across training\n"
"(mean Β± std over seeds for n=1000)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8, loc="lower right"); ax.grid(alpha=0.3)
ax.set_ylim(0.42, 0.80)
# βββββββββββββββ Panel B: M5 steering mean Β± std βββββββββββββββ
ax = axes[1]
a_g, head_g = _stack_m5(n1000_grok)
if a_g is not None and len(n1000_grok) >= 2:
mu = head_g.mean(0); sd = head_g.std(0)
ax.plot(a_g, mu, "-o", color="navy", lw=2.2, ms=7,
label=f"grokking n=1000 (n={len(n1000_grok)} seeds)")
ax.fill_between(a_g, mu - sd, mu + sd, color="navy", alpha=0.20)
if n1000_std:
a_s, head_s = _stack_m5(n1000_std)
if a_s is not None:
ax.plot(a_s, head_s.mean(0), "-s", color="darkred", lw=1.8, ms=6,
alpha=0.9, label=f"standard n=1000 (n={len(n1000_std)} seed)")
ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5)
ax.set_xlabel("Steering coefficient Ξ± (Ο-units of v_s)")
ax.set_ylabel("Head OOD (H4) accuracy")
ax.set_title("M5 β Activation steering response\n"
"(mean Β± std over seeds for n=1000)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8); ax.grid(alpha=0.3)
fig.suptitle("Multi-seed Causal-Intervention Robustness (n=1000, 3 seeds for grokking)",
fontsize=12, fontweight="bold", y=1.01)
plt.tight_layout()
out = ROOT / "paper_figures" / "figure_multiseed_intervention"
fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight")
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
plt.close(fig)
print(f" Saved {out}.png + .pdf")
if __name__ == "__main__":
main()
|