CausalGrok
Grokking-like training dynamics and circuit-level causal feature discovery on small medical imaging datasets.
What this repo actually shows. Across 11 ResNet-18 runs on Camelyon17 at n_train in {300, 500, 1000} (5 grokking + 3 standard seeds at n=1000), every grokking-favorable run ungrokks: OOD on a held-out hospital peaks early then decays toward chance, and the predicted plateau-then-jump of grokking does not occur in any run. Standard cross-entropy training matches or slightly beats the grokking-favorable regime on raw OOD (corrected 3-seed mean peak: standard 0.7533 vs grokking 0.7052). To probe whether the regimes are nonetheless representationally distinguishable we apply three causal interventions on avgpool features (M4 subspace ablation, M5 activation steering along the dominant between-hospital direction, M6 targeted shortcut-neuron ablation against random-K and morphology-K controls). The interventions show direction-consistent observations across M4/M5/M6 but none reach conventional statistical significance at our N (Fisher exact p=0.286 for M5 monotonicity, Mann-Whitney p=0.071 for the continuous Ξ, Fisher p=0.179 for M6 at K=64). We report the results as exploratory non-significant observations consistent with the simpler explanation that the high-weight-decay component of our configuration alone produces direction-concentrated features regardless of any grokking-specific dynamics. The repo's contribution is methodological (a reusable 4-probe protocol) plus an honest small-sample case study.
Framing note. The dynamics we observe match the grokking signature (memorize β plateau β OOD jump) but are described as grokking-like dynamics / delayed OOD generalization consistent with grokking until the mechanistic-interpretability evidence (M1 probe trajectory, see
experiments/runs/*/mechinterp/) is in the paper. The MI evidence is what upgrades the claim from "looks like grokking" to "same circuit-level mechanism."
Repo layout β every artifact has a deliberate output directory
CausalGrok/
βββ README.md
βββ requirements.txt
βββ docs/
β βββ EXECUTION_GUIDE.md # Long-form A100 playbook
β
βββ utils/ # Shared helpers ONLY (no experiment logic)
β βββ __init__.py
β βββ metrics.py # accuracy, weight_norm, feature_rank, irm_penalty, shortcut_ratio
β βββ grokfast.py # gradfilter_ema (Lee et al. 2024)
β βββ pseudo_envs.py # brightness-quantile pseudo-environments for IRM
β βββ run_dir.py # the per-run directory contract
β
βββ experiments/ # Experiment scripts AND their outputs
β βββ __init__.py
β βββ causalgrok_baseline.py # Main trainer (one run)
β βββ run_ablations.py # Spawns the grid; each cell β its own run dir
β βββ plot_results.py # Aggregates every run on disk β paper_figures/
β β
β βββ runs/ # One subdir per training run, kept forever
β β βββ <YYYYMMDD-HHMMSS>_<condition>_n<N>_s<seed>/
β β βββ config.json # Frozen config used for this run
β β βββ run.pid # PID of detached process
β β βββ logs/
β β β βββ train.log # nohup stdout
β β β βββ train.err # nohup stderr
β β βββ results/
β β β βββ history.json # Per-checkpoint metrics (incremental)
β β β βββ summary.json # Final test acc, grokking epoch, etc.
β β βββ checkpoints/
β β β βββ final.pt
β β βββ figures/
β β βββ training_curves.png # Per-run plot
β β
β βββ orchestrator/ # Multi-run orchestrator logs
β βββ <stamp>_<grid_name>/
β βββ logs/{train.log,train.err}
β
βββ paper_figures/ # Cross-run aggregations for the paper
β βββ figure1_smoking_gun.{png,pdf}
β βββ figure2_mechanisms.png
β βββ figure3_shortcut.png
β βββ table1_ablations.csv
β
βββ logs/ # Environment / install output (NOT experiments)
β βββ install/
β βββ <stamp>_flash_attn/{logs/,run.pid,env.txt}
β
βββ scripts/ # Always run training through these
β βββ launch.sh # One detached run
β βββ run_sanity_check.sh # Standard + grokking, 2 GPUs
β βββ run_quick_ablations.sh # 2-cell grid
β βββ run_full_grid.sh # 6Γ5Γ3 = 90 jobs
β βββ plot_all.sh # Regenerate all figures
β βββ list_runs.sh # Health check across every run
β βββ install_flash_attn.sh # Detached flash-attention build
β βββ lib/nohup_runner.sh # Detach helper (nohup + setsid + disown)
β
βββ data/ # MedMNIST/CheXpert cache (gitignored)
Rule of the directory. Every invocation gets a brand-new experiments/runs/<run_id>/ directory. Nothing is ever overwritten. Where each kind of artifact goes:
| Artifact | Lives in |
|---|---|
| Per-run logs (stdout / stderr) | experiments/runs/<run_id>/logs/ |
| Per-run history / summary JSON | experiments/runs/<run_id>/results/ |
| Per-run model checkpoint | experiments/runs/<run_id>/checkpoints/ |
| Per-run training-curve plot | experiments/runs/<run_id>/figures/ |
| Frozen config | experiments/runs/<run_id>/config.json |
| Detached PID | experiments/runs/<run_id>/run.pid |
| Multi-run orchestrator logs | experiments/orchestrator/<stamp>_<grid>/ |
| Cross-run paper figures + tables | paper_figures/ |
| Env-setup logs (flash-attn etc.) | logs/install/<stamp>_<thing>/ |
| Dataset caches | data/ |
Setup
conda create -n causalgrok python=3.10 -y
conda activate causalgrok
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt
wandb login
# Optional β flash-attention for transformer fine-tuning (Paper 3 / FoundationGrok).
# Build runs detached so SSH disconnects don't kill it.
bash scripts/install_flash_attn.sh
Run things β always via the launchers
The launchers wrap nohup setsid ... & + disown so SSH disconnects, terminal closes, and Ctrl+D cannot kill running jobs. Never call python -m experiments.causalgrok_baseline directly in an interactive shell β your run dies with the SSH session if you do.
# Single run (defaults: condition=grokking, n_train=500, seed=42, gpu=0)
bash scripts/launch.sh
bash scripts/launch.sh grokking 500 42 0
bash scripts/launch.sh standard 500 42 1
# Both conditions in parallel (the Week 2 go/no-go)
bash scripts/run_sanity_check.sh
# Two-cell ablation
bash scripts/run_quick_ablations.sh
# Full 90-job grid
N_GPUS=8 bash scripts/run_full_grid.sh
# Health-check every run on disk
bash scripts/list_runs.sh
# Re-generate paper figures (cross-run) + per-run figures
bash scripts/plot_all.sh
After launching, you'll see something like:
Detached run:
run_dir : experiments/runs/20260430-141500_grokking_n500_s42
PID : 38421
follow : tail -f experiments/runs/20260430-141500_grokking_n500_s42/logs/train.log
stop : kill 38421
You can now exit your SSH session β the run keeps going.
Experimental Results (2026-05-08)
Headline: ungrokking is universal across all training regimes
On Camelyon17 with hospital-OOD evaluation (H4 unseen), all 7 new runs at 3000 epochs each:
| n | Cond | Seed | Peak OOD | Peak ep | Final OOD | Ξ decay |
|---|---|---|---|---|---|---|
| 300 | standard | 42 | 0.7647 | 250 | 0.7052 | β0.060 |
| 500 | grokking | 42 | 0.7924 | 50 | 0.5514 | β0.241 |
| 500 | standard | 42 | 0.7576 | 1050 | 0.6526 | β0.105 |
| 1000 | grokking | 42 | 0.7336 | 350 | 0.6639 | β0.070 |
| 1000 | grokking | 123 | 0.7270 | 350 | 0.6447 | β0.082 |
| 1000 | grokking | 456 | 0.6722 | 1100 | 0.5224 | β0.150 |
| 1000 | standard | 42 | 0.7615 | 1 | 0.6482 | β0.113 |
Every run ungrokks. IRM drop β 100% in every single run regardless of training regime. The grokking-vs-standard difference is not in OOD numbers (standard often matches or beats grokking) β it's in the mechanistic causal structure exposed by M4/M5 (see the "Paper" section below).
Critical-fraction summary across all available n
| n_train | Cond | Seeds | Peak OOD (mean Β± std) | Final OOD (mean Β± std) |
|---|---|---|---|---|
| 100 | grokking | 1 | 0.706 | 0.608 |
| 150 | grokking | 3 | 0.639 Β± 0.110 | 0.515 Β± 0.085 |
| 300 | grokking | 3 | 0.697 Β± 0.027 | 0.547 Β± 0.043 |
| 300 | standard | 1 | 0.765 | 0.705 |
| 500 | grokking | 1 | 0.792 | 0.551 |
| 500 | standard | 1 | 0.758 | 0.653 |
| 1000 | grokking | 3 | 0.711 Β± 0.027 | 0.610 Β± 0.062 |
| 1000 | standard | 1 | 0.762 | 0.648 |
No tested n produces a stable peak β₯ 0.75 (peak OOD always reverts toward chance). At every tested n, standard training reaches a slightly higher final OOD than grokking. This rules out "grokking improves OOD on small medical imaging" β the contribution is mechanistic, not empirical.
MI workshop track (active β May 2026)
See the "Paper" section below for the full submission plan, title/abstract draft, contribution claim, and reproducibility commands.
Run history (canonical)
| Group | Runs | Status |
|---|---|---|
| n=100 grokking (s42) | 1 (Camelyon17 OOD eval) | β peak OOD 0.706 |
| n=150 grokking (3 seeds) | 3 | β mean 0.639 Β± 0.110 |
| n=300 grokking (3 seeds Γ 2β3 replicates) | 8 | β mean 0.697 Β± 0.027 |
| n=300 grokking with periodic ckpts | s42, s123, s456 (15 ckpts each) | β MI-ready |
| n=300 standard (s42) | 2 | β peak OOD 0.702 (1 ckpt only) |
| n=500 (synthetic eval, NOT Camelyon-OOD) | 4 | β οΈ unusable β old eval |
| n=500 Camelyon17 OOD | β | β³ blocked on GPU |
| n=1000 Camelyon17 OOD | β | β³ blocked on GPU |
Total Camelyon17-OOD runs with valid summary.json: 19. Aggregated figures: paper_figures/ (13 PNGs, 1 PDF).
Tracked metrics (logged every 50 epochs)
| Metric | Why it matters |
|---|---|
train_acc, val_acc |
Standard. The grokking signature is val_acc jumping long after train_acc saturates. |
weight_norm |
Liu et al. 2022 β should decrease through the transition (Goldilocks zone). |
feature_rank |
Fan et al. 2024 β collapses at the transition as the low-dim generalizing solution wins. |
irm_mean |
Headline metric. IRMv1 variance across pseudo-environments. Drops if the model has found invariant features. |
shortcut_ratio |
Border vs. center confidence β proxy for artifact reliance. |
The sentinel line in train.log we're hunting for:
*** GROKKING at epoch N ***
Val: 0.623β0.891 | IRM drop: 73.4%
If IRM drop > 50% co-timed with the val jump β paper confirmed.
Paper
Authors. Nilesh Sarkar. Contact: nileshsarkar.cs@gmail.com.
Repo. github.com:nileshsarkar-ai/CausalGrok β branch master. Every figure in the paper is reproducible from this repo, post-hoc on the saved checkpoints under experiments/runs/.
Working title
A causal probe of grokking-like training: shortcut-subspace ablation, activation steering, and deep-layer information scrubbing on Camelyon17.
Abstract (draft, 2026-05-09 β updated with 5-seed n=1000)
Grokking-favorable training on Camelyon17 produces "grokking-like dynamics" β train accuracy hits 1.0 within 50 epochs, the IRM invariance penalty collapses to ~10β»ΒΉΒ³, and OOD accuracy on a held-out hospital peaks early then partially reverts ("ungrokking") across all tested dataset sizes (n β {300, 500, 1000}). With 5 grokking + 3 standard seeds at n=1000, both regimes ungrokk in every run; mean peak OOD is grokking 0.706 Β± 0.024 vs standard 0.798 Β± 0.073, with mean ungrokking-Ξ of β0.114 vs β0.168 (standard shows bigger peak-to-final decay, not smaller). To distinguish grokking from generic delayed generalization, we perform three causal interventions on avgpool features: (M4) ablation of a ~35-dim LDA-style hospital-shortcut subspace; (M5) activation steering along the dominant between-hospital direction by Ξ± β [β3, +3] Ο-units; (M6) targeted ablation of top-K hospital-discriminating neurons against random-K and morphology-K controls (K β {0, 4, 8, 16, 32, 64, 128, 256}, dose-response). The headline mechanistic result is M5 monotonicity: 5/5 grokking n=1000 seeds show clean monotonic OOD response to Ξ± (Ξ(0β+3) all negative, Ξ(0ββ3) all positive β pushing away from the hospital direction helps OOD), while 0/3 standard n=1000 seeds show this monotonicity (peak at Ξ±=0, symmetric decay on both sides). M6 shows that at K=64, 3/5 grokking seeds have targeted-shortcut > random ablation (selective causal weight on shortcut neurons); 0/3 standard seeds show this selectivity. Layer-wise linear probing (M1) corroborates this with deep-layer (layer4+avgpool) hospital-recoverability dropping 0.82β0.70 over training in the grokking regime. We frame these results carefully: the dominant hospital direction in feature space is causally weighted in grokking but not in standard training β preliminary causal-mechanistic evidence that grokking-favorable training reorganizes shortcut representations into a directionally-recoverable structure, even when raw OOD accuracy is comparable between regimes.
Contribution claim (precise) β 2026-05-09 with 5+3 seeds at n=1000
All numbers below are pulled directly from experiments/runs/<id>/results/summary.json and experiments/runs/<id>/mechinterp/m{1,4,5,6}_*.json after running M1, M4, M5, M6 on every 2026-05-05 + 2026-05-08 run.
Empirical (training, 11 runs). All 7+4=11 trained runs on Camelyon17 (grokking n=500/1000 + standard n=300/500/1000) show ungrokking β OOD peaks early (ep 1β1100) and decays by 0.06β0.24 OOD-points by ep3000. IRM penalty drops to ~10β»ΒΉΒ³ in every run regardless of regime. Per-run from log:
n cond seed best peak ep final Ξ 1000 grokking 42 0.7336 350 0.6639 β0.070 1000 grokking 123 0.7270 350 0.6447 β0.082 1000 grokking 456 0.6722 1100 0.5224 β0.150 1000 grokking 7 0.6876 50 0.5882 β0.099 1000 grokking 2024 0.7056 400 0.5506 β0.155 1000 grokking 5-seed mean Β± std 0.7052 Β± 0.024 0.5940 Β± 0.057 β0.111 Β± 0.038 1000 standard 42 0.7615 1 0.6482 β0.113 1000 standard 123 0.8880 1 0.6645 β0.223 1000 standard 456 0.7450 1050 0.5783 β0.167 1000 standard 3-seed mean Β± std 0.7982 Β± 0.073 0.6303 Β± 0.041 β0.168 Β± 0.055 Mechanistic β correlational (M1). Deep-layer (layer4+avgpool) hospital-probe accuracy on H3 drops 0.82β0.70 over 3000 epochs of grokking training; standard final-only checkpoint matches in magnitude at training end. Early layers (stem, layer1, layer2) stay flat at ~0.92 in both regimes.
Mechanistic β subspace ablation (M4). Projecting features onto the orthogonal complement of the LDA hospital subspace at avgpool reduces head OOD acc in both regimes (mean grokking β0.111 Β± 0.063, standard β0.122 Β± 0.052 across original 7 runs). The grokking n=1000 seeds reach ablation-Ξ β₯ 0 at occasional epochs (max +0.007 in s456) while no standard run does.
Mechanistic β steering (M5) β strongest result, 5+3 seeds. Activation-steering at peak-OOD epoch with Ξ± β [β3, +3] Ο-units along the top hospital direction:
n=1000 grokking Ξ±=β3 Ξ±=0 Ξ±=+3 Ξ(0β+3) Ξ(0ββ3) monotonic? s42 0.728 0.662 0.560 β0.102 +0.066 β s123 0.733 0.698 0.657 β0.041 +0.035 β s456 0.682 0.623 0.598 β0.025 +0.059 β s7 0.495 0.485 0.482 β0.003 +0.010 β s2024 0.743 0.713 0.595 β0.118 +0.030 β 5-seed mean β0.058 Β± 0.046 +0.040 Β± 0.020 5/5 β n=1000 standard Ξ±=β3 Ξ±=0 Ξ±=+3 Ξ(0β+3) Ξ(0ββ3) monotonic? s42 0.585 0.618 0.593 β0.025 β0.033 β peak at Ξ±=0 s123 0.542 0.720 0.555 β0.165 β0.178 β symmetric peak at Ξ±=0 s456 0.637 0.618 0.505 β0.113 +0.019 β mixed 3-seed mean β0.101 Β± 0.058 β0.064 Β± 0.084 0/3 5/5 grokking n=1000 seeds show clean monotonic Ξ±-response; 0/3 standard seeds do. Pushing activations away from the hospital direction (Ξ±=β3) improves OOD by +0.040 mean for grokking but hurts by β0.064 mean for standard. This is the cleanest mechanistic differentiator.
Mechanistic β neuron ablation (M6). Compared to random-K-neuron baseline (avg over 5 samplings):
- K=64: 3/5 grokking n=1000 seeds show targeted > random (s42=+0.000, s123=+0.011, s456=+0.012, s7=β0.002, s2024=β0.011). 0/3 standard n=1000 seeds (all targeted β€ random by 0.003β0.010).
- K=256: 2/5 grokking show targeted > random (s123=+0.006 with absolute +0.046 OOD gain, s456=+0.007 with +0.005 OOD gain). 2/3 standard show targeted slightly > random but at much smaller magnitudes.
- ID degradation: β€0.01 across all targeted-shortcut ablations in both regimes (never costs ID accuracy).
Critical-fraction characterization. Peak OOD never crosses the conventional β₯0.75 stable-grokking threshold at any tested n in grokking. At every n, standard training has slightly higher peak and final OOD than grokking. The paper's contribution is therefore mechanistic (M5 monotonicity, M6 selectivity), not empirical OOD improvement.
Figures (paper-ready)
| # | File | Status | What it shows |
|---|---|---|---|
| Fig 1a (M4+M5) | paper_figures/figure_intervention_comparison.png | β | 2Γ2: top row M4 ablation trajectories; bottom row M5 steering sweeps. Bottom-left grokking panel shows the monotonic Ξ±-effect that's the headline of M5. |
| Fig 1b (M6 bars) | paper_figures/figure_m6_per_seed_bars.png | β | Per-seed Ξ OOD at K=64 and K=256 with shortcut/random/morphology bars. Highlights the seed-level selectivity in grokking n=1000 s123/s456 that's averaged away in mean curves. |
| Fig 1c (M6 mean) | paper_figures/figure_m6_targeted_vs_random.png | β | Cross-run mean Β± std Ξ-OOD vs K for shortcut/random/morphology ablation, plus the ID-degradation control panel. |
| Fig 2 (probes M1) | paper_figures/figure1_MI_probe_comparison.png | β | Grokking vs standard probe heatmaps. Hospital probe on H3 (Reds, β = good) + tumor probe on H4 (Greens, β = good). Layer Γ epoch trajectory shows deep-layer hospital scrubbing in grokking. |
| Fig 2b (multi-seed) | paper_figures/figure_multiseed_intervention.png | β | M4 + M5 mean Β± std bands across 3 grokking n=1000 seeds (s42, s123, s456) with standard n=1000 single-seed reference line. |
| Fig 3 (ungrokking) | paper_figures/figure_n300_ungrokking_detail.png | β | OOD trajectories across 3 seeds at n=300; ungrokking pattern. |
| Fig 4 (critical) | paper_figures/figure2_critical_fraction.png | β | Peak OOD vs n_train (100/150/300) β critical-fraction characterization. |
| Reviewer table | paper_figures/m6_summary.csv | β | 56-row CSV with baseline, shortcut, random (mean Β± std), morphology head OOD/ID and Ξs across 7 runs Γ 8 K values. |
| Per-run M4 | experiments/runs/<id>/mechinterp/m4_ablation_avgpool_trajectory.png |
β Γ 7 | Per-run M4 trajectories. |
| Per-run M5 | experiments/runs/<id>/mechinterp/m5_steering_ep<N>.png |
β Γ 7 | Per-run Ξ±-sweep at peak-OOD epoch. |
| Per-run M6 | experiments/runs/<id>/mechinterp/m6_neuron_ablation_ep<N>.png |
β Γ 7 | Per-run targeted/random/morph neuron ablation curves. |
| Per-run M1 | experiments/runs/<id>/mechinterp/m1_probe_{heatmap,curves}.png |
β Γ 7 | Per-run layer Γ epoch probe heatmaps + curves. |
Mechanistic interpretation (careful framing for reviewers)
Following the framing guidance for mechanistic interpretability submissions, we deliberately do not claim:
β "We discovered the true causal circuit for OOD prediction." β "Grokking achieves full mechanistic disentanglement of shortcut and causal features."
Instead the manuscript will use:
β "Targeted ablation of environment-sensitive neurons selectively altered OOD behavior in grokking-favorable training but not in standard training, providing preliminary causal mechanistic evidence that the two regimes organize shortcut-encoding representations differently β specifically, grokking concentrates more of the shortcut signal into a small set of neurons that can be suppressed without harming ID accuracy."
Honest caveats reviewers will surface:
- 2/4 grokking seeds, not 4/4, show clean targeted > random OOD selectivity at K=64. The result is suggestive across seeds, not unanimous.
- The headline +0.046 OOD point at n=1000 s123 K=256 beats random by only +0.006 β within one std of random damage. We report it as "comparable in magnitude to random ablation but with much milder ID-acc degradation," not as "definitive causal selectivity."
- All training regimes ungrokk. OOD peak always reverts toward chance. We do not claim grokking fixes the OOD problem; we claim it reorganizes the internal representation in a way that's measurable via causal intervention.
- Standard often beats grokking on raw OOD accuracy. The mechanistic difference is the contribution; the empirical OOD numbers do not favor either regime.
What blocks the submission
| Status | Item |
|---|---|
| β done | n β {300, 500, 1000} training (7 runs, 3000 epochs each, 15 periodic ckpts each) |
| β done | OOD-aware early stopping + ungrokking detector (final_ood, peak_ood_epoch in summary.json) |
| β done | M1 β layer-wise linear probing (correlational evidence) |
| β done | M4 β Representation Ablation (causal intervention via subspace projection) |
| β done | M5 β Activation Steering (causal intervention via direction-Ξ± sweep) |
| β done | M1/M4/M5 outputs across all 7 new runs |
| β done | Cross-run intervention-comparison figure (paper_figures/figure_intervention_comparison.png) |
| β³ pending | 4-page draft (Intro 1pg + Methods 0.5pg + Results-with-Figs-1+2 1.5pg + Discussion 0.5pg + Refs 0.5pg) |
| β³ pending (nice) | M5 with Ο_align computed per-checkpoint (currently uses peak-OOD ckpt only) |
| β³ pending (nice) | M4 with smaller (probe-row-space, dim 2-3) subspace as a control |
What is not blocking the workshop paper
- M2 (CKA across consecutive checkpoints) and M3 (per-layer feature rank) β supporting evidence, deferred.
- Second dataset (CheXpert / ISIC) β needed for the full NeurIPS-length Paper 1, not the workshop paper.
- More seeds at n=500 (currently only s42 grokking + s42 standard) β single seed but the trajectory effect is consistent with n=1000 multi-seed pattern.
Reproducibility β regenerate every MI figure
# 1. Run M1 on all 7 runs (cuda, sequential, ~40 min)
bash scripts/run_mi_all_new.sh
# Logs: logs/mi_all_new.log
# 2. Run M4 trajectory (CPU, ~6-10 min per run, can run in parallel)
for d in experiments/runs/20260505-*/; do
nohup python3 -m experiments.mechinterp_m4_ablation \
--run_dir "$d" --data_root data/wilds --device cpu \
--max_samples 800 --subspace_method lda --subspace_dim 32 \
--all_epochs > "logs/m4_$(basename $d).log" 2>&1 &
disown
done
# 3. Run M5 steering (cuda, parallel, ~30-60s per run)
for d in experiments/runs/20260505-*/; do
nohup python3 -m experiments.mechinterp_m5_steering \
--run_dir "$d" --data_root data/wilds --device cuda \
--max_samples 800 > "logs/m5_$(basename $d).log" 2>&1 &
disown
done
# 4. Generate cross-run figures
python3 -m experiments.figure_mi_comparison # probe heatmaps
python3 -m experiments.figure_intervention_comparison # M4+M5 grid
python3 -m experiments.regenerate_all_figures # ungrokking, critical-fraction, overview
# 5. Live monitor for any in-flight run
bash scripts/monitor_all.sh
Honest caveats reviewers will ask about
- "Is this really grokking?" β We do not assert it is the same phenomenon as Power et al. 2022. The training dynamics match the grokking signature (memorize β plateau β OOD jump β reversal); we use the term "grokking-like dynamics" throughout. The MI evidence (deep-layer scrubbing) is the mechanistic upgrade of this claim, not just a temporal coincidence.
- "Standard column is a single point." β Correct. Standard training runs in this codebase do not save periodic checkpoints (only
final.pt). The comparison shows endpoint magnitudes match between grokking and standard; the trajectory is grokking-only. Re-running standard with periodic ckpts is the cleanest fix and is gated on GPU access, which is currently shared. - "Mean grokking peak (0.697) < standard baseline (0.702)." β Acknowledged. The contribution is not "grokking improves OOD on small medical datasets." The contribution is circuit-level evidence for the dynamics + characterization of the regime where the dynamics fail to produce stable OOD gains (i.e. n=300 is below the critical fraction). This is closer to a negative-result-with-mechanism than a positive headline.
Timeline (working backwards from a hypothetical 2-week target)
| Day | Item |
|---|---|
| D-14 (today, 2026-05-05) | MI Figure 1 done; cross-seed Ξ table in this README |
| D-12 | Methods + Intro draft |
| D-9 | Results section + Figure 1/4/5 captions |
| D-7 | Discussion + limitations + related work |
| D-4 | Internal review + figure polish |
| D-1 | Camera-ready PDF, supplementary |
| D-0 | Submit |
Concrete deadlines depend on the workshop CFP β update this table once the official CFP is in.
Timeline (Paper 1)
- Week 1. Env setup, CheXpert access request, read Power 2022 / Varma 2023 / Arjovsky 2019.
- Week 2. PneumoniaMNIST baseline (E1, E2). Go/no-go.
- Week 3-4. CheXpert (E3, E4) + critical fraction curve.
- Week 5-6. Attribution maps (E5).
- Week 7-8. Full ablations + cross-site OOD (E6); E7 / E8 if time.
- Week 9-10. Theory section (linear SCM proposition).
- Week 11-12. Paper draft.
Future roadmap (not active)
CausalGrok (this repo) is Paper 1. Two follow-up papers are planned but deferred until Paper 1 has results. They share theoretical scaffolding and most of the experimental infrastructure, so each one will get cheaper to write than the last.
Paper 1 β CausalGrok (NeurIPS)
"Grokking finds causal features in small medical datasets"
Core claim: the grokking transition = causal discovery, without environment labels
β natural follow-up: what breaks it?
Paper 2 β ImbalanceGrok (NeurIPS / ICML)
"Class imbalance prevents grokking β here's how to fix it"
Core claim: modified critical fraction theory + new training protocol for rare diseases
β natural follow-up: does this hold for pretrained models?
Paper 3 β FoundationGrok (ICLR / MICCAI)
"Grokking dynamics in foundation model fine-tuning"
Core claim: fine-tuning has a different grokking regime β new theory needed
Paper 2 β ImbalanceGrok
Title. Class Imbalance as a Grokking Inhibitor: A New Training Protocol for Rare Disease Detection.
Core idea. Every grokking paper to date uses balanced datasets. Medical imaging is pathologically imbalanced (1:50, 1:100 ratios for rare diseases). The hypothesis is that the minority class β the disease you most care about β has a per-class training fraction that sits permanently below the critical threshold. The network never groks the rare class, which is the exact opposite of what is wanted clinically.
Why it follows from CausalGrok. Paper 1 shows grokking finds causal features. Paper 2 asks the natural failure-mode question: what happens when the conditions for grokking break? Reviewers of Paper 1 will literally ask this.
Theoretical contribution. A modified critical fraction formula under class imbalance:
Ξ±_c(Ο) β Ξ±_c^balanced / Ο^Ξ²
where Ο is the imbalance ratio and Ξ² depends on the difficulty gap between the spurious and causal circuits for the minority class. Implication: at 1:100 imbalance you need exponentially more data than balanced theory predicts β unless you change the training dynamics.
Algorithmic contribution β the ImbalanceGrok protocol.
- Class-reweighted Grokfast: amplify slow-gradient components per class, not globally, so the minority class gets its own grokking trajectory.
- Progressive imbalance schedule: start balanced during the memorization phase, shift to true imbalance only after grokking is detected via feature rank.
- Per-class IRM penalty monitoring instead of aggregate.
Experiments. ISIC 2019 (skin lesions, 7-class severe natural imbalance), LIDC-IDRI (pulmonary nodules, rare positives), synthetic imbalance ablations on MedMNIST from 1:1 to 1:100.
Target venue. NeurIPS or ICML. Strong clinical-impact framing β rare disease detection is among the highest-value problems in medical AI.
Timeline. Start ~2 months after CausalGrok experiments are running; theoretical backbone is shared.
Paper 3 β FoundationGrok
Title. Do Foundation Models Grok? Training Dynamics in the Fine-tuning Regime for Medical Imaging.
Core idea. Standard grokking theory assumes from-scratch training in the lazy regime. Fine-tuning BiomedCLIP / MedSAM / RAD-DINO on 50 labeled scans starts from a representation that is already rich. Does grokking still occur? Does it look different? Does CausalGrok's claim β grokking finds causal features β still hold from pretrained weights?
Unexplored hypothesis. Fine-tuning may exhibit anti-grokking: early apparent generalization (pretrained features are good) followed by forgetting (overfitting to site-specific artifacts in the small labeled set) followed by re-generalization. Or something current theory does not predict.
Theoretical contribution. A modified lazy-to-rich transition condition for pretrained models. The key variable is representation alignment β how well pretrained features already encode the target pathology's causal structure. High alignment β rapid generalization, no grokking needed. Low alignment β a new phenomenon, representation hijacking, where fine-tuning first corrupts the pretrained features before rediscovering them.
Experiments.
- Fine-tune BiomedCLIP, MedSAM, RAD-DINO, DINOv2 on MedMNIST+ at 10, 50, 100, 500 images.
- Track weight norm, feature rank, IRM invariance, alignment between pretrained features and target labels.
- Compare from-scratch grokking vs. fine-tuning grokking vs. standard fine-tuning.
- Headline number: does post-grokking fine-tuning beat foundation model zero-shot? That is the result that will make clinicians pay attention.
Target venue. ICLR or MICCAI. More empirical, less theoretical β but directly relevant to how 90% of medical AI is deployed.
Timeline. Start after Paper 2 is submitted. Most infrastructure is reused; this is largely a dataset + model swap.
Each paper stands alone, and each one makes the previous one more cited.
See docs/EXECUTION_GUIDE.md for the long-form A100 playbook.