CausalGrok / code /experiments /plot_results.py
nileshsarkar-ai's picture
Upload code/experiments
50fa85c verified
"""
CausalGrok — Paper Figure Generator
Reads every experiments/runs/<run_id>/results/history.json on disk and
produces:
paper_figures/figure1_smoking_gun.png|pdf ← IRM penalty + val acc
paper_figures/figure2_mechanisms.png ← weight norm + feature rank
paper_figures/figure3_shortcut.png ← shortcut ratio over training
paper_figures/table1_ablations.csv ← summary across runs
Per-run figures are also saved into experiments/runs/<run_id>/figures/.
Run after experiments complete:
bash scripts/plot_all.sh
# or:
python -m experiments.plot_results
"""
from __future__ import annotations
import argparse
import glob
import json
import os
from typing import Dict, List
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
from utils.run_dir import DEFAULT_BASE
matplotlib.rcParams.update({"font.size": 12, "figure.dpi": 150})
# ──────────────────────────────────────────────
# LOADING
# ──────────────────────────────────────────────
def discover_runs(runs_dir: str = DEFAULT_BASE) -> List[Dict]:
"""One record per run that has a history.json."""
runs = []
for run_dir in sorted(glob.glob(os.path.join(runs_dir, "*"))):
hist_path = os.path.join(run_dir, "results", "history.json")
cfg_path = os.path.join(run_dir, "config.json")
if not os.path.isfile(hist_path):
continue
try:
df = pd.DataFrame(json.load(open(hist_path)))
except Exception:
continue
# Normalize column names for v1 vs v2 compatibility
# v1 uses: val_acc, train_acc
# v2 uses: id_val_acc, ood_acc, train_acc
if "id_val_acc" in df.columns and "val_acc" not in df.columns:
df = df.rename(columns={"id_val_acc": "val_acc"})
cfg = json.load(open(cfg_path)) if os.path.isfile(cfg_path) else {}
runs.append(dict(run_dir=run_dir, df=df, cfg=cfg,
run_id=os.path.basename(run_dir)))
return runs
def average_by_condition(runs: List[Dict]) -> Dict[str, pd.DataFrame]:
"""
Group runs by (condition, n_train) so we never average across
incompatible dataset sizes. Returned key is "<condition>_n<N>".
"""
grouped: Dict[tuple, List[pd.DataFrame]] = {}
for r in runs:
cond = r["cfg"].get("condition")
if cond is None:
cond = "grokking" if "grokking" in r["run_id"] else "standard"
n_train = r["cfg"].get("n_train", 0)
grouped.setdefault((cond, n_train), []).append(r["df"])
out: Dict[str, pd.DataFrame] = {}
for (cond, n), dfs in grouped.items():
merged = pd.concat(dfs, ignore_index=True)
numeric_cols = [c for c in merged.columns if c != "epoch"
and pd.api.types.is_numeric_dtype(merged[c])]
out[f"{cond}_n{n}"] = merged.groupby("epoch")[numeric_cols].mean().reset_index()
return out
def pick_headline_curves(data: Dict[str, pd.DataFrame]):
"""
Pick one grokking curve and one standard curve for the headline
figure. Heuristic: prefer n=500 (the canonical small-data regime
for this paper); otherwise fall back to the smallest n_train
available. Large-dataset runs grok fast and the plateau
disappears, washing out the visual story.
"""
def best(cond_prefix):
keys = [k for k in data if k.startswith(f"{cond_prefix}_n")]
if not keys:
return None
target = f"{cond_prefix}_n500"
if target in keys:
return target
keys.sort(key=lambda k: int(k.split("_n")[-1]))
return keys[0]
return best("grokking"), best("standard")
# ──────────────────────────────────────────────
# FIGURE 1 — THE SMOKING GUN
# ──────────────────────────────────────────────
def figure1_smoking_gun(data: Dict[str, pd.DataFrame], save_dir: str):
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
grok_key, std_key = pick_headline_curves(data)
panels = [
(axes[0], grok_key, "#2563EB",
f"Grokking-Favorable Training\n({grok_key or 'no data'})"),
(axes[1], std_key, "#DC2626",
f"Standard Training\n({std_key or 'no data'})"),
]
for ax, cond, color, title in panels:
if cond is None or cond not in data:
ax.text(0.5, 0.5, f"No {cond} data yet",
ha="center", va="center", transform=ax.transAxes)
ax.set_title(title, fontweight="bold")
continue
df = data[cond]
ax2 = ax.twinx()
ax.plot(df["epoch"], df["val_acc"], color=color, lw=2.5,
label="ID Val Accuracy (H3)", zorder=3)
# For v2 runs: also show OOD accuracy (the actual grokking signal)
if "ood_acc" in df.columns:
ax.plot(df["epoch"], df["ood_acc"], color=color, lw=2.5, ls="--",
alpha=0.7, label="OOD Accuracy (H4)", zorder=3)
ax2.plot(df["epoch"], df["irm_mean"], color="#F59E0B", lw=2,
ls="--", label="IRM Penalty ↓", zorder=2)
if "grokking_detected" in df.columns:
grok = df[df["grokking_detected"].astype(bool)]
if len(grok):
ep = int(grok["epoch"].min())
ax.axvline(ep, color="gray", ls=":", lw=1.5)
ax.annotate(f"Grokking\nep.{ep}",
xy=(ep, 0.5),
xytext=(ep + ep * 0.05, 0.3),
fontsize=9, color="gray",
arrowprops=dict(arrowstyle="->", color="gray"))
ax.set_xlabel("Epoch")
ax.set_ylabel("Val Accuracy", color=color)
ax2.set_ylabel("IRM Penalty (↓ = causal)", color="#F59E0B")
ax.set_title(title, fontweight="bold")
ax.tick_params(axis="y", labelcolor=color)
ax2.tick_params(axis="y", labelcolor="#F59E0B")
ax.set_ylim([0, 1.05])
ax.grid(alpha=0.3)
h1, l1 = ax.get_legend_handles_labels()
h2, l2 = ax2.get_legend_handles_labels()
ax.legend(h1 + h2, l1 + l2, loc="center left", fontsize=9)
fig.suptitle(
"Figure 1 — IRM Invariance Penalty Drops at the Grokking Transition\n"
"Causal feature discovery and delayed generalization are the same event",
fontsize=12, y=1.02
)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "figure1_smoking_gun.png"), bbox_inches="tight")
plt.savefig(os.path.join(save_dir, "figure1_smoking_gun.pdf"), bbox_inches="tight")
print(" Figure 1 saved")
plt.close()
def figure2_mechanisms(data: Dict[str, pd.DataFrame], save_dir: str):
grok_key, _ = pick_headline_curves(data)
if grok_key is None:
print(" Skipping Figure 2 (no grokking data)")
return
df = data[grok_key]
fig, ax1 = plt.subplots(figsize=(10, 5))
ax2 = ax1.twinx()
ax3 = ax1.twinx()
ax3.spines["right"].set_position(("outward", 60))
ax1.plot(df["epoch"], df["val_acc"], "#2563EB", lw=2.5, label="Val Acc")
ax2.plot(df["epoch"], df["weight_norm"], "#10B981", lw=2, ls="--", label="Weight Norm ‖W‖")
ax3.plot(df["epoch"], df["feature_rank"], "#F59E0B", lw=2, ls="-.", label="Feature Rank")
ax1.set_xlabel("Epoch"); ax1.set_ylabel("Val Accuracy", color="#2563EB")
ax2.set_ylabel("Weight Norm", color="#10B981")
ax3.set_ylabel("Feature Rank", color="#F59E0B")
ax1.tick_params(axis="y", labelcolor="#2563EB")
ax2.tick_params(axis="y", labelcolor="#10B981")
ax3.tick_params(axis="y", labelcolor="#F59E0B")
handles = (ax1.get_legend_handles_labels()[0]
+ ax2.get_legend_handles_labels()[0]
+ ax3.get_legend_handles_labels()[0])
labels = (ax1.get_legend_handles_labels()[1]
+ ax2.get_legend_handles_labels()[1]
+ ax3.get_legend_handles_labels()[1])
ax1.legend(handles, labels, loc="center left", fontsize=9)
ax1.set_title(
"Figure 2 — Training Dynamics: Weight Norm + Feature Rank as Progress Measures",
fontweight="bold")
ax1.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "figure2_mechanisms.png"), bbox_inches="tight")
print(" Figure 2 saved")
plt.close()
def figure3_shortcut(data: Dict[str, pd.DataFrame], save_dir: str):
grok_key, _ = pick_headline_curves(data)
if grok_key is None:
print(" Skipping Figure 3 (no grokking data)")
return
df = data[grok_key]
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(df["epoch"], df["center_conf"], "#2563EB", lw=2,
label="Center (anatomy) confidence")
ax.plot(df["epoch"], df["border_conf"], "#DC2626", lw=2, ls="--",
label="Border (artifact) confidence")
ax.plot(df["epoch"], df["shortcut_ratio"], "#F59E0B", lw=2, ls="-.",
label="Shortcut ratio (border/center)")
ax.axhline(1.0, color="gray", ls=":", lw=1, alpha=0.7,
label="Ratio = 1 (equal reliance)")
ax.set_xlabel("Epoch"); ax.set_ylabel("Confidence / Ratio")
ax.set_title(
"Figure 3 — Shortcut Reliance: Model shifts from artifacts to anatomy at grokking",
fontweight="bold")
ax.legend(fontsize=10); ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "figure3_shortcut.png"), bbox_inches="tight")
print(" Figure 3 saved")
plt.close()
def table1_ablations(runs: List[Dict], save_dir: str):
rows = []
for r in runs:
df = r["df"]
if df.empty:
continue
if "grokking_detected" in df:
grok_rows = df[df["grokking_detected"].astype(bool)]
else:
grok_rows = df.iloc[:0]
irm0 = df["irm_mean"].iloc[0] if "irm_mean" in df else float("nan")
irm_min = df["irm_mean"].min() if "irm_mean" in df else float("nan")
# Co-movement: compare the epoch where val_acc jumped (grokking
# transition) vs. the epoch where IRM dropped fastest. Small gap
# ⇒ same event ⇒ paper's central claim. Large gap ⇒ separate
# events ⇒ weaker, lagged claim.
irm_drop_ep = -1
if "irm_mean" in df and len(df) > 1:
irm_delta = df["irm_mean"].diff().abs()
if irm_delta.notna().any():
irm_drop_ep = int(df.loc[irm_delta.idxmax(), "epoch"])
grok_ep = int(grok_rows["epoch"].min()) if len(grok_rows) else -1
epoch_gap = abs(grok_ep - irm_drop_ep) if grok_ep > 0 and irm_drop_ep > 0 else -1
rows.append({
"run_id": r["run_id"],
"condition": r["cfg"].get("condition", ""),
"n_train": r["cfg"].get("n_train"),
"seed": r["cfg"].get("seed"),
"best_val_acc": df["val_acc"].max() if "val_acc" in df else float("nan"),
"grokking_epoch": grok_ep,
"irm_drop_epoch": irm_drop_ep,
"epoch_gap": epoch_gap,
"irm_drop_pct": (irm0 - irm_min) / (irm0 + 1e-8) * 100,
"final_shortcut_ratio": df["shortcut_ratio"].iloc[-1] if "shortcut_ratio" in df else float("nan"),
"run_dir": r["run_dir"],
})
if not rows:
print(" No runs to summarize.")
return
table = pd.DataFrame(rows).sort_values("best_val_acc", ascending=False)
out_path = os.path.join(save_dir, "table1_ablations.csv")
table.to_csv(out_path, index=False)
print(f"\nTable 1 ({len(table)} runs):")
print(table.to_string(index=False))
print(f"\n Saved → {out_path}")
def per_run_figure(r: Dict):
df = r["df"]
if df.empty:
return
out = os.path.join(r["run_dir"], "figures", "training_curves.png")
fig, ax = plt.subplots(figsize=(9, 4.5))
ax2 = ax.twinx()
ax.plot(df["epoch"], df["val_acc"], "#2563EB", lw=2, label="Val Acc")
ax.plot(df["epoch"], df["train_acc"], "#9CA3AF", lw=1, ls=":", label="Train Acc")
ax2.plot(df["epoch"], df["irm_mean"], "#F59E0B", lw=2, ls="--", label="IRM")
ax.set_xlabel("Epoch"); ax.set_ylabel("Accuracy")
ax2.set_ylabel("IRM penalty")
ax.set_title(r["run_id"], fontsize=10)
ax.grid(alpha=0.3)
h1, l1 = ax.get_legend_handles_labels()
h2, l2 = ax2.get_legend_handles_labels()
ax.legend(h1 + h2, l1 + l2, loc="center left", fontsize=8)
plt.tight_layout()
plt.savefig(out, bbox_inches="tight")
plt.close()
def main():
p = argparse.ArgumentParser()
p.add_argument("--runs_dir", default=DEFAULT_BASE)
p.add_argument("--save_dir", default="paper_figures")
args = p.parse_args()
os.makedirs(args.save_dir, exist_ok=True)
runs = discover_runs(args.runs_dir)
print(f"Found {len(runs)} runs in {args.runs_dir}/")
if not runs:
return
for r in runs:
per_run_figure(r)
data = average_by_condition(runs)
print(f"Conditions averaged: {sorted(data.keys())}")
figure1_smoking_gun(data, args.save_dir)
figure2_mechanisms(data, args.save_dir)
figure3_shortcut(data, args.save_dir)
table1_ablations(runs, args.save_dir)
print(f"\nAll cross-run artifacts in {args.save_dir}/")
if __name__ == "__main__":
main()