File size: 5,986 Bytes
50fa85c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""Compose the MI workshop's causal-intervention figure (M4 + M5 across runs).

Reads M4 trajectory + M5 sweep JSONs from every run that has them, groups by
condition (grokking vs standard), and produces a 2Γ—2 layout:

  Top row    β€” M4 ablation: head OOD acc (raw) vs (shortcut-ablated), trajectories
                across all available runs in each condition.
  Bottom row β€” M5 steering: head OOD acc as a function of Ξ±, one curve per run.

Saves: paper_figures/figure_intervention_comparison.{png,pdf}
"""
from __future__ import annotations

import glob
import json
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

ROOT = Path(__file__).resolve().parent.parent


def _load_summary(run_dir: Path) -> dict:
    p = run_dir / "results" / "summary.json"
    if not p.exists():
        return {}
    return json.loads(p.read_text())


def _gather():
    """Return dict: cond β†’ list of (run_dir, m4_traj?, m5?)."""
    out = {"grokking": [], "standard": []}
    run_dirs = sorted(list(ROOT.glob("experiments/runs/20260505-*/")) +
                      list(ROOT.glob("experiments/runs/20260508-*/")))
    for run_dir in run_dirs:
        s = _load_summary(run_dir)
        cond = s.get("condition")
        if cond not in out:
            continue
        m4_path = run_dir / "mechinterp" / "m4_ablation_avgpool_trajectory.json"
        m5_glob = list((run_dir / "mechinterp").glob("m5_steering_*.json"))
        m4 = json.loads(m4_path.read_text()) if m4_path.exists() else None
        m5 = json.loads(m5_glob[0].read_text()) if m5_glob else None
        if m4 is None and m5 is None:
            continue
        out[cond].append({
            "run_dir": run_dir.name,
            "n_train": s.get("n_train"),
            "seed":    s.get("seed"),
            "best_ood": s.get("best_ood"),
            "m4": m4,
            "m5": m5,
        })
    return out


def main():
    data = _gather()
    n_grok = len(data["grokking"])
    n_std  = len(data["standard"])
    print(f"Found {n_grok} grokking runs and {n_std} standard runs with intervention data")

    fig, axes = plt.subplots(2, 2, figsize=(15, 9))
    cmap_grok = plt.cm.Blues
    cmap_std  = plt.cm.Reds

    # ─────────────── M4 ablation trajectories ───────────────
    for col, cond in enumerate(["grokking", "standard"]):
        ax = axes[0][col]
        runs = data[cond]
        cmap = cmap_grok if cond == "grokking" else cmap_std

        if not runs:
            ax.text(0.5, 0.5, f"no {cond} M4 data yet",
                    ha="center", va="center", transform=ax.transAxes,
                    color="gray", fontsize=11)
            ax.set_title(f"{cond.upper()}: head OOD raw vs ablated", fontweight="bold")
            continue

        # collect trajectories
        all_eps = sorted(set(
            r["epoch"] for run in runs if run["m4"] for r in run["m4"]
        ))
        for i, run in enumerate([r for r in runs if r["m4"]]):
            color = cmap(0.4 + 0.5 * i / max(1, len(runs) - 1))
            traj = run["m4"]
            eps  = [r["epoch"] for r in traj]
            raw  = [r["head_ood_acc_raw"]     for r in traj]
            abl  = [r["head_ood_acc_ablated"] for r in traj]
            label_n = f"n={run['n_train']} s{run['seed']}"
            ax.plot(eps, raw, "-",  color=color, alpha=0.8, lw=1.6,
                    label=f"{label_n} raw")
            ax.plot(eps, abl, "--", color=color, alpha=0.8, lw=1.6,
                    label=f"{label_n} ablated")
        ax.set_xlabel("Training epoch"); ax.set_ylabel("Head OOD (H4) accuracy")
        ax.set_title(f"M4 ablation β€” {cond.upper()}\n"
                     f"raw (β€”) vs shortcut-ablated (- -)",
                     fontweight="bold", fontsize=10)
        ax.legend(fontsize=7, ncol=2); ax.grid(alpha=0.3)
        ax.set_ylim(0.30, 0.85)

    # ─────────────── M5 steering sweeps ───────────────
    for col, cond in enumerate(["grokking", "standard"]):
        ax = axes[1][col]
        runs = data[cond]
        cmap = cmap_grok if cond == "grokking" else cmap_std

        m5_runs = [r for r in runs if r["m5"]]
        if not m5_runs:
            ax.text(0.5, 0.5, f"no {cond} M5 data yet",
                    ha="center", va="center", transform=ax.transAxes,
                    color="gray", fontsize=11)
            ax.set_title(f"{cond.upper()}: head OOD vs steering Ξ±", fontweight="bold")
            continue

        for i, run in enumerate(m5_runs):
            color = cmap(0.4 + 0.5 * i / max(1, len(m5_runs) - 1))
            m5 = run["m5"]
            alphas = [r["alpha"]        for r in m5["sweep"]]
            heads  = [r["head_ood_acc"] for r in m5["sweep"]]
            ax.plot(alphas, heads, "-o", color=color, lw=2, ms=6,
                    label=f"n={run['n_train']} s{run['seed']}  ep{m5['epoch']}")
        ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5)
        ax.set_xlabel("Steering coefficient Ξ± (Οƒ-units of v_s)")
        ax.set_ylabel("Head OOD (H4) accuracy")
        ax.set_title(f"M5 steering β€” {cond.upper()}\n"
                     f"Ξ±=0 baseline; |Ξ±|↑ = stronger shortcut activation",
                     fontweight="bold", fontsize=10)
        ax.legend(fontsize=8); ax.grid(alpha=0.3)
        ax.set_ylim(0.45, 0.80)

    fig.suptitle(
        "Causal Interventions on the Shortcut Subspace (avgpool, ResNet-18, Camelyon17)\n"
        "M4 = ablate-and-evaluate, M5 = steer-and-evaluate",
        fontsize=12, fontweight="bold", y=1.005,
    )
    plt.tight_layout()
    out = ROOT / "paper_figures" / "figure_intervention_comparison"
    fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight")
    fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
    plt.close(fig)
    print(f"Saved {out}.png + .pdf")


if __name__ == "__main__":
    main()