| """Compose the MI workshop's causal-intervention figure (M4 + M5 across runs). |
| |
| Reads M4 trajectory + M5 sweep JSONs from every run that has them, groups by |
| condition (grokking vs standard), and produces a 2Γ2 layout: |
| |
| Top row β M4 ablation: head OOD acc (raw) vs (shortcut-ablated), trajectories |
| across all available runs in each condition. |
| Bottom row β M5 steering: head OOD acc as a function of Ξ±, one curve per run. |
| |
| Saves: paper_figures/figure_intervention_comparison.{png,pdf} |
| """ |
| from __future__ import annotations |
|
|
| import glob |
| import json |
| from pathlib import Path |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
|
|
|
|
| def _load_summary(run_dir: Path) -> dict: |
| p = run_dir / "results" / "summary.json" |
| if not p.exists(): |
| return {} |
| return json.loads(p.read_text()) |
|
|
|
|
| def _gather(): |
| """Return dict: cond β list of (run_dir, m4_traj?, m5?).""" |
| out = {"grokking": [], "standard": []} |
| run_dirs = sorted(list(ROOT.glob("experiments/runs/20260505-*/")) + |
| list(ROOT.glob("experiments/runs/20260508-*/"))) |
| for run_dir in run_dirs: |
| s = _load_summary(run_dir) |
| cond = s.get("condition") |
| if cond not in out: |
| continue |
| m4_path = run_dir / "mechinterp" / "m4_ablation_avgpool_trajectory.json" |
| m5_glob = list((run_dir / "mechinterp").glob("m5_steering_*.json")) |
| m4 = json.loads(m4_path.read_text()) if m4_path.exists() else None |
| m5 = json.loads(m5_glob[0].read_text()) if m5_glob else None |
| if m4 is None and m5 is None: |
| continue |
| out[cond].append({ |
| "run_dir": run_dir.name, |
| "n_train": s.get("n_train"), |
| "seed": s.get("seed"), |
| "best_ood": s.get("best_ood"), |
| "m4": m4, |
| "m5": m5, |
| }) |
| return out |
|
|
|
|
| def main(): |
| data = _gather() |
| n_grok = len(data["grokking"]) |
| n_std = len(data["standard"]) |
| print(f"Found {n_grok} grokking runs and {n_std} standard runs with intervention data") |
|
|
| fig, axes = plt.subplots(2, 2, figsize=(15, 9)) |
| cmap_grok = plt.cm.Blues |
| cmap_std = plt.cm.Reds |
|
|
| |
| for col, cond in enumerate(["grokking", "standard"]): |
| ax = axes[0][col] |
| runs = data[cond] |
| cmap = cmap_grok if cond == "grokking" else cmap_std |
|
|
| if not runs: |
| ax.text(0.5, 0.5, f"no {cond} M4 data yet", |
| ha="center", va="center", transform=ax.transAxes, |
| color="gray", fontsize=11) |
| ax.set_title(f"{cond.upper()}: head OOD raw vs ablated", fontweight="bold") |
| continue |
|
|
| |
| all_eps = sorted(set( |
| r["epoch"] for run in runs if run["m4"] for r in run["m4"] |
| )) |
| for i, run in enumerate([r for r in runs if r["m4"]]): |
| color = cmap(0.4 + 0.5 * i / max(1, len(runs) - 1)) |
| traj = run["m4"] |
| eps = [r["epoch"] for r in traj] |
| raw = [r["head_ood_acc_raw"] for r in traj] |
| abl = [r["head_ood_acc_ablated"] for r in traj] |
| label_n = f"n={run['n_train']} s{run['seed']}" |
| ax.plot(eps, raw, "-", color=color, alpha=0.8, lw=1.6, |
| label=f"{label_n} raw") |
| ax.plot(eps, abl, "--", color=color, alpha=0.8, lw=1.6, |
| label=f"{label_n} ablated") |
| ax.set_xlabel("Training epoch"); ax.set_ylabel("Head OOD (H4) accuracy") |
| ax.set_title(f"M4 ablation β {cond.upper()}\n" |
| f"raw (β) vs shortcut-ablated (- -)", |
| fontweight="bold", fontsize=10) |
| ax.legend(fontsize=7, ncol=2); ax.grid(alpha=0.3) |
| ax.set_ylim(0.30, 0.85) |
|
|
| |
| for col, cond in enumerate(["grokking", "standard"]): |
| ax = axes[1][col] |
| runs = data[cond] |
| cmap = cmap_grok if cond == "grokking" else cmap_std |
|
|
| m5_runs = [r for r in runs if r["m5"]] |
| if not m5_runs: |
| ax.text(0.5, 0.5, f"no {cond} M5 data yet", |
| ha="center", va="center", transform=ax.transAxes, |
| color="gray", fontsize=11) |
| ax.set_title(f"{cond.upper()}: head OOD vs steering Ξ±", fontweight="bold") |
| continue |
|
|
| for i, run in enumerate(m5_runs): |
| color = cmap(0.4 + 0.5 * i / max(1, len(m5_runs) - 1)) |
| m5 = run["m5"] |
| alphas = [r["alpha"] for r in m5["sweep"]] |
| heads = [r["head_ood_acc"] for r in m5["sweep"]] |
| ax.plot(alphas, heads, "-o", color=color, lw=2, ms=6, |
| label=f"n={run['n_train']} s{run['seed']} ep{m5['epoch']}") |
| ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5) |
| ax.set_xlabel("Steering coefficient Ξ± (Ο-units of v_s)") |
| ax.set_ylabel("Head OOD (H4) accuracy") |
| ax.set_title(f"M5 steering β {cond.upper()}\n" |
| f"Ξ±=0 baseline; |Ξ±|β = stronger shortcut activation", |
| fontweight="bold", fontsize=10) |
| ax.legend(fontsize=8); ax.grid(alpha=0.3) |
| ax.set_ylim(0.45, 0.80) |
|
|
| fig.suptitle( |
| "Causal Interventions on the Shortcut Subspace (avgpool, ResNet-18, Camelyon17)\n" |
| "M4 = ablate-and-evaluate, M5 = steer-and-evaluate", |
| fontsize=12, fontweight="bold", y=1.005, |
| ) |
| plt.tight_layout() |
| out = ROOT / "paper_figures" / "figure_intervention_comparison" |
| fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight") |
| fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") |
| plt.close(fig) |
| print(f"Saved {out}.png + .pdf") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|