File size: 6,639 Bytes
50fa85c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""Multi-seed mean Β± std figure for M4 and M5 across the 3 n=1000 grokking seeds
(s42, s123, s456) plus single-seed standard baselines as reference lines.

Reads:
  experiments/runs/<id>/mechinterp/m4_ablation_avgpool_trajectory.json
  experiments/runs/<id>/mechinterp/m5_steering_*.json

Outputs:
  paper_figures/figure_multiseed_intervention.{png,pdf}
"""
from __future__ import annotations

import glob
import json
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

ROOT = Path(__file__).resolve().parent.parent


def _load_summary(run_dir: Path) -> dict:
    p = run_dir / "results" / "summary.json"
    return json.loads(p.read_text()) if p.exists() else {}


def _gather():
    """Return per-(cond, n) β†’ list of run dicts with M4 traj + M5 sweep."""
    out = {}
    # Use both 2026-05-05 (initial 7) and 2026-05-08 (4 new) runs
    run_dirs = sorted(list(ROOT.glob("experiments/runs/20260505-*/")) +
                      list(ROOT.glob("experiments/runs/20260508-*/")))
    for run_dir in run_dirs:
        s = _load_summary(run_dir)
        cond = s.get("condition")
        n = s.get("n_train")
        if cond is None or n is None:
            continue
        m4_p = run_dir / "mechinterp" / "m4_ablation_avgpool_trajectory.json"
        m5_p = list((run_dir / "mechinterp").glob("m5_steering_*.json"))
        m4 = json.loads(m4_p.read_text()) if m4_p.exists() else None
        m5 = json.loads(m5_p[0].read_text()) if m5_p else None
        if m4 is None and m5 is None:
            continue
        out.setdefault((cond, n), []).append({
            "run_id":   run_dir.name,
            "seed":     s.get("seed"),
            "best_ood": s.get("best_ood"),
            "m4": m4, "m5": m5,
        })
    return out


def _stack_m4(runs):
    """Align M4 trajectories on shared epochs and return:
       (eps, raw_mat, abl_mat, delta_mat) β€” each (n_runs, n_eps)."""
    if not runs:
        return None, None, None, None
    eps_set = set.intersection(*[
        {r["epoch"] for r in run["m4"]} for run in runs if run["m4"]
    ])
    eps = sorted(eps_set)
    raw = np.array([
        [next(r for r in run["m4"] if r["epoch"] == e)["head_ood_acc_raw"]
         for e in eps] for run in runs
    ])
    abl = np.array([
        [next(r for r in run["m4"] if r["epoch"] == e)["head_ood_acc_ablated"]
         for e in eps] for run in runs
    ])
    return np.array(eps), raw, abl, abl - raw


def _stack_m5(runs):
    """Align M5 sweeps on shared Ξ± and return (alphas, head_mat, hosp_mat, tum_mat)."""
    runs = [run for run in runs if run["m5"]]
    if not runs:
        return None, None, None, None
    alphas_set = set.intersection(*[
        {r["alpha"] for r in run["m5"]["sweep"]} for run in runs
    ])
    alphas = sorted(alphas_set)
    head = np.array([
        [next(r for r in run["m5"]["sweep"] if r["alpha"] == a)["head_ood_acc"]
         for a in alphas] for run in runs
    ])
    return np.array(alphas), head


def main():
    data = _gather()
    n1000_grok = data.get(("grokking", 1000), [])
    n1000_std  = data.get(("standard", 1000), [])
    print(f"  grokking n=1000 seeds: {[r['seed'] for r in n1000_grok]}")
    print(f"  standard n=1000 seeds: {[r['seed'] for r in n1000_std]}")

    if len(n1000_grok) < 2:
        print("  Need β‰₯2 grokking n=1000 seeds; aborting.")
        return

    fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))

    # ─────────────── Panel A: M4 trajectory mean Β± std ───────────────
    ax = axes[0]
    eps, raw_g, abl_g, _ = _stack_m4(n1000_grok)
    if eps is not None:
        n_seeds_g = len(n1000_grok)
        raw_mu, raw_sd = raw_g.mean(0), raw_g.std(0)
        abl_mu, abl_sd = abl_g.mean(0), abl_g.std(0)
        ax.plot(eps, raw_mu, "-",  color="navy", lw=2.2,
                label=f"grok raw (n={n_seeds_g} seeds)")
        ax.fill_between(eps, raw_mu - raw_sd, raw_mu + raw_sd,
                        color="navy", alpha=0.18)
        ax.plot(eps, abl_mu, "--", color="darkorange", lw=2.2,
                label=f"grok ablated (n={n_seeds_g} seeds)")
        ax.fill_between(eps, abl_mu - abl_sd, abl_mu + abl_sd,
                        color="darkorange", alpha=0.18)

    if n1000_std:
        eps_s, raw_s, abl_s, _ = _stack_m4(n1000_std)
        if eps_s is not None:
            ax.plot(eps_s, raw_s.mean(0), "-",  color="darkred", lw=1.6,
                    alpha=0.9, label=f"standard raw (n={len(n1000_std)} seed)")
            ax.plot(eps_s, abl_s.mean(0), "--", color="darkred", lw=1.6,
                    alpha=0.9, label=f"standard ablated (n={len(n1000_std)} seed)")

    ax.set_xlabel("Training epoch")
    ax.set_ylabel("Head OOD (H4) accuracy")
    ax.set_title("M4 β€” Shortcut subspace ablation across training\n"
                 "(mean Β± std over seeds for n=1000)",
                 fontweight="bold", fontsize=10)
    ax.legend(fontsize=8, loc="lower right"); ax.grid(alpha=0.3)
    ax.set_ylim(0.42, 0.80)

    # ─────────────── Panel B: M5 steering mean Β± std ───────────────
    ax = axes[1]
    a_g, head_g = _stack_m5(n1000_grok)
    if a_g is not None and len(n1000_grok) >= 2:
        mu = head_g.mean(0); sd = head_g.std(0)
        ax.plot(a_g, mu, "-o", color="navy", lw=2.2, ms=7,
                label=f"grokking n=1000 (n={len(n1000_grok)} seeds)")
        ax.fill_between(a_g, mu - sd, mu + sd, color="navy", alpha=0.20)
    if n1000_std:
        a_s, head_s = _stack_m5(n1000_std)
        if a_s is not None:
            ax.plot(a_s, head_s.mean(0), "-s", color="darkred", lw=1.8, ms=6,
                    alpha=0.9, label=f"standard n=1000 (n={len(n1000_std)} seed)")
    ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5)
    ax.set_xlabel("Steering coefficient Ξ± (Οƒ-units of v_s)")
    ax.set_ylabel("Head OOD (H4) accuracy")
    ax.set_title("M5 β€” Activation steering response\n"
                 "(mean Β± std over seeds for n=1000)",
                 fontweight="bold", fontsize=10)
    ax.legend(fontsize=8); ax.grid(alpha=0.3)

    fig.suptitle("Multi-seed Causal-Intervention Robustness (n=1000, 3 seeds for grokking)",
                 fontsize=12, fontweight="bold", y=1.01)
    plt.tight_layout()
    out = ROOT / "paper_figures" / "figure_multiseed_intervention"
    fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight")
    fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved {out}.png + .pdf")


if __name__ == "__main__":
    main()