| """ |
| CausalGrok — Paper Figure Generator |
| |
| Reads every experiments/runs/<run_id>/results/history.json on disk and |
| produces: |
| paper_figures/figure1_smoking_gun.png|pdf ← IRM penalty + val acc |
| paper_figures/figure2_mechanisms.png ← weight norm + feature rank |
| paper_figures/figure3_shortcut.png ← shortcut ratio over training |
| paper_figures/table1_ablations.csv ← summary across runs |
| |
| Per-run figures are also saved into experiments/runs/<run_id>/figures/. |
| |
| Run after experiments complete: |
| bash scripts/plot_all.sh |
| # or: |
| python -m experiments.plot_results |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import glob |
| import json |
| import os |
| from typing import Dict, List |
|
|
| import matplotlib |
| import matplotlib.pyplot as plt |
| import pandas as pd |
|
|
| from utils.run_dir import DEFAULT_BASE |
|
|
| matplotlib.rcParams.update({"font.size": 12, "figure.dpi": 150}) |
|
|
|
|
| |
| |
| |
|
|
| def discover_runs(runs_dir: str = DEFAULT_BASE) -> List[Dict]: |
| """One record per run that has a history.json.""" |
| runs = [] |
| for run_dir in sorted(glob.glob(os.path.join(runs_dir, "*"))): |
| hist_path = os.path.join(run_dir, "results", "history.json") |
| cfg_path = os.path.join(run_dir, "config.json") |
| if not os.path.isfile(hist_path): |
| continue |
| try: |
| df = pd.DataFrame(json.load(open(hist_path))) |
| except Exception: |
| continue |
|
|
| |
| |
| |
| if "id_val_acc" in df.columns and "val_acc" not in df.columns: |
| df = df.rename(columns={"id_val_acc": "val_acc"}) |
|
|
| cfg = json.load(open(cfg_path)) if os.path.isfile(cfg_path) else {} |
| runs.append(dict(run_dir=run_dir, df=df, cfg=cfg, |
| run_id=os.path.basename(run_dir))) |
| return runs |
|
|
|
|
| def average_by_condition(runs: List[Dict]) -> Dict[str, pd.DataFrame]: |
| """ |
| Group runs by (condition, n_train) so we never average across |
| incompatible dataset sizes. Returned key is "<condition>_n<N>". |
| """ |
| grouped: Dict[tuple, List[pd.DataFrame]] = {} |
| for r in runs: |
| cond = r["cfg"].get("condition") |
| if cond is None: |
| cond = "grokking" if "grokking" in r["run_id"] else "standard" |
| n_train = r["cfg"].get("n_train", 0) |
| grouped.setdefault((cond, n_train), []).append(r["df"]) |
|
|
| out: Dict[str, pd.DataFrame] = {} |
| for (cond, n), dfs in grouped.items(): |
| merged = pd.concat(dfs, ignore_index=True) |
| numeric_cols = [c for c in merged.columns if c != "epoch" |
| and pd.api.types.is_numeric_dtype(merged[c])] |
| out[f"{cond}_n{n}"] = merged.groupby("epoch")[numeric_cols].mean().reset_index() |
| return out |
|
|
|
|
| def pick_headline_curves(data: Dict[str, pd.DataFrame]): |
| """ |
| Pick one grokking curve and one standard curve for the headline |
| figure. Heuristic: prefer n=500 (the canonical small-data regime |
| for this paper); otherwise fall back to the smallest n_train |
| available. Large-dataset runs grok fast and the plateau |
| disappears, washing out the visual story. |
| """ |
| def best(cond_prefix): |
| keys = [k for k in data if k.startswith(f"{cond_prefix}_n")] |
| if not keys: |
| return None |
| target = f"{cond_prefix}_n500" |
| if target in keys: |
| return target |
| keys.sort(key=lambda k: int(k.split("_n")[-1])) |
| return keys[0] |
|
|
| return best("grokking"), best("standard") |
|
|
|
|
| |
| |
| |
|
|
| def figure1_smoking_gun(data: Dict[str, pd.DataFrame], save_dir: str): |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| grok_key, std_key = pick_headline_curves(data) |
| panels = [ |
| (axes[0], grok_key, "#2563EB", |
| f"Grokking-Favorable Training\n({grok_key or 'no data'})"), |
| (axes[1], std_key, "#DC2626", |
| f"Standard Training\n({std_key or 'no data'})"), |
| ] |
|
|
| for ax, cond, color, title in panels: |
| if cond is None or cond not in data: |
| ax.text(0.5, 0.5, f"No {cond} data yet", |
| ha="center", va="center", transform=ax.transAxes) |
| ax.set_title(title, fontweight="bold") |
| continue |
|
|
| df = data[cond] |
| ax2 = ax.twinx() |
|
|
| ax.plot(df["epoch"], df["val_acc"], color=color, lw=2.5, |
| label="ID Val Accuracy (H3)", zorder=3) |
|
|
| |
| if "ood_acc" in df.columns: |
| ax.plot(df["epoch"], df["ood_acc"], color=color, lw=2.5, ls="--", |
| alpha=0.7, label="OOD Accuracy (H4)", zorder=3) |
|
|
| ax2.plot(df["epoch"], df["irm_mean"], color="#F59E0B", lw=2, |
| ls="--", label="IRM Penalty ↓", zorder=2) |
|
|
| if "grokking_detected" in df.columns: |
| grok = df[df["grokking_detected"].astype(bool)] |
| if len(grok): |
| ep = int(grok["epoch"].min()) |
| ax.axvline(ep, color="gray", ls=":", lw=1.5) |
| ax.annotate(f"Grokking\nep.{ep}", |
| xy=(ep, 0.5), |
| xytext=(ep + ep * 0.05, 0.3), |
| fontsize=9, color="gray", |
| arrowprops=dict(arrowstyle="->", color="gray")) |
|
|
| ax.set_xlabel("Epoch") |
| ax.set_ylabel("Val Accuracy", color=color) |
| ax2.set_ylabel("IRM Penalty (↓ = causal)", color="#F59E0B") |
| ax.set_title(title, fontweight="bold") |
| ax.tick_params(axis="y", labelcolor=color) |
| ax2.tick_params(axis="y", labelcolor="#F59E0B") |
| ax.set_ylim([0, 1.05]) |
| ax.grid(alpha=0.3) |
|
|
| h1, l1 = ax.get_legend_handles_labels() |
| h2, l2 = ax2.get_legend_handles_labels() |
| ax.legend(h1 + h2, l1 + l2, loc="center left", fontsize=9) |
|
|
| fig.suptitle( |
| "Figure 1 — IRM Invariance Penalty Drops at the Grokking Transition\n" |
| "Causal feature discovery and delayed generalization are the same event", |
| fontsize=12, y=1.02 |
| ) |
| plt.tight_layout() |
| plt.savefig(os.path.join(save_dir, "figure1_smoking_gun.png"), bbox_inches="tight") |
| plt.savefig(os.path.join(save_dir, "figure1_smoking_gun.pdf"), bbox_inches="tight") |
| print(" Figure 1 saved") |
| plt.close() |
|
|
|
|
| def figure2_mechanisms(data: Dict[str, pd.DataFrame], save_dir: str): |
| grok_key, _ = pick_headline_curves(data) |
| if grok_key is None: |
| print(" Skipping Figure 2 (no grokking data)") |
| return |
| df = data[grok_key] |
| fig, ax1 = plt.subplots(figsize=(10, 5)) |
| ax2 = ax1.twinx() |
| ax3 = ax1.twinx() |
| ax3.spines["right"].set_position(("outward", 60)) |
|
|
| ax1.plot(df["epoch"], df["val_acc"], "#2563EB", lw=2.5, label="Val Acc") |
| ax2.plot(df["epoch"], df["weight_norm"], "#10B981", lw=2, ls="--", label="Weight Norm ‖W‖") |
| ax3.plot(df["epoch"], df["feature_rank"], "#F59E0B", lw=2, ls="-.", label="Feature Rank") |
|
|
| ax1.set_xlabel("Epoch"); ax1.set_ylabel("Val Accuracy", color="#2563EB") |
| ax2.set_ylabel("Weight Norm", color="#10B981") |
| ax3.set_ylabel("Feature Rank", color="#F59E0B") |
| ax1.tick_params(axis="y", labelcolor="#2563EB") |
| ax2.tick_params(axis="y", labelcolor="#10B981") |
| ax3.tick_params(axis="y", labelcolor="#F59E0B") |
|
|
| handles = (ax1.get_legend_handles_labels()[0] |
| + ax2.get_legend_handles_labels()[0] |
| + ax3.get_legend_handles_labels()[0]) |
| labels = (ax1.get_legend_handles_labels()[1] |
| + ax2.get_legend_handles_labels()[1] |
| + ax3.get_legend_handles_labels()[1]) |
| ax1.legend(handles, labels, loc="center left", fontsize=9) |
| ax1.set_title( |
| "Figure 2 — Training Dynamics: Weight Norm + Feature Rank as Progress Measures", |
| fontweight="bold") |
| ax1.grid(alpha=0.3) |
| plt.tight_layout() |
| plt.savefig(os.path.join(save_dir, "figure2_mechanisms.png"), bbox_inches="tight") |
| print(" Figure 2 saved") |
| plt.close() |
|
|
|
|
| def figure3_shortcut(data: Dict[str, pd.DataFrame], save_dir: str): |
| grok_key, _ = pick_headline_curves(data) |
| if grok_key is None: |
| print(" Skipping Figure 3 (no grokking data)") |
| return |
| df = data[grok_key] |
| fig, ax = plt.subplots(figsize=(10, 5)) |
| ax.plot(df["epoch"], df["center_conf"], "#2563EB", lw=2, |
| label="Center (anatomy) confidence") |
| ax.plot(df["epoch"], df["border_conf"], "#DC2626", lw=2, ls="--", |
| label="Border (artifact) confidence") |
| ax.plot(df["epoch"], df["shortcut_ratio"], "#F59E0B", lw=2, ls="-.", |
| label="Shortcut ratio (border/center)") |
| ax.axhline(1.0, color="gray", ls=":", lw=1, alpha=0.7, |
| label="Ratio = 1 (equal reliance)") |
| ax.set_xlabel("Epoch"); ax.set_ylabel("Confidence / Ratio") |
| ax.set_title( |
| "Figure 3 — Shortcut Reliance: Model shifts from artifacts to anatomy at grokking", |
| fontweight="bold") |
| ax.legend(fontsize=10); ax.grid(alpha=0.3) |
| plt.tight_layout() |
| plt.savefig(os.path.join(save_dir, "figure3_shortcut.png"), bbox_inches="tight") |
| print(" Figure 3 saved") |
| plt.close() |
|
|
|
|
| def table1_ablations(runs: List[Dict], save_dir: str): |
| rows = [] |
| for r in runs: |
| df = r["df"] |
| if df.empty: |
| continue |
| if "grokking_detected" in df: |
| grok_rows = df[df["grokking_detected"].astype(bool)] |
| else: |
| grok_rows = df.iloc[:0] |
| irm0 = df["irm_mean"].iloc[0] if "irm_mean" in df else float("nan") |
| irm_min = df["irm_mean"].min() if "irm_mean" in df else float("nan") |
|
|
| |
| |
| |
| |
| irm_drop_ep = -1 |
| if "irm_mean" in df and len(df) > 1: |
| irm_delta = df["irm_mean"].diff().abs() |
| if irm_delta.notna().any(): |
| irm_drop_ep = int(df.loc[irm_delta.idxmax(), "epoch"]) |
|
|
| grok_ep = int(grok_rows["epoch"].min()) if len(grok_rows) else -1 |
| epoch_gap = abs(grok_ep - irm_drop_ep) if grok_ep > 0 and irm_drop_ep > 0 else -1 |
|
|
| rows.append({ |
| "run_id": r["run_id"], |
| "condition": r["cfg"].get("condition", ""), |
| "n_train": r["cfg"].get("n_train"), |
| "seed": r["cfg"].get("seed"), |
| "best_val_acc": df["val_acc"].max() if "val_acc" in df else float("nan"), |
| "grokking_epoch": grok_ep, |
| "irm_drop_epoch": irm_drop_ep, |
| "epoch_gap": epoch_gap, |
| "irm_drop_pct": (irm0 - irm_min) / (irm0 + 1e-8) * 100, |
| "final_shortcut_ratio": df["shortcut_ratio"].iloc[-1] if "shortcut_ratio" in df else float("nan"), |
| "run_dir": r["run_dir"], |
| }) |
| if not rows: |
| print(" No runs to summarize.") |
| return |
| table = pd.DataFrame(rows).sort_values("best_val_acc", ascending=False) |
| out_path = os.path.join(save_dir, "table1_ablations.csv") |
| table.to_csv(out_path, index=False) |
| print(f"\nTable 1 ({len(table)} runs):") |
| print(table.to_string(index=False)) |
| print(f"\n Saved → {out_path}") |
|
|
|
|
| def per_run_figure(r: Dict): |
| df = r["df"] |
| if df.empty: |
| return |
| out = os.path.join(r["run_dir"], "figures", "training_curves.png") |
| fig, ax = plt.subplots(figsize=(9, 4.5)) |
| ax2 = ax.twinx() |
| ax.plot(df["epoch"], df["val_acc"], "#2563EB", lw=2, label="Val Acc") |
| ax.plot(df["epoch"], df["train_acc"], "#9CA3AF", lw=1, ls=":", label="Train Acc") |
| ax2.plot(df["epoch"], df["irm_mean"], "#F59E0B", lw=2, ls="--", label="IRM") |
| ax.set_xlabel("Epoch"); ax.set_ylabel("Accuracy") |
| ax2.set_ylabel("IRM penalty") |
| ax.set_title(r["run_id"], fontsize=10) |
| ax.grid(alpha=0.3) |
| h1, l1 = ax.get_legend_handles_labels() |
| h2, l2 = ax2.get_legend_handles_labels() |
| ax.legend(h1 + h2, l1 + l2, loc="center left", fontsize=8) |
| plt.tight_layout() |
| plt.savefig(out, bbox_inches="tight") |
| plt.close() |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--runs_dir", default=DEFAULT_BASE) |
| p.add_argument("--save_dir", default="paper_figures") |
| args = p.parse_args() |
|
|
| os.makedirs(args.save_dir, exist_ok=True) |
|
|
| runs = discover_runs(args.runs_dir) |
| print(f"Found {len(runs)} runs in {args.runs_dir}/") |
| if not runs: |
| return |
|
|
| for r in runs: |
| per_run_figure(r) |
|
|
| data = average_by_condition(runs) |
| print(f"Conditions averaged: {sorted(data.keys())}") |
|
|
| figure1_smoking_gun(data, args.save_dir) |
| figure2_mechanisms(data, args.save_dir) |
| figure3_shortcut(data, args.save_dir) |
| table1_ablations(runs, args.save_dir) |
| print(f"\nAll cross-run artifacts in {args.save_dir}/") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|