CausalGrok / code /experiments /figure_intervention_comparison.py
nileshsarkar-ai's picture
Upload code/experiments
50fa85c verified
"""Compose the MI workshop's causal-intervention figure (M4 + M5 across runs).
Reads M4 trajectory + M5 sweep JSONs from every run that has them, groups by
condition (grokking vs standard), and produces a 2Γ—2 layout:
Top row β€” M4 ablation: head OOD acc (raw) vs (shortcut-ablated), trajectories
across all available runs in each condition.
Bottom row β€” M5 steering: head OOD acc as a function of Ξ±, one curve per run.
Saves: paper_figures/figure_intervention_comparison.{png,pdf}
"""
from __future__ import annotations
import glob
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parent.parent
def _load_summary(run_dir: Path) -> dict:
p = run_dir / "results" / "summary.json"
if not p.exists():
return {}
return json.loads(p.read_text())
def _gather():
"""Return dict: cond β†’ list of (run_dir, m4_traj?, m5?)."""
out = {"grokking": [], "standard": []}
run_dirs = sorted(list(ROOT.glob("experiments/runs/20260505-*/")) +
list(ROOT.glob("experiments/runs/20260508-*/")))
for run_dir in run_dirs:
s = _load_summary(run_dir)
cond = s.get("condition")
if cond not in out:
continue
m4_path = run_dir / "mechinterp" / "m4_ablation_avgpool_trajectory.json"
m5_glob = list((run_dir / "mechinterp").glob("m5_steering_*.json"))
m4 = json.loads(m4_path.read_text()) if m4_path.exists() else None
m5 = json.loads(m5_glob[0].read_text()) if m5_glob else None
if m4 is None and m5 is None:
continue
out[cond].append({
"run_dir": run_dir.name,
"n_train": s.get("n_train"),
"seed": s.get("seed"),
"best_ood": s.get("best_ood"),
"m4": m4,
"m5": m5,
})
return out
def main():
data = _gather()
n_grok = len(data["grokking"])
n_std = len(data["standard"])
print(f"Found {n_grok} grokking runs and {n_std} standard runs with intervention data")
fig, axes = plt.subplots(2, 2, figsize=(15, 9))
cmap_grok = plt.cm.Blues
cmap_std = plt.cm.Reds
# ─────────────── M4 ablation trajectories ───────────────
for col, cond in enumerate(["grokking", "standard"]):
ax = axes[0][col]
runs = data[cond]
cmap = cmap_grok if cond == "grokking" else cmap_std
if not runs:
ax.text(0.5, 0.5, f"no {cond} M4 data yet",
ha="center", va="center", transform=ax.transAxes,
color="gray", fontsize=11)
ax.set_title(f"{cond.upper()}: head OOD raw vs ablated", fontweight="bold")
continue
# collect trajectories
all_eps = sorted(set(
r["epoch"] for run in runs if run["m4"] for r in run["m4"]
))
for i, run in enumerate([r for r in runs if r["m4"]]):
color = cmap(0.4 + 0.5 * i / max(1, len(runs) - 1))
traj = run["m4"]
eps = [r["epoch"] for r in traj]
raw = [r["head_ood_acc_raw"] for r in traj]
abl = [r["head_ood_acc_ablated"] for r in traj]
label_n = f"n={run['n_train']} s{run['seed']}"
ax.plot(eps, raw, "-", color=color, alpha=0.8, lw=1.6,
label=f"{label_n} raw")
ax.plot(eps, abl, "--", color=color, alpha=0.8, lw=1.6,
label=f"{label_n} ablated")
ax.set_xlabel("Training epoch"); ax.set_ylabel("Head OOD (H4) accuracy")
ax.set_title(f"M4 ablation β€” {cond.upper()}\n"
f"raw (β€”) vs shortcut-ablated (- -)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=7, ncol=2); ax.grid(alpha=0.3)
ax.set_ylim(0.30, 0.85)
# ─────────────── M5 steering sweeps ───────────────
for col, cond in enumerate(["grokking", "standard"]):
ax = axes[1][col]
runs = data[cond]
cmap = cmap_grok if cond == "grokking" else cmap_std
m5_runs = [r for r in runs if r["m5"]]
if not m5_runs:
ax.text(0.5, 0.5, f"no {cond} M5 data yet",
ha="center", va="center", transform=ax.transAxes,
color="gray", fontsize=11)
ax.set_title(f"{cond.upper()}: head OOD vs steering Ξ±", fontweight="bold")
continue
for i, run in enumerate(m5_runs):
color = cmap(0.4 + 0.5 * i / max(1, len(m5_runs) - 1))
m5 = run["m5"]
alphas = [r["alpha"] for r in m5["sweep"]]
heads = [r["head_ood_acc"] for r in m5["sweep"]]
ax.plot(alphas, heads, "-o", color=color, lw=2, ms=6,
label=f"n={run['n_train']} s{run['seed']} ep{m5['epoch']}")
ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5)
ax.set_xlabel("Steering coefficient Ξ± (Οƒ-units of v_s)")
ax.set_ylabel("Head OOD (H4) accuracy")
ax.set_title(f"M5 steering β€” {cond.upper()}\n"
f"Ξ±=0 baseline; |Ξ±|↑ = stronger shortcut activation",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8); ax.grid(alpha=0.3)
ax.set_ylim(0.45, 0.80)
fig.suptitle(
"Causal Interventions on the Shortcut Subspace (avgpool, ResNet-18, Camelyon17)\n"
"M4 = ablate-and-evaluate, M5 = steer-and-evaluate",
fontsize=12, fontweight="bold", y=1.005,
)
plt.tight_layout()
out = ROOT / "paper_figures" / "figure_intervention_comparison"
fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight")
fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight")
plt.close(fig)
print(f"Saved {out}.png + .pdf")
if __name__ == "__main__":
main()