File size: 6,542 Bytes
50fa85c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""Generate MI workshop Figure 1: grokking vs standard probe heatmap.

Reads M1 probe outputs from experiments/runs/*/mechinterp/m1_probe_data.json,
produces a 2x2 grid (rows = hospital/tumor probe; cols = grokking/standard)
with epoch-x-layer heatmaps. Hospital = Reds (want fading); Tumor = Greens (want rising).

Picks the strongest grokking run by best_ood and the standard control with
periodic checkpoints (or final.pt if that's all that's available).

Usage:
    python -m experiments.figure_mi_comparison \
        [--grok-run experiments/runs/<id>] [--std-run experiments/runs/<id>] \
        [--out paper_figures/figure1_MI_probe_comparison]
"""
from __future__ import annotations

import argparse
import glob
import json
import os
import sys
from pathlib import Path

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

ROOT = Path(__file__).resolve().parent.parent


def _load_summary(run_dir: Path) -> dict:
    p = run_dir / "results" / "summary.json"
    if not p.exists():
        return {}
    try:
        return json.loads(p.read_text())
    except Exception:
        return {}


def _pick_best_grok_run() -> Path | None:
    candidates = []
    for f in glob.glob(str(ROOT / "experiments/runs/*/mechinterp/m1_probe_data.json")):
        run_dir = Path(f).parent.parent
        s = _load_summary(run_dir)
        if s.get("condition") != "grokking":
            continue
        best = s.get("best_ood", 0) or 0
        candidates.append((best, run_dir))
    if not candidates:
        return None
    candidates.sort(reverse=True)
    return candidates[0][1]


def _pick_std_run() -> Path | None:
    candidates = []
    for f in glob.glob(str(ROOT / "experiments/runs/*/mechinterp/m1_probe_data.json")):
        run_dir = Path(f).parent.parent
        s = _load_summary(run_dir)
        if s.get("condition") != "standard":
            continue
        best = s.get("best_ood", s.get("ood_test_acc", 0)) or 0
        candidates.append((best, run_dir))
    if not candidates:
        return None
    candidates.sort(reverse=True)
    return candidates[0][1]


def _load_probe(run_dir: Path) -> dict | None:
    p = run_dir / "mechinterp" / "m1_probe_data.json"
    if not p.exists():
        return None
    return json.loads(p.read_text())


def _heatmap(ax, data: dict, key: str, title: str, cmap: str):
    if data is None:
        ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(title, fontweight="bold", fontsize=10)
        return None

    epochs = data["epochs"]
    layers = data["layers"]
    mat = np.array(data[key])  # shape (n_epochs, n_layers)
    if mat.ndim != 2:
        mat = mat.reshape(len(epochs), len(layers))

    im = ax.imshow(
        mat.T,
        aspect="auto",
        cmap=cmap,
        vmin=0.0,
        vmax=1.0,
        interpolation="nearest",
        origin="lower",
    )
    ax.set_xticks(range(len(epochs)))
    ax.set_xticklabels(epochs, rotation=45, ha="right", fontsize=7)
    ax.set_yticks(range(len(layers)))
    ax.set_yticklabels(layers, fontsize=8)
    ax.set_xlabel("Epoch", fontsize=9)
    ax.set_title(title, fontweight="bold", fontsize=10)
    return im


def make_figure(grok_dir: Path, std_dir: Path | None, out_base: Path):
    grok = _load_probe(grok_dir)
    std = _load_probe(std_dir) if std_dir else None

    if grok is None:
        print(f"[ERROR] no probe data at {grok_dir}/mechinterp/m1_probe_data.json")
        sys.exit(1)

    # Hospital probe on H3 (held-in held-out hospital) β€” has signal.
    # H4 version is degenerate (probe class set excludes H4 by construction β†’ ≑ 0).
    hosp_key = "hospital_probe_id" if "hospital_probe_id" in grok else "hospital_probe"
    # Tumor probe on H4 (truly OOD hospital) β€” measures causal-feature transfer.
    tumor_key = "tumor_probe_ood" if "tumor_probe_ood" in grok else "tumor_probe"

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    grok_title = f"Grokking ({grok_dir.name[-30:]})"
    std_title = f"Standard ({std_dir.name[-30:]})" if std_dir else "Standard (no data)"

    im00 = _heatmap(axes[0][0], grok, hosp_key, f"{grok_title}\nHospital probe on H3 (shortcut recoverability, ↓ good)", "Reds")
    im01 = _heatmap(axes[0][1], std, hosp_key, f"{std_title}\nHospital probe on H3 (shortcut recoverability, ↓ good)", "Reds")
    im10 = _heatmap(axes[1][0], grok, tumor_key, f"{grok_title}\nTumor probe on H4 (causal transfer, ↑ good)", "Greens")
    im11 = _heatmap(axes[1][1], std, tumor_key, f"{std_title}\nTumor probe on H4 (causal transfer, ↑ good)", "Greens")

    for im, ax in [(im00, axes[0][0]), (im01, axes[0][1]), (im10, axes[1][0]), (im11, axes[1][1])]:
        if im is not None:
            plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02)

    fig.suptitle(
        "Figure 1 β€” Layer-wise circuit analysis: grokking-favorable vs standard training\n"
        "Grokking: deep-layer hospital recoverability (Reds) drops over training while tumor recoverability (Greens) is preserved.\n"
        "Standard: no localized scrubbing of hospital information.",
        fontsize=11,
        y=1.005,
        fontweight="bold",
    )
    plt.tight_layout()

    out_base.parent.mkdir(parents=True, exist_ok=True)
    png = out_base.with_suffix(".png")
    pdf = out_base.with_suffix(".pdf")
    fig.savefig(png, bbox_inches="tight", dpi=200)
    fig.savefig(pdf, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved {png}")
    print(f"Saved {pdf}")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--grok-run", default=None, help="Path to grokking run dir; auto-pick best if omitted")
    ap.add_argument("--std-run", default=None, help="Path to standard run dir; auto-pick best if omitted")
    ap.add_argument("--out", default="paper_figures/figure1_MI_probe_comparison")
    args = ap.parse_args()

    grok_dir = Path(args.grok_run) if args.grok_run else _pick_best_grok_run()
    std_dir = Path(args.std_run) if args.std_run else _pick_std_run()

    if grok_dir is None:
        print("[ERROR] No grokking run with M1 probe data found.")
        print("        Run experiments/mechinterp_m1.py on a grokking run first.")
        sys.exit(2)

    print(f"Grokking run : {grok_dir}")
    print(f"Standard run : {std_dir if std_dir else '(none β€” figure will show only grokking)'}")

    out_base = ROOT / args.out
    make_figure(grok_dir, std_dir, out_base)


if __name__ == "__main__":
    main()