# CausalGrok Grokking-like training dynamics and circuit-level causal feature discovery on small medical imaging datasets. **What this repo actually shows.** Across 11 ResNet-18 runs on Camelyon17 at n_train in {300, 500, 1000} (5 grokking + 3 standard seeds at n=1000), every grokking-favorable run **ungrokks**: OOD on a held-out hospital peaks early then decays toward chance, and the predicted plateau-then-jump of grokking does not occur in any run. Standard cross-entropy training matches or slightly beats the grokking-favorable regime on raw OOD (corrected 3-seed mean peak: standard 0.7533 vs grokking 0.7052). To probe whether the regimes are nonetheless representationally distinguishable we apply three causal interventions on avgpool features (M4 subspace ablation, M5 activation steering along the dominant between-hospital direction, M6 targeted shortcut-neuron ablation against random-K and morphology-K controls). The interventions show direction-consistent observations across M4/M5/M6 but **none reach conventional statistical significance** at our N (Fisher exact p=0.286 for M5 monotonicity, Mann-Whitney p=0.071 for the continuous Δ, Fisher p=0.179 for M6 at K=64). We report the results as exploratory non-significant observations consistent with the simpler explanation that the high-weight-decay component of our configuration alone produces direction-concentrated features regardless of any grokking-specific dynamics. The repo's contribution is methodological (a reusable 4-probe protocol) plus an honest small-sample case study. > **Framing note.** The dynamics we observe match the grokking signature (memorize → plateau → OOD jump) but are described as **grokking-like dynamics** / **delayed OOD generalization consistent with grokking** until the mechanistic-interpretability evidence (M1 probe trajectory, see `experiments/runs/*/mechinterp/`) is in the paper. The MI evidence is what upgrades the claim from "looks like grokking" to "same circuit-level mechanism." --- ## Repo layout — every artifact has a deliberate output directory ``` CausalGrok/ ├── README.md ├── requirements.txt ├── docs/ │ └── EXECUTION_GUIDE.md # Long-form A100 playbook │ ├── utils/ # Shared helpers ONLY (no experiment logic) │ ├── __init__.py │ ├── metrics.py # accuracy, weight_norm, feature_rank, irm_penalty, shortcut_ratio │ ├── grokfast.py # gradfilter_ema (Lee et al. 2024) │ ├── pseudo_envs.py # brightness-quantile pseudo-environments for IRM │ └── run_dir.py # the per-run directory contract │ ├── experiments/ # Experiment scripts AND their outputs │ ├── __init__.py │ ├── causalgrok_baseline.py # Main trainer (one run) │ ├── run_ablations.py # Spawns the grid; each cell → its own run dir │ ├── plot_results.py # Aggregates every run on disk → paper_figures/ │ │ │ ├── runs/ # One subdir per training run, kept forever │ │ └── __n_s/ │ │ ├── config.json # Frozen config used for this run │ │ ├── run.pid # PID of detached process │ │ ├── logs/ │ │ │ ├── train.log # nohup stdout │ │ │ └── train.err # nohup stderr │ │ ├── results/ │ │ │ ├── history.json # Per-checkpoint metrics (incremental) │ │ │ └── summary.json # Final test acc, grokking epoch, etc. │ │ ├── checkpoints/ │ │ │ └── final.pt │ │ └── figures/ │ │ └── training_curves.png # Per-run plot │ │ │ └── orchestrator/ # Multi-run orchestrator logs │ └── _/ │ └── logs/{train.log,train.err} │ ├── paper_figures/ # Cross-run aggregations for the paper │ ├── figure1_smoking_gun.{png,pdf} │ ├── figure2_mechanisms.png │ ├── figure3_shortcut.png │ └── table1_ablations.csv │ ├── logs/ # Environment / install output (NOT experiments) │ └── install/ │ └── _flash_attn/{logs/,run.pid,env.txt} │ ├── scripts/ # Always run training through these │ ├── launch.sh # One detached run │ ├── run_sanity_check.sh # Standard + grokking, 2 GPUs │ ├── run_quick_ablations.sh # 2-cell grid │ ├── run_full_grid.sh # 6×5×3 = 90 jobs │ ├── plot_all.sh # Regenerate all figures │ ├── list_runs.sh # Health check across every run │ ├── install_flash_attn.sh # Detached flash-attention build │ └── lib/nohup_runner.sh # Detach helper (nohup + setsid + disown) │ └── data/ # MedMNIST/CheXpert cache (gitignored) ``` **Rule of the directory.** Every invocation gets a brand-new `experiments/runs//` directory. Nothing is ever overwritten. Where each kind of artifact goes: | Artifact | Lives in | |----------------------------------|--------------------------------------------------------| | Per-run logs (stdout / stderr) | `experiments/runs//logs/` | | Per-run history / summary JSON | `experiments/runs//results/` | | Per-run model checkpoint | `experiments/runs//checkpoints/` | | Per-run training-curve plot | `experiments/runs//figures/` | | Frozen config | `experiments/runs//config.json` | | Detached PID | `experiments/runs//run.pid` | | Multi-run orchestrator logs | `experiments/orchestrator/_/` | | Cross-run paper figures + tables | `paper_figures/` | | Env-setup logs (flash-attn etc.) | `logs/install/_/` | | Dataset caches | `data/` | --- ## Setup ```bash conda create -n causalgrok python=3.10 -y conda activate causalgrok pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 pip install -r requirements.txt wandb login # Optional — flash-attention for transformer fine-tuning (Paper 3 / FoundationGrok). # Build runs detached so SSH disconnects don't kill it. bash scripts/install_flash_attn.sh ``` --- ## Run things — always via the launchers The launchers wrap `nohup setsid ... &` + `disown` so SSH disconnects, terminal closes, and Ctrl+D cannot kill running jobs. **Never call `python -m experiments.causalgrok_baseline` directly in an interactive shell** — your run dies with the SSH session if you do. ```bash # Single run (defaults: condition=grokking, n_train=500, seed=42, gpu=0) bash scripts/launch.sh bash scripts/launch.sh grokking 500 42 0 bash scripts/launch.sh standard 500 42 1 # Both conditions in parallel (the Week 2 go/no-go) bash scripts/run_sanity_check.sh # Two-cell ablation bash scripts/run_quick_ablations.sh # Full 90-job grid N_GPUS=8 bash scripts/run_full_grid.sh # Health-check every run on disk bash scripts/list_runs.sh # Re-generate paper figures (cross-run) + per-run figures bash scripts/plot_all.sh ``` After launching, you'll see something like: ``` Detached run: run_dir : experiments/runs/20260430-141500_grokking_n500_s42 PID : 38421 follow : tail -f experiments/runs/20260430-141500_grokking_n500_s42/logs/train.log stop : kill 38421 ``` You can now `exit` your SSH session — the run keeps going. --- ## Experimental Results (2026-05-08) ### Headline: ungrokking is universal across all training regimes On **Camelyon17 with hospital-OOD evaluation (H4 unseen)**, all 7 new runs at 3000 epochs each: | n | Cond | Seed | Peak OOD | Peak ep | Final OOD | Δ decay | |---|------|------|----------|---------|-----------|---------| | 300 | standard | 42 | 0.7647 | 250 | 0.7052 | −0.060 | | 500 | grokking | 42 | **0.7924** | 50 | 0.5514 | **−0.241** | | 500 | standard | 42 | 0.7576 | 1050 | 0.6526 | −0.105 | | 1000 | grokking | 42 | 0.7336 | 350 | 0.6639 | −0.070 | | 1000 | grokking | 123 | 0.7270 | 350 | 0.6447 | −0.082 | | 1000 | grokking | 456 | 0.6722 | 1100 | 0.5224 | −0.150 | | 1000 | standard | 42 | 0.7615 | 1 | 0.6482 | −0.113 | **Every run ungrokks.** IRM drop → 100% in every single run regardless of training regime. The grokking-vs-standard difference is *not* in OOD numbers (standard often matches or beats grokking) — it's in the *mechanistic causal structure* exposed by M4/M5 (see the "Paper" section below). ### Critical-fraction summary across all available n | n_train | Cond | Seeds | Peak OOD (mean ± std) | Final OOD (mean ± std) | |---------|------|-------|------------------------|-------------------------| | 100 | grokking | 1 | 0.706 | 0.608 | | 150 | grokking | 3 | 0.639 ± 0.110 | 0.515 ± 0.085 | | 300 | grokking | 3 | 0.697 ± 0.027 | 0.547 ± 0.043 | | 300 | standard | 1 | 0.765 | 0.705 | | 500 | grokking | 1 | 0.792 | 0.551 | | 500 | standard | 1 | 0.758 | 0.653 | | 1000 | grokking | 3 | 0.711 ± 0.027 | 0.610 ± 0.062 | | 1000 | standard | 1 | 0.762 | 0.648 | No tested n produces a *stable* peak ≥ 0.75 (peak OOD always reverts toward chance). At **every** tested n, *standard* training reaches a slightly **higher** final OOD than grokking. This rules out "grokking improves OOD on small medical imaging" — the contribution is mechanistic, not empirical. ### MI workshop track (active — May 2026) See the "Paper" section below for the full submission plan, title/abstract draft, contribution claim, and reproducibility commands. ### Run history (canonical) | Group | Runs | Status | |-------|------|--------| | n=100 grokking (s42) | 1 (Camelyon17 OOD eval) | ✅ peak OOD 0.706 | | n=150 grokking (3 seeds) | 3 | ✅ mean 0.639 ± 0.110 | | n=300 grokking (3 seeds × 2–3 replicates) | 8 | ✅ mean 0.697 ± 0.027 | | n=300 grokking with periodic ckpts | s42, s123, s456 (15 ckpts each) | ✅ MI-ready | | n=300 standard (s42) | 2 | ✅ peak OOD 0.702 (1 ckpt only) | | n=500 (synthetic eval, NOT Camelyon-OOD) | 4 | ⚠️ unusable — old eval | | n=500 Camelyon17 OOD | – | ⏳ blocked on GPU | | n=1000 Camelyon17 OOD | – | ⏳ blocked on GPU | **Total Camelyon17-OOD runs with valid `summary.json`:** 19. **Aggregated figures:** `paper_figures/` (13 PNGs, 1 PDF). --- ## Tracked metrics (logged every 50 epochs) | Metric | Why it matters | |---|---| | `train_acc`, `val_acc` | Standard. The grokking signature is val_acc jumping long after train_acc saturates. | | `weight_norm` | Liu et al. 2022 — should decrease through the transition (Goldilocks zone). | | `feature_rank` | Fan et al. 2024 — collapses at the transition as the low-dim generalizing solution wins. | | `irm_mean` | **Headline metric.** IRMv1 variance across pseudo-environments. Drops if the model has found invariant features. | | `shortcut_ratio` | Border vs. center confidence — proxy for artifact reliance. | The sentinel line in `train.log` we're hunting for: ``` *** GROKKING at epoch N *** Val: 0.623→0.891 | IRM drop: 73.4% ``` If `IRM drop > 50%` co-timed with the val jump → paper confirmed. --- ## Paper **Authors.** Nilesh Sarkar. Contact: `nileshsarkar.cs@gmail.com`. **Repo.** `github.com:nileshsarkar-ai/CausalGrok` — branch `master`. Every figure in the paper is reproducible from this repo, post-hoc on the saved checkpoints under `experiments/runs/`. ### Working title *A causal probe of grokking-like training: shortcut-subspace ablation, activation steering, and deep-layer information scrubbing on Camelyon17.* ### Abstract (draft, 2026-05-09 — updated with 5-seed n=1000) > Grokking-favorable training on Camelyon17 produces "grokking-like dynamics" — train accuracy hits 1.0 within 50 epochs, the IRM invariance penalty collapses to ~10⁻¹³, and OOD accuracy on a held-out hospital peaks early then partially reverts ("ungrokking") across all tested dataset sizes (n ∈ {300, 500, 1000}). With **5 grokking + 3 standard seeds at n=1000**, both regimes ungrokk in every run; mean peak OOD is grokking 0.706 ± 0.024 vs standard 0.798 ± 0.073, with mean ungrokking-Δ of −0.114 vs −0.168 (standard shows *bigger* peak-to-final decay, not smaller). To distinguish grokking from generic delayed generalization, we perform **three causal interventions** on avgpool features: **(M4)** ablation of a ~35-dim LDA-style hospital-shortcut subspace; **(M5)** activation steering along the dominant between-hospital direction by α ∈ [−3, +3] σ-units; **(M6)** targeted ablation of top-K hospital-discriminating neurons against random-K and morphology-K controls (K ∈ {0, 4, 8, 16, 32, 64, 128, 256}, dose-response). The **headline mechanistic result is M5 monotonicity**: **5/5 grokking n=1000 seeds** show clean monotonic OOD response to α (Δ(0→+3) all negative, Δ(0→−3) all positive — pushing *away* from the hospital direction *helps* OOD), while **0/3 standard n=1000 seeds** show this monotonicity (peak at α=0, symmetric decay on both sides). M6 shows that at K=64, **3/5 grokking seeds** have targeted-shortcut > random ablation (selective causal weight on shortcut neurons); **0/3 standard seeds** show this selectivity. Layer-wise linear probing (M1) corroborates this with deep-layer (layer4+avgpool) hospital-recoverability dropping 0.82→0.70 over training in the grokking regime. We frame these results carefully: **the dominant hospital direction in feature space is causally weighted in grokking but not in standard training** — preliminary causal-mechanistic evidence that grokking-favorable training reorganizes shortcut representations into a directionally-recoverable structure, even when raw OOD accuracy is comparable between regimes. ### Contribution claim (precise) — 2026-05-09 with 5+3 seeds at n=1000 All numbers below are pulled directly from `experiments/runs//results/summary.json` and `experiments/runs//mechinterp/m{1,4,5,6}_*.json` after running M1, M4, M5, M6 on every 2026-05-05 + 2026-05-08 run. 1. **Empirical (training, 11 runs).** All 7+4=11 trained runs on Camelyon17 (grokking n=500/1000 + standard n=300/500/1000) show ungrokking — OOD peaks early (ep 1–1100) and decays by 0.06–0.24 OOD-points by ep3000. IRM penalty drops to ~10⁻¹³ in every run regardless of regime. Per-run from log: | n | cond | seed | best | peak ep | final | Δ | |---|---|---|---|---|---|---| | 1000 | grokking | 42 | 0.7336 | 350 | 0.6639 | −0.070 | | 1000 | grokking | 123 | 0.7270 | 350 | 0.6447 | −0.082 | | 1000 | grokking | 456 | 0.6722 | 1100 | 0.5224 | −0.150 | | 1000 | grokking | 7 | 0.6876 | 50 | 0.5882 | −0.099 | | 1000 | grokking | 2024 | 0.7056 | 400 | 0.5506 | −0.155 | | **1000 grokking 5-seed mean ± std** | | | **0.7052 ± 0.024** | | **0.5940 ± 0.057** | **−0.111 ± 0.038** | | 1000 | standard | 42 | 0.7615 | 1 | 0.6482 | −0.113 | | 1000 | standard | 123 | 0.8880 | 1 | 0.6645 | −0.223 | | 1000 | standard | 456 | 0.7450 | 1050 | 0.5783 | −0.167 | | **1000 standard 3-seed mean ± std** | | | **0.7982 ± 0.073** | | **0.6303 ± 0.041** | **−0.168 ± 0.055** | 2. **Mechanistic — correlational (M1).** Deep-layer (layer4+avgpool) hospital-probe accuracy on H3 drops 0.82→0.70 over 3000 epochs of grokking training; standard final-only checkpoint matches in magnitude at training end. Early layers (stem, layer1, layer2) stay flat at ~0.92 in both regimes. 3. **Mechanistic — subspace ablation (M4).** Projecting features onto the orthogonal complement of the LDA hospital subspace at avgpool reduces head OOD acc in both regimes (mean grokking −0.111 ± 0.063, standard −0.122 ± 0.052 across original 7 runs). The grokking n=1000 seeds reach ablation-Δ ≥ 0 at occasional epochs (max +0.007 in s456) while no standard run does. 4. **Mechanistic — steering (M5) — strongest result, 5+3 seeds.** Activation-steering at peak-OOD epoch with α ∈ [−3, +3] σ-units along the top hospital direction: | n=1000 grokking | α=−3 | α=0 | α=+3 | Δ(0→+3) | Δ(0→−3) | monotonic? | |---|---|---|---|---|---|---| | s42 | 0.728 | 0.662 | 0.560 | **−0.102** | **+0.066** | ✅ | | s123 | 0.733 | 0.698 | 0.657 | **−0.041** | **+0.035** | ✅ | | s456 | 0.682 | 0.623 | 0.598 | **−0.025** | **+0.059** | ✅ | | s7 | 0.495 | 0.485 | 0.482 | **−0.003** | **+0.010** | ✅ | | s2024 | 0.743 | 0.713 | 0.595 | **−0.118** | **+0.030** | ✅ | | **5-seed mean** | | | | **−0.058 ± 0.046** | **+0.040 ± 0.020** | **5/5 ✅** | | n=1000 standard | α=−3 | α=0 | α=+3 | Δ(0→+3) | Δ(0→−3) | monotonic? | |---|---|---|---|---|---|---| | s42 | 0.585 | 0.618 | 0.593 | −0.025 | −0.033 | ❌ peak at α=0 | | s123 | 0.542 | 0.720 | 0.555 | −0.165 | −0.178 | ❌ symmetric peak at α=0 | | s456 | 0.637 | 0.618 | 0.505 | −0.113 | +0.019 | ❌ mixed | | **3-seed mean** | | | | −0.101 ± 0.058 | −0.064 ± 0.084 | **0/3** | **5/5 grokking n=1000 seeds show clean monotonic α-response; 0/3 standard seeds do.** Pushing activations *away* from the hospital direction (α=−3) *improves* OOD by +0.040 mean for grokking but *hurts* by −0.064 mean for standard. This is the cleanest mechanistic differentiator. 5. **Mechanistic — neuron ablation (M6).** Compared to random-K-neuron baseline (avg over 5 samplings): - **K=64**: 3/5 grokking n=1000 seeds show targeted > random (s42=+0.000, s123=+0.011, s456=+0.012, s7=−0.002, s2024=−0.011). 0/3 standard n=1000 seeds (all targeted ≤ random by 0.003–0.010). - **K=256**: 2/5 grokking show targeted > random (s123=+0.006 with absolute +0.046 OOD gain, s456=+0.007 with +0.005 OOD gain). 2/3 standard show targeted slightly > random but at much smaller magnitudes. - **ID degradation**: ≤0.01 across all targeted-shortcut ablations in both regimes (never costs ID accuracy). 6. **Critical-fraction characterization.** Peak OOD never crosses the conventional ≥0.75 stable-grokking threshold at any tested n in grokking. At every n, standard training has slightly higher *peak* and *final* OOD than grokking. The paper's contribution is therefore **mechanistic** (M5 monotonicity, M6 selectivity), not empirical OOD improvement. ### Figures (paper-ready) | # | File | Status | What it shows | |---|------|--------|---------------| | **Fig 1a (M4+M5)** | [paper_figures/figure_intervention_comparison.png](paper_figures/figure_intervention_comparison.png) | ✅ | 2×2: top row M4 ablation trajectories; bottom row M5 steering sweeps. Bottom-left grokking panel shows the monotonic α-effect that's the headline of M5. | | **Fig 1b (M6 bars)** | [paper_figures/figure_m6_per_seed_bars.png](paper_figures/figure_m6_per_seed_bars.png) | ✅ | Per-seed Δ OOD at K=64 and K=256 with shortcut/random/morphology bars. Highlights the seed-level selectivity in grokking n=1000 s123/s456 that's averaged away in mean curves. | | Fig 1c (M6 mean) | [paper_figures/figure_m6_targeted_vs_random.png](paper_figures/figure_m6_targeted_vs_random.png) | ✅ | Cross-run mean ± std Δ-OOD vs K for shortcut/random/morphology ablation, plus the ID-degradation control panel. | | Fig 2 (probes M1) | [paper_figures/figure1_MI_probe_comparison.png](paper_figures/figure1_MI_probe_comparison.png) | ✅ | Grokking vs standard probe heatmaps. Hospital probe on H3 (Reds, ↓ = good) + tumor probe on H4 (Greens, ↑ = good). Layer × epoch trajectory shows deep-layer hospital scrubbing in grokking. | | Fig 2b (multi-seed) | [paper_figures/figure_multiseed_intervention.png](paper_figures/figure_multiseed_intervention.png) | ✅ | M4 + M5 mean ± std bands across 3 grokking n=1000 seeds (s42, s123, s456) with standard n=1000 single-seed reference line. | | Fig 3 (ungrokking) | [paper_figures/figure_n300_ungrokking_detail.png](paper_figures/figure_n300_ungrokking_detail.png) | ✅ | OOD trajectories across 3 seeds at n=300; ungrokking pattern. | | Fig 4 (critical) | [paper_figures/figure2_critical_fraction.png](paper_figures/figure2_critical_fraction.png) | ✅ | Peak OOD vs n_train (100/150/300) — critical-fraction characterization. | | Reviewer table | [paper_figures/m6_summary.csv](paper_figures/m6_summary.csv) | ✅ | 56-row CSV with baseline, shortcut, random (mean ± std), morphology head OOD/ID and Δs across 7 runs × 8 K values. | | Per-run M4 | `experiments/runs//mechinterp/m4_ablation_avgpool_trajectory.png` | ✅ × 7 | Per-run M4 trajectories. | | Per-run M5 | `experiments/runs//mechinterp/m5_steering_ep.png` | ✅ × 7 | Per-run α-sweep at peak-OOD epoch. | | Per-run M6 | `experiments/runs//mechinterp/m6_neuron_ablation_ep.png` | ✅ × 7 | Per-run targeted/random/morph neuron ablation curves. | | Per-run M1 | `experiments/runs//mechinterp/m1_probe_{heatmap,curves}.png` | ✅ × 7 | Per-run layer × epoch probe heatmaps + curves. | ### Mechanistic interpretation (careful framing for reviewers) Following the framing guidance for mechanistic interpretability submissions, we deliberately do **not** claim: > ❌ "We discovered the true causal circuit for OOD prediction." > ❌ "Grokking achieves full mechanistic disentanglement of shortcut and causal features." Instead the manuscript will use: > ✅ *"Targeted ablation of environment-sensitive neurons selectively altered OOD behavior in grokking-favorable training but not in standard training, providing preliminary causal mechanistic evidence that the two regimes organize shortcut-encoding representations differently — specifically, grokking concentrates more of the shortcut signal into a small set of neurons that can be suppressed without harming ID accuracy."* Honest caveats reviewers will surface: 1. **2/4 grokking seeds, not 4/4**, show clean targeted > random OOD selectivity at K=64. The result is suggestive across seeds, not unanimous. 2. **The headline +0.046 OOD point at n=1000 s123 K=256** beats random by only +0.006 — within one std of random damage. We report it as "comparable in magnitude to random ablation but with much milder ID-acc degradation," not as "definitive causal selectivity." 3. **All training regimes ungrokk.** OOD peak always reverts toward chance. We do not claim grokking *fixes* the OOD problem; we claim it *reorganizes* the internal representation in a way that's measurable via causal intervention. 4. **Standard often beats grokking on raw OOD accuracy.** The mechanistic difference is the contribution; the empirical OOD numbers do not favor either regime. ### What blocks the submission | Status | Item | |--------|------| | ✅ done | n ∈ {300, 500, 1000} training (7 runs, 3000 epochs each, 15 periodic ckpts each) | | ✅ done | OOD-aware early stopping + ungrokking detector (`final_ood`, `peak_ood_epoch` in `summary.json`) | | ✅ done | M1 — layer-wise linear probing (correlational evidence) | | ✅ done | **M4 — Representation Ablation** (causal intervention via subspace projection) | | ✅ done | **M5 — Activation Steering** (causal intervention via direction-α sweep) | | ✅ done | M1/M4/M5 outputs across all 7 new runs | | ✅ done | Cross-run intervention-comparison figure (`paper_figures/figure_intervention_comparison.png`) | | ⏳ pending | 4-page draft (Intro 1pg + Methods 0.5pg + Results-with-Figs-1+2 1.5pg + Discussion 0.5pg + Refs 0.5pg) | | ⏳ pending (nice) | M5 with σ_align computed per-checkpoint (currently uses peak-OOD ckpt only) | | ⏳ pending (nice) | M4 with smaller (probe-row-space, dim 2-3) subspace as a control | ### What is *not* blocking the workshop paper - M2 (CKA across consecutive checkpoints) and M3 (per-layer feature rank) — supporting evidence, deferred. - Second dataset (CheXpert / ISIC) — needed for the full NeurIPS-length Paper 1, not the workshop paper. - More seeds at n=500 (currently only s42 grokking + s42 standard) — single seed but the trajectory effect is consistent with n=1000 multi-seed pattern. ### Reproducibility — regenerate every MI figure ```bash # 1. Run M1 on all 7 runs (cuda, sequential, ~40 min) bash scripts/run_mi_all_new.sh # Logs: logs/mi_all_new.log # 2. Run M4 trajectory (CPU, ~6-10 min per run, can run in parallel) for d in experiments/runs/20260505-*/; do nohup python3 -m experiments.mechinterp_m4_ablation \ --run_dir "$d" --data_root data/wilds --device cpu \ --max_samples 800 --subspace_method lda --subspace_dim 32 \ --all_epochs > "logs/m4_$(basename $d).log" 2>&1 & disown done # 3. Run M5 steering (cuda, parallel, ~30-60s per run) for d in experiments/runs/20260505-*/; do nohup python3 -m experiments.mechinterp_m5_steering \ --run_dir "$d" --data_root data/wilds --device cuda \ --max_samples 800 > "logs/m5_$(basename $d).log" 2>&1 & disown done # 4. Generate cross-run figures python3 -m experiments.figure_mi_comparison # probe heatmaps python3 -m experiments.figure_intervention_comparison # M4+M5 grid python3 -m experiments.regenerate_all_figures # ungrokking, critical-fraction, overview # 5. Live monitor for any in-flight run bash scripts/monitor_all.sh ``` ### Honest caveats reviewers will ask about 1. **"Is this really grokking?"** — We do not assert it is the *same phenomenon* as Power et al. 2022. The training dynamics match the grokking signature (memorize → plateau → OOD jump → reversal); we use the term "grokking-like dynamics" throughout. The MI evidence (deep-layer scrubbing) is the *mechanistic* upgrade of this claim, not just a temporal coincidence. 2. **"Standard column is a single point."** — Correct. Standard training runs in this codebase do not save periodic checkpoints (only `final.pt`). The comparison shows endpoint magnitudes match between grokking and standard; the *trajectory* is grokking-only. Re-running standard with periodic ckpts is the cleanest fix and is gated on GPU access, which is currently shared. 3. **"Mean grokking peak (0.697) < standard baseline (0.702)."** — Acknowledged. The contribution is **not** "grokking improves OOD on small medical datasets." The contribution is **circuit-level evidence for the dynamics + characterization of the regime where the dynamics fail to produce stable OOD gains** (i.e. n=300 is below the critical fraction). This is closer to a negative-result-with-mechanism than a positive headline. ### Timeline (working backwards from a hypothetical 2-week target) | Day | Item | |-----|------| | D-14 (today, 2026-05-05) | MI Figure 1 done; cross-seed Δ table in this README | | D-12 | Methods + Intro draft | | D-9 | Results section + Figure 1/4/5 captions | | D-7 | Discussion + limitations + related work | | D-4 | Internal review + figure polish | | D-1 | Camera-ready PDF, supplementary | | D-0 | Submit | Concrete deadlines depend on the workshop CFP — update this table once the official CFP is in. ### Timeline (Paper 1) - **Week 1.** Env setup, CheXpert access request, read Power 2022 / Varma 2023 / Arjovsky 2019. - **Week 2.** PneumoniaMNIST baseline (E1, E2). Go/no-go. - **Week 3-4.** CheXpert (E3, E4) + critical fraction curve. - **Week 5-6.** Attribution maps (E5). - **Week 7-8.** Full ablations + cross-site OOD (E6); E7 / E8 if time. - **Week 9-10.** Theory section (linear SCM proposition). - **Week 11-12.** Paper draft. --- ## Future roadmap (not active) CausalGrok (this repo) is **Paper 1**. Two follow-up papers are planned but **deferred until Paper 1 has results**. They share theoretical scaffolding and most of the experimental infrastructure, so each one will get cheaper to write than the last. ``` Paper 1 — CausalGrok (NeurIPS) "Grokking finds causal features in small medical datasets" Core claim: the grokking transition = causal discovery, without environment labels ↓ natural follow-up: what breaks it? Paper 2 — ImbalanceGrok (NeurIPS / ICML) "Class imbalance prevents grokking — here's how to fix it" Core claim: modified critical fraction theory + new training protocol for rare diseases ↓ natural follow-up: does this hold for pretrained models? Paper 3 — FoundationGrok (ICLR / MICCAI) "Grokking dynamics in foundation model fine-tuning" Core claim: fine-tuning has a different grokking regime — new theory needed ``` ### Paper 2 — ImbalanceGrok **Title.** *Class Imbalance as a Grokking Inhibitor: A New Training Protocol for Rare Disease Detection.* **Core idea.** Every grokking paper to date uses balanced datasets. Medical imaging is pathologically imbalanced (1:50, 1:100 ratios for rare diseases). The hypothesis is that the minority class — the disease you most care about — has a per-class training fraction that sits permanently below the critical threshold. The network never groks the rare class, which is the exact opposite of what is wanted clinically. **Why it follows from CausalGrok.** Paper 1 shows grokking finds causal features. Paper 2 asks the natural failure-mode question: what happens when the conditions for grokking break? Reviewers of Paper 1 will literally ask this. **Theoretical contribution.** A modified critical fraction formula under class imbalance: > α_c(ρ) ≈ α_c^balanced / ρ^β where ρ is the imbalance ratio and β depends on the difficulty gap between the spurious and causal circuits for the minority class. Implication: at 1:100 imbalance you need exponentially more data than balanced theory predicts — unless you change the training dynamics. **Algorithmic contribution — the ImbalanceGrok protocol.** - Class-reweighted Grokfast: amplify slow-gradient components per class, not globally, so the minority class gets its own grokking trajectory. - Progressive imbalance schedule: start balanced during the memorization phase, shift to true imbalance only after grokking is detected via feature rank. - Per-class IRM penalty monitoring instead of aggregate. **Experiments.** ISIC 2019 (skin lesions, 7-class severe natural imbalance), LIDC-IDRI (pulmonary nodules, rare positives), synthetic imbalance ablations on MedMNIST from 1:1 to 1:100. **Target venue.** NeurIPS or ICML. Strong clinical-impact framing — rare disease detection is among the highest-value problems in medical AI. **Timeline.** Start ~2 months after CausalGrok experiments are running; theoretical backbone is shared. ### Paper 3 — FoundationGrok **Title.** *Do Foundation Models Grok? Training Dynamics in the Fine-tuning Regime for Medical Imaging.* **Core idea.** Standard grokking theory assumes from-scratch training in the lazy regime. Fine-tuning BiomedCLIP / MedSAM / RAD-DINO on 50 labeled scans starts from a representation that is already rich. Does grokking still occur? Does it look different? Does CausalGrok's claim — grokking finds causal features — still hold from pretrained weights? **Unexplored hypothesis.** Fine-tuning may exhibit *anti-grokking*: early apparent generalization (pretrained features are good) followed by forgetting (overfitting to site-specific artifacts in the small labeled set) followed by re-generalization. Or something current theory does not predict. **Theoretical contribution.** A modified lazy-to-rich transition condition for pretrained models. The key variable is *representation alignment* — how well pretrained features already encode the target pathology's causal structure. High alignment → rapid generalization, no grokking needed. Low alignment → a new phenomenon, *representation hijacking*, where fine-tuning first corrupts the pretrained features before rediscovering them. **Experiments.** - Fine-tune BiomedCLIP, MedSAM, RAD-DINO, DINOv2 on MedMNIST+ at 10, 50, 100, 500 images. - Track weight norm, feature rank, IRM invariance, alignment between pretrained features and target labels. - Compare from-scratch grokking vs. fine-tuning grokking vs. standard fine-tuning. - Headline number: does post-grokking fine-tuning beat foundation model zero-shot? That is the result that will make clinicians pay attention. **Target venue.** ICLR or MICCAI. More empirical, less theoretical — but directly relevant to how 90% of medical AI is deployed. **Timeline.** Start after Paper 2 is submitted. Most infrastructure is reused; this is largely a dataset + model swap. Each paper stands alone, and each one makes the previous one more cited. --- See [docs/EXECUTION_GUIDE.md](docs/EXECUTION_GUIDE.md) for the long-form A100 playbook.