File size: 5,986 Bytes
50fa85c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | """Compose the MI workshop's causal-intervention figure (M4 + M5 across runs).
Reads M4 trajectory + M5 sweep JSONs from every run that has them, groups by
condition (grokking vs standard), and produces a 2Γ2 layout:
Top row β M4 ablation: head OOD acc (raw) vs (shortcut-ablated), trajectories
across all available runs in each condition.
Bottom row β M5 steering: head OOD acc as a function of Ξ±, one curve per run.
Saves: paper_figures/figure_intervention_comparison.{png,pdf}
"""
from __future__ import annotations
import glob
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parent.parent
def _load_summary(run_dir: Path) -> dict:
p = run_dir / "results" / "summary.json"
if not p.exists():
return {}
return json.loads(p.read_text())
def _gather():
"""Return dict: cond β list of (run_dir, m4_traj?, m5?)."""
out = {"grokking": [], "standard": []}
run_dirs = sorted(list(ROOT.glob("experiments/runs/20260505-*/")) +
list(ROOT.glob("experiments/runs/20260508-*/")))
for run_dir in run_dirs:
s = _load_summary(run_dir)
cond = s.get("condition")
if cond not in out:
continue
m4_path = run_dir / "mechinterp" / "m4_ablation_avgpool_trajectory.json"
m5_glob = list((run_dir / "mechinterp").glob("m5_steering_*.json"))
m4 = json.loads(m4_path.read_text()) if m4_path.exists() else None
m5 = json.loads(m5_glob[0].read_text()) if m5_glob else None
if m4 is None and m5 is None:
continue
out[cond].append({
"run_dir": run_dir.name,
"n_train": s.get("n_train"),
"seed": s.get("seed"),
"best_ood": s.get("best_ood"),
"m4": m4,
"m5": m5,
})
return out
def main():
data = _gather()
n_grok = len(data["grokking"])
n_std = len(data["standard"])
print(f"Found {n_grok} grokking runs and {n_std} standard runs with intervention data")
fig, axes = plt.subplots(2, 2, figsize=(15, 9))
cmap_grok = plt.cm.Blues
cmap_std = plt.cm.Reds
# βββββββββββββββ M4 ablation trajectories βββββββββββββββ
for col, cond in enumerate(["grokking", "standard"]):
ax = axes[0][col]
runs = data[cond]
cmap = cmap_grok if cond == "grokking" else cmap_std
if not runs:
ax.text(0.5, 0.5, f"no {cond} M4 data yet",
ha="center", va="center", transform=ax.transAxes,
color="gray", fontsize=11)
ax.set_title(f"{cond.upper()}: head OOD raw vs ablated", fontweight="bold")
continue
# collect trajectories
all_eps = sorted(set(
r["epoch"] for run in runs if run["m4"] for r in run["m4"]
))
for i, run in enumerate([r for r in runs if r["m4"]]):
color = cmap(0.4 + 0.5 * i / max(1, len(runs) - 1))
traj = run["m4"]
eps = [r["epoch"] for r in traj]
raw = [r["head_ood_acc_raw"] for r in traj]
abl = [r["head_ood_acc_ablated"] for r in traj]
label_n = f"n={run['n_train']} s{run['seed']}"
ax.plot(eps, raw, "-", color=color, alpha=0.8, lw=1.6,
label=f"{label_n} raw")
ax.plot(eps, abl, "--", color=color, alpha=0.8, lw=1.6,
label=f"{label_n} ablated")
ax.set_xlabel("Training epoch"); ax.set_ylabel("Head OOD (H4) accuracy")
ax.set_title(f"M4 ablation β {cond.upper()}\n"
f"raw (β) vs shortcut-ablated (- -)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=7, ncol=2); ax.grid(alpha=0.3)
ax.set_ylim(0.30, 0.85)
# βββββββββββββββ M5 steering sweeps βββββββββββββββ
for col, cond in enumerate(["grokking", "standard"]):
ax = axes[1][col]
runs = data[cond]
cmap = cmap_grok if cond == "grokking" else cmap_std
m5_runs = [r for r in runs if r["m5"]]
if not m5_runs:
ax.text(0.5, 0.5, f"no {cond} M5 data yet",
ha="center", va="center", transform=ax.transAxes,
color="gray", fontsize=11)
ax.set_title(f"{cond.upper()}: head OOD vs steering Ξ±", fontweight="bold")
continue
for i, run in enumerate(m5_runs):
color = cmap(0.4 + 0.5 * i / max(1, len(m5_runs) - 1))
m5 = run["m5"]
alphas = [r["alpha"] for r in m5["sweep"]]
heads = [r["head_ood_acc"] for r in m5["sweep"]]
ax.plot(alphas, heads, "-o", color=color, lw=2, ms=6,
label=f"n={run['n_train']} s{run['seed']} ep{m5['epoch']}")
ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5)
ax.set_xlabel("Steering coefficient Ξ± (Ο-units of v_s)")
ax.set_ylabel("Head OOD (H4) accuracy")
ax.set_title(f"M5 steering β {cond.upper()}\n"
f"Ξ±=0 baseline; |Ξ±|β = stronger shortcut activation",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8); ax.grid(alpha=0.3)
ax.set_ylim(0.45, 0.80)
fig.suptitle(
"Causal Interventions on the Shortcut Subspace (avgpool, ResNet-18, Camelyon17)\n"
"M4 = ablate-and-evaluate, M5 = steer-and-evaluate",
fontsize=12, fontweight="bold", y=1.005,
)
plt.tight_layout()
out = ROOT / "paper_figures" / "figure_intervention_comparison"
fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight")
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
plt.close(fig)
print(f"Saved {out}.png + .pdf")
if __name__ == "__main__":
main()
|