CausalGrok / docs /TRAINING_DETAILS.md
nileshsarkar-ai's picture
Sync TRAINING_DETAILS.md (full source + per-run JSONs + logs)
cc24355 verified
# CausalGrok — Complete Training, Evaluation, and Mechanistic-Interpretability Reference
Paper: *Interventional Analysis of Shortcut Geometry Under Grokking-Favorable Training*.
This document is the **complete preservation archive** of the project. It contains the full source of every script that ran, every per-run config and summary, the full training logs of two reference runs, every M5 activation-steering result, the full M6 K-sweep data, and all formulas. All numerical values are read directly from on-disk `config.json` / `results/summary.json` / `results/history.json` / `mechinterp/*.json` / `paper_figures/m6_summary.csv`. All source code is the exact code that produced every reported result.
**Contents**
1. Environment
2. Dataset and data pipeline
3. Model architecture and initialization
4. Hyperparameter tables
5. Loss function (cross-entropy) — formula
6. IRM penalty (diagnostic) — formula
7. Grokfast EMA — formula
8. Training loop and checkpointing
9. Evaluation metrics — formulas
10. Full source: `utils/grokfast.py`
11. Full source: `experiments/causalgrok_camelyon_v2.py`
12. Full source: `experiments/mechinterp_m1.py`
13. Full source: `experiments/mechinterp_m4_ablation.py`
14. Full source: `experiments/mechinterp_m5_steering.py`
15. Full source: `experiments/mechinterp_m6_neuron_ablation.py`
16. Run inventory and summary results (14 runs)
17. Per-run `config.json` and `summary.json` (all 14 runs)
18. Full training log: grokking n=1000 seed=42
19. Full training log: standard n=1000 seed=42
20. M5 — Full activation-steering JSONs (8 runs at n=1000)
21. M5 — Aggregated sweep tables
22. M6 — Full K-sweep results (per-seed, all K)
23. Exact commands
24. Output layout
---
## 1. Environment
| Item | Value |
| --- | --- |
| Python env | `conda env: causalgrok` (Python 3.10) |
| Framework | PyTorch, `timm`, `torchvision`, `scikit-learn`, `numpy`, `wandb` (offline) |
| Model | `timm.create_model("resnet18", pretrained=False, num_classes=2)` |
| Device | CUDA (NVIDIA A100 80GB PCIe) |
| Precision | TF32 (`set_float32_matmul_precision("high")`, `cudnn.benchmark=True`, `allow_tf32=True`) |
| Params | 11,177,538 |
| Dataset | Camelyon17 via WILDS (`utils.camelyon_data.get_camelyon_subsets`, auto-download) |
| Wall time per run | ~5.5 h (grokking n=1000 s42, 3000 epochs, A100) |
## 2. Dataset and data pipeline
Camelyon17 (WILDS), H&E-stained histopathology patches, binary tumor label. Five hospitals; the WILDS split provides train hospitals, an in-distribution validation set, and a held-out OOD test hospital.
- **ID validation**: 33,560 images.
- **OOD test (held-out hospital)**: 85,054 images.
- **Train**: subsampled to `n_train`. The n=1000 seed-42 run drew hospitals `{0, 3, 4}` with 181 / 371 / 448 samples (positive rates 0.53 / 0.48 / 0.50).
Transforms: `Resize((96, 96))`, `ToTensor`, `Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])`. Train loader: `batch_size=32, shuffle=True`; eval loaders: `batch_size=256, shuffle=False`; `num_workers=0`, `pin_memory=True`.
IRM environments — one `{x, y}` dict per unique training hospital, built once from `metadata[:, 0]`.
## 3. Model architecture and initialization
ResNet-18 (`timm`, no ImageNet pretraining), 96×96 input, 2-class head, 11,177,538 parameters. The grokking-favorable regime multiplies every multi-dim weight tensor by `init_scale = 4.0` at initialization; standard uses `init_scale = 1.0` (no rescaling). `avgpool` feature dimension `D = 512`. Six probed stages: `stem, layer1, layer2, layer3, layer4, avgpool`.
## 4. Hyperparameter tables
| Hyperparameter | Standard | Grokking-favorable |
| --- | --- | --- |
| Optimizer | AdamW | AdamW |
| Learning rate | 1e-3 | 1e-3 |
| Weight decay | 1e-4 | 5e-3 (50×) |
| Epochs | 3000 | 3000 |
| Init scale | 1.0 | 4.0 |
| Grokfast EMA | off | on |
| Grokfast alpha (EMA decay) | — | 0.98 |
| Grokfast lamb (slow-grad amplification) | — | 2.0 |
| Gradient clip (max-norm) | 1.0 | 1.0 |
| Batch size | 32 | 32 |
| Image size | 96 × 96 | 96 × 96 |
| `log_every` (metric cadence) | 50 epochs | 50 epochs |
| Checkpoint cadence | 200 epochs | 200 epochs |
| IRM weight in loss | 0.0 | 0.0 |
**Three-axis confound**: weight decay (50× ratio), init scale (4× ratio), and Grokfast EMA (on vs off) differ simultaneously between regimes. No single-axis ablation was run.
## 5. Loss function — cross-entropy
```
L = CE(f_theta(x), y) = - (1/B) * sum_i log [ exp(z_{i, y_i}) / sum_k exp(z_{i, k}) ]
```
where `z = f_theta(x)` are the logits. The training objective is pure CE for every reported run (`irm_weight = 0.0`).
## 6. IRM penalty (diagnostic, not in loss)
IRMv1 penalty (Arjovsky et al. 2019):
```
IRM_e = || grad_{w=1.0} CE( w · f_theta(x^e), y^e ) ||^2
IRM = (1/|E|) sum_{e in E} IRM_e
```
Logged at every checkpoint as `irm_mean`, `irm_var`. Collapses from ~0.20 to ~1e-13 within ~50–150 epochs in every run (a CE-memorization consequence, not invariance learning).
## 7. Grokfast EMA (grokking-favorable only)
Grokfast (Lee et al. 2024, arXiv:2405.20233):
```
g_ema <- alpha * g_ema + (1 - alpha) * g # alpha = 0.98
g <- g + lamb * g_ema # lamb = 2.0
```
Applied after `loss.backward()`, before `optimizer.step()`.
## 8. Training loop and checkpointing
- Loop over `epoch ∈ [1, 3000]`.
- Per minibatch: forward, CE loss, `loss.backward()`, Grokfast filter (grokking only), `clip_grad_norm_(max_norm=1.0)`, `optimizer.step()`.
- Metrics computed every 50 epochs (+ epoch 1) → `results/history.json` (61 rows/run).
- Checkpoints every 200 epochs → 15 `.pt` per run + `final.pt`.
- OOD-aware early stopping exists but defaults off; all reported runs train the full 3000 epochs.
- Grokking detection watches OOD-accuracy plateau-then-jump; `grokking_epoch = -1` (never fires) for every run.
## 9. Evaluation metrics — formulas
| Metric | Definition |
| --- | --- |
| Accuracy | `argmax` accuracy on a loader; `train_acc`, `id_val_acc`, `ood_acc`. `ood_gap = id_val_acc - ood_acc`. |
| Weight norm | `||W|| = sqrt(sum_p ||p||_2^2)` |
| Effective feature rank | `exp(- sum_i ŝ_i log ŝ_i)` where `ŝ_i = s_i / sum_j s_j` are the normalized SVD singular values of `layer4` avgpool features on ≤300 samples. |
| Shortcut ratio | `min(border_conf / center_conf, 10)`. `>1` = stain-reliant, `<1` = tissue-reliant. |
| IRM penalty | see §6 — computed across the 3 training-hospital environments per epoch. |
Summary fields (`summary.json`): `best_id_val`, `best_ood`, `peak_ood_epoch`, `final_ood`, `ood_delta` (= `final_ood − best_ood`, ungrokking signal), `ood_improvement`, `grokking_epoch`, `irm_drop_pct`, `irm_drop_epoch`, `epoch_gap`, `final_weight_norm`, `final_feature_rank`, `final_irm`, `final_shortcut_ratio`, `final_ood_gap`.
---
## 10. Full source: `utils/grokfast.py`
```python
"""
utils.grokfast — accelerated grokking by amplifying slow-varying gradient
components (Lee et al. 2024, arXiv:2405.20233).
Maintain an EMA of gradients across steps; the slow-EMA component
corresponds to the generalising circuit. Adding it back into the live
gradient (scaled by `lamb`) accelerates the grokking transition 20-100×.
"""
from __future__ import annotations
def gradfilter_ema(model, grads_ema, alpha: float = 0.98, lamb: float = 2.0):
"""
Call this AFTER `loss.backward()` and BEFORE `optimizer.step()`.
Args:
model: the network whose gradients we are filtering.
grads_ema: dict {param_name: ema_grad}, or None on the first call.
alpha: EMA decay (0.98 → very slow, emphasises persistent grads).
lamb: amplification factor for the slow component.
Returns:
Updated `grads_ema` dict — pass it back in on the next step.
"""
if grads_ema is None:
grads_ema = {}
for name, p in model.named_parameters():
if p.requires_grad and p.grad is not None:
if name not in grads_ema:
grads_ema[name] = p.grad.data.detach().clone()
else:
grads_ema[name] = (
grads_ema[name] * alpha
+ p.grad.data.detach() * (1 - alpha)
)
p.grad.data = p.grad.data + grads_ema[name] * lamb
return grads_ema
```
---
## 11. Full source: `experiments/causalgrok_camelyon_v2.py`
```python
"""
CausalGrok — Camelyon17 Training Loop v2
Nilesh
KEY CHANGE FROM v1:
OOD test accuracy (H4 — unseen hospital) is now tracked at EVERY
checkpoint, not just at the end. Grokking detection watches OOD acc,
not ID val acc. This is the correct signal.
The paper claim: after ID accuracy converges (fast, expected), the model
undergoes a delayed phase transition in OOD generalization — grokking
the cross-hospital invariant causal features. This co-occurs with a drop
in IRM penalty. That is the grokking we care about for clinical deployment.
Two curves to watch:
val_acc (H3 ID val) — converges fast, expected ~0.86 by ep 50
ood_acc (H4 OOD test) — should plateau then JUMP (the grokking)
Run via:
python -m experiments.causalgrok_camelyon_v2 --condition grokking --n_train 300
"""
from __future__ import annotations
import argparse
import json
import os
import time
from datetime import datetime, timezone
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import timm
try:
import wandb
except ImportError:
wandb = None
from utils.grokfast import gradfilter_ema
from utils.camelyon_data import get_camelyon_subsets
from utils.run_dir import make_run_dir, ensure_run_dir, save_config
# ──────────────────────────────────────────────
# CONFIG
# ──────────────────────────────────────────────
def get_config(condition):
base = dict(
seed=42, n_train=300, batch_size=32, img_size=96,
n_classes=2, log_every=50,
device="cuda" if torch.cuda.is_available() else "cpu",
)
if condition == "standard":
base.update(dict(
condition="standard",
lr=1e-3, weight_decay=1e-4,
# Default 3000 epochs to match grokking config and the
# paper's reported runs; previously defaulted to 300 which
# made the standard baseline trivially under-trained
# relative to grokking. See paper Limitations §M3.
n_epochs=3000, init_scale=1.0, use_grokfast=False,
))
elif condition == "grokking":
base.update(dict(
condition="grokking",
lr=1e-3, weight_decay=5e-3,
n_epochs=3000, init_scale=4.0, use_grokfast=True,
grokfast_alpha=0.98, grokfast_lamb=2.0,
))
return base
# ──────────────────────────────────────────────
# WILDS-SAFE METRICS
# All handle the (imgs, labels, metadata) 3-tuple WILDS batch format.
# ──────────────────────────────────────────────
@torch.no_grad()
def accuracy_wilds(model, loader, device, max_samples=None):
model.eval()
correct = total = 0
for batch in loader:
imgs = batch[0].to(device)
labels = batch[1].squeeze().long().to(device)
preds = model(imgs).argmax(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
if max_samples and total >= max_samples:
break
return correct / max(total, 1)
@torch.no_grad()
def weight_norm_fn(model):
return sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
@torch.no_grad()
def feature_rank_wilds(model, loader, device, n=300):
model.eval()
feats = []
def hook_fn(module, input, output):
avg_pool = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1))
feats.append(avg_pool.view(avg_pool.size(0), -1).cpu())
hook = model.layer4[-1].register_forward_hook(hook_fn)
count = 0
for batch in loader:
model(batch[0].to(device))
count += batch[0].size(0)
if count >= n:
break
hook.remove()
if not feats:
return float("nan")
F_mat = torch.cat(feats)[:n]
try:
_, s, _ = torch.svd(F_mat)
s = s / (s.sum() + 1e-10)
return torch.exp(-(s * torch.log(s + 1e-10)).sum()).item()
except Exception:
return float("nan")
@torch.no_grad()
def shortcut_ratio_wilds(model, loader, device, n_samples=200):
"""
Stain shortcut proxy: compare model confidence on center crop
(tissue — causal features) vs. border region (stain — spurious).
sc > 1.0 = relying on border stain more than tissue (shortcut)
sc < 1.0 = relying on tissue center more than stain (causal)
The transition from > 1.0 to < 1.0 during training is the
attribution-level signature of the grokking transition.
"""
model.eval()
cc, bc = [], []
count = 0
for batch in loader:
if count >= n_samples:
break
imgs = batch[0].to(device)
B, C, H, W = imgs.shape
hs, he = H // 4, 3 * H // 4
ws, we = W // 4, 3 * W // 4
center = F.interpolate(
imgs[:, :, hs:he, ws:we], size=(H, W),
mode="bilinear", align_corners=False
)
border = imgs.clone()
border[:, :, hs:he, ws:we] = 0.0
cc.append(F.softmax(model(center), 1).max(1).values.mean().item())
bc.append(F.softmax(model(border), 1).max(1).values.mean().item())
count += imgs.size(0)
cconf = float(np.mean(cc)) if cc else 0.5
bconf = float(np.mean(bc)) if bc else 0.5
return cconf, bconf
def irm_penalty_wilds(model, envs, device):
"""
IRMv1 penalty across TRAINING hospital environments (H0-H2).
Diagnostic version: uses create_graph=False, returns floats. Used as a
monitoring metric only (logged per epoch).
"""
model.eval()
penalties = []
for env in envs:
w = torch.tensor(1.0, requires_grad=True, device=device)
logits = model(env["x"]) * w
loss = F.cross_entropy(logits, env["y"])
grad = torch.autograd.grad(loss, w, create_graph=False)[0]
penalties.append(grad.item() ** 2)
t = torch.tensor(penalties)
return t.mean().item(), t.var().item()
def irm_penalty_train_time(logits_list, y_list):
"""
IRMv1 penalty for use INSIDE the training loss (differentiable).
Splits a batch by environment, computes per-env loss with a virtual
scale variable, takes the squared gradient of each per-env loss w.r.t.
that scale, returns the mean across envs.
Args:
logits_list: list of (per-env) logits tensors
y_list: list of (per-env) label tensors
Returns:
scalar tensor (differentiable), the IRM penalty contribution.
"""
penalty = 0.0
n = 0
for logits, y in zip(logits_list, y_list):
if logits.shape[0] == 0:
continue
scale = torch.tensor(1.0, requires_grad=True, device=logits.device)
loss = F.cross_entropy(logits * scale, y)
grad = torch.autograd.grad(loss, scale, create_graph=True)[0]
penalty = penalty + grad ** 2
n += 1
if n == 0:
return torch.tensor(0.0, device=logits_list[0].device)
return penalty / n
def eval_irm_penalty_wilds(model, id_val_loader, ood_test_loader, device):
"""
IRM penalty evaluated on HELD-OUT environments (H3 and H4).
This avoids the measurement artifact of training on H0-H2 where loss→0.
HIGH penalty = model relies on hospital-discriminating features = shortcuts.
LOW penalty = model ignores hospital labels = causal features.
"""
model.eval()
penalties = []
# Create environment views from eval data
for loader, hospital_label in [
(id_val_loader, "H3"),
(ood_test_loader, "H4"),
]:
xs, ys = [], []
count = 0
with torch.no_grad():
for batch in loader:
imgs = batch[0].to(device)
labels = batch[1].squeeze().long().to(device)
xs.append(model(imgs))
ys.append(labels)
count += imgs.size(0)
if count >= 500:
break
if xs:
x = torch.cat(xs)
y = torch.cat(ys)
w = torch.tensor(1.0, requires_grad=True, device=device)
logits = x * w
loss = F.cross_entropy(logits, y)
try:
grad = torch.autograd.grad(loss, w, create_graph=False)[0]
penalties.append(grad.item() ** 2)
except:
penalties.append(float("nan"))
if penalties and not any(np.isnan(p) for p in penalties):
return float(np.mean(penalties)), float(np.var(penalties))
else:
return float("nan"), float("nan")
# ──────────────────────────────────────────────
# DATA
# ──────────────────────────────────────────────
class TransformWrapper:
def __init__(self, dataset, transform):
self.dataset = dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img, label, metadata = self.dataset[idx]
return self.transform(img), label, metadata
def get_dataloaders(cfg, data_root):
transform = transforms.Compose([
transforms.Resize((cfg["img_size"], cfg["img_size"])),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets(
root_dir=data_root, download=True)
# Subsample training set
torch.manual_seed(cfg["seed"])
indices = torch.randperm(len(train_ds))[:cfg["n_train"]]
train_subset = Subset(train_ds, indices)
# Wrap with TransformWrapper to apply transforms
train_subset = TransformWrapper(train_subset, transform)
id_val_ds = TransformWrapper(id_val_ds, transform)
ood_test_ds = TransformWrapper(ood_test_ds, transform)
train_loader = DataLoader(train_subset, batch_size=cfg["batch_size"],
shuffle=True, num_workers=0, pin_memory=True)
id_val_loader = DataLoader(id_val_ds, batch_size=256,
shuffle=False, num_workers=0, pin_memory=True)
ood_test_loader = DataLoader(ood_test_ds, batch_size=256,
shuffle=False, num_workers=0, pin_memory=True)
return train_loader, id_val_loader, ood_test_loader, train_subset
def get_hospital_environments(train_subset, device):
"""
Build IRM environments from ground-truth hospital labels.
Returns list of {x, y} dicts — one per unique hospital in the subset.
Hospitals in Camelyon17 train split: 0, 1, 2.
"""
loader = DataLoader(train_subset, batch_size=512,
shuffle=False, num_workers=4)
all_imgs, all_labels, all_meta = [], [], []
for imgs, labels, meta in loader:
all_imgs.append(imgs)
all_labels.append(labels.squeeze().long())
all_meta.append(meta)
all_imgs = torch.cat(all_imgs)
all_labels = torch.cat(all_labels)
hospitals = torch.cat(all_meta)[:, 0].long() # field 0 = hospital ID
envs = []
for h in torch.unique(hospitals):
mask = hospitals == h
n = mask.sum().item()
envs.append({
"x": all_imgs[mask].to(device),
"y": all_labels[mask].to(device),
"hospital": int(h),
})
pos_rate = all_labels[mask].float().mean().item()
print(f" Env hospital={int(h)}: {n} samples, "
f"positive rate={pos_rate:.2f}")
return envs
# ──────────────────────────────────────────────
# MODEL
# ──────────────────────────────────────────────
def build_model(cfg):
model = timm.create_model("resnet18", pretrained=False,
num_classes=cfg["n_classes"])
if cfg["init_scale"] != 1.0:
with torch.no_grad():
for name, p in model.named_parameters():
if "weight" in name and p.dim() > 1:
p.data *= cfg["init_scale"]
return model.to(cfg["device"])
# ──────────────────────────────────────────────
# TRAIN
# ──────────────────────────────────────────────
def train(cfg, model, train_loader, id_val_loader, ood_test_loader,
envs, optimizer, run_dir):
criterion = nn.CrossEntropyLoss()
grads_ema = None
history = []
best_id_val = 0.0
best_ood = 0.0
peak_ood_epoch = None # Epoch where best_ood was achieved
grok_epoch = None
irm_base = None
history_path = os.path.join(run_dir, "results", "history.json")
grad_clip = cfg.get("grad_clip", 1.0)
# Grokking detection parameters.
# We watch OOD accuracy (H4), not ID val accuracy (H3).
# ID val converges fast (expected). OOD is what should grok.
plateau_window = 10
plateau_eps = 0.01
# Ungrokking early stopping parameters.
# If OOD peaks then declines, stop at the peak rather than training to convergence.
ood_patience = cfg.get("ood_patience", 20) # checkpoints to wait before stopping
ood_min_delta = cfg.get("ood_min_delta", 0.01) # minimum improvement threshold
use_ood_early_stop = cfg.get("use_ood_early_stop", False)
print(f"\n{'='*60}")
print(f" {cfg['condition'].upper()} | Camelyon17 v2 | {cfg['n_epochs']} epochs")
print(f" WD={cfg['weight_decay']} | α={cfg['init_scale']} | n={cfg['n_train']}")
print(f" Tracking: ID val (H3) + OOD test (H4) at every checkpoint")
print(f" Grokking detection: watching OOD acc, not ID val acc")
print(f" IRM envs: {len(envs)} hospitals")
print(f"{'='*60}", flush=True)
irm_weight = float(cfg.get("irm_weight", 0.0))
use_irm_in_loss = irm_weight > 0.0
if use_irm_in_loss:
print(f" IRM-in-loss: ENABLED, alpha={irm_weight}", flush=True)
else:
print(f" IRM-in-loss: disabled (CE-only training; IRM penalty is diagnostic)", flush=True)
for epoch in range(1, cfg["n_epochs"] + 1):
# ── Train step ────────────────────────────────────────────────
model.train()
loss_sum = n_b = 0
for imgs, labels, metadata in train_loader:
imgs = imgs.to(cfg["device"])
labels = labels.squeeze().long().to(cfg["device"])
optimizer.zero_grad()
logits = model(imgs)
ce_loss = criterion(logits, labels)
if use_irm_in_loss:
# Split this batch by training hospital (H0/H1/H2) and
# compute IRMv1 penalty as a differentiable scalar.
hosp_ids = metadata[:, 0].long().to(cfg["device"])
logits_per_env, y_per_env = [], []
for h in [0, 1, 2]:
mask = (hosp_ids == h)
if mask.sum() < 2:
continue
logits_per_env.append(logits[mask])
y_per_env.append(labels[mask])
if len(logits_per_env) >= 2:
irm_term = irm_penalty_train_time(logits_per_env, y_per_env)
loss = ce_loss + irm_weight * irm_term
else:
loss = ce_loss
else:
loss = ce_loss
loss.backward()
if cfg.get("use_grokfast"):
grads_ema = gradfilter_ema(
model, grads_ema,
alpha=cfg.get("grokfast_alpha", 0.98),
lamb=cfg.get("grokfast_lamb", 2.0))
if grad_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=grad_clip)
optimizer.step()
loss_sum += loss.item()
n_b += 1
# ── Checkpoint metrics ────────────────────────────────────────
if epoch % cfg["log_every"] == 0 or epoch == 1:
tr_acc = accuracy_wilds(model, train_loader, cfg["device"])
id_acc = accuracy_wilds(model, id_val_loader, cfg["device"])
ood_acc = accuracy_wilds(model, ood_test_loader, cfg["device"]) # KEY
wn = weight_norm_fn(model)
fr = feature_rank_wilds(model, id_val_loader, cfg["device"])
irm_m, irm_v = irm_penalty_wilds(model, envs, cfg["device"])
cconf, bconf = shortcut_ratio_wilds(
model, id_val_loader, cfg["device"])
if irm_base is None:
irm_base = irm_m
# ── OOD grokking detection ────────────────────────────────
# Require sustained plateau in OOD acc before the jump.
# The ID val acc plateau is expected and not grokking.
if grok_epoch is None and len(history) >= plateau_window:
last = history[-plateau_window:]
ref = last[-1]["ood_acc"]
flat = sum(1 for r in last
if abs(r["ood_acc"] - ref) < plateau_eps)
if flat >= plateau_window - 2 and ood_acc > best_ood + 0.05:
grok_epoch = epoch
irm_drop = (irm_base - irm_m) / (irm_base + 1e-8) * 100
print(f"\n *** OOD GROKKING at epoch {epoch} ***")
print(f" OOD: {best_ood:.3f} → {ood_acc:.3f} | "
f"IRM drop: {irm_drop:.1f}%", flush=True)
if id_acc > best_id_val: best_id_val = id_acc
if ood_acc > best_ood:
best_ood = ood_acc
peak_ood_epoch = epoch # Track when peak was achieved
sc_ratio = min(bconf / (cconf + 1e-8), 10.0)
# OOD gap: how much worse is OOD vs ID?
# This should shrink at the grokking transition.
ood_gap = id_acc - ood_acc
row = dict(
epoch = epoch,
train_loss = loss_sum / n_b,
train_acc = tr_acc,
id_val_acc = id_acc,
ood_acc = ood_acc, # ← primary grokking signal
ood_gap = ood_gap, # ← should narrow at transition
weight_norm = wn,
feature_rank = fr,
irm_mean = irm_m,
irm_var = irm_v,
center_conf = cconf,
border_conf = bconf,
shortcut_ratio = sc_ratio,
grokking_detected = grok_epoch is not None,
)
history.append(row)
if wandb:
wandb.log(row)
with open(history_path, "w") as f:
json.dump(history, f, indent=2)
# Save periodic checkpoint for M1 analysis (every 200 epochs)
if epoch % 200 == 0:
ckpt_dir = os.path.join(run_dir, "checkpoints")
os.makedirs(ckpt_dir, exist_ok=True)
ckpt_path = os.path.join(ckpt_dir, f"ep{epoch:05d}.pt")
torch.save(model.state_dict(), ckpt_path)
print(f" ✓ Checkpoint → ep{epoch:05d}.pt", flush=True)
# ── OOD-aware early stopping (if ungrokking detected) ───────
# If OOD peaks then declines, stop at the peak rather than full epochs.
if use_ood_early_stop and peak_ood_epoch is not None and len(history) >= ood_patience:
recent_ood = [r["ood_acc"] for r in history[-ood_patience:]]
ood_trend = max(recent_ood) - min(recent_ood)
if ood_acc < best_ood - ood_min_delta:
print(f"\n *** EARLY STOP (OOD declining) at epoch {epoch} ***", flush=True)
print(f" Peak OOD: {best_ood:.4f} at epoch {peak_ood_epoch}", flush=True)
print(f" Current: {ood_acc:.4f} ({ood_acc-best_ood:+.4f})", flush=True)
# Save peak checkpoint separately for clinical deployment
if peak_ood_epoch and peak_ood_epoch % 200 == 0:
peak_src = os.path.join(run_dir, "checkpoints", f"ep{peak_ood_epoch:05d}.pt")
peak_dst = os.path.join(run_dir, "checkpoints", "peak_ood.pt")
if os.path.exists(peak_src):
import shutil
shutil.copy(peak_src, peak_dst)
print(f" Saved peak → checkpoints/peak_ood.pt", flush=True)
break # Exit training loop
print(f" ep {epoch:5d} | "
f"tr {tr_acc:.3f} | "
f"id {id_acc:.3f} | "
f"ood {ood_acc:.3f} | "
f"gap {ood_gap:+.3f} | " # + means OOD worse than ID
f"‖W‖ {wn:.1f} | "
f"rank {fr:.1f} | "
f"IRM {irm_m:.4f} | "
f"sc {sc_ratio:.2f}x",
flush=True)
# ── Final summary ─────────────────────────────────────────────────
# One final OOD eval at the very end
final_ood = accuracy_wilds(model, ood_test_loader, cfg["device"])
if wandb:
wandb.log({"final_ood_acc": final_ood,
"grokking_epoch": grok_epoch or -1})
# Decision numbers
irm_drop_pct = float("nan")
irm_drop_ep = epoch_gap = -1
if history:
irm0 = history[0]["irm_mean"]
irm_min = min(r["irm_mean"] for r in history)
if irm0:
irm_drop_pct = (irm0 - irm_min) / (irm0 + 1e-8) * 100
if len(history) > 1:
biggest = 0.0
for prev, cur in zip(history[:-1], history[1:]):
d = abs(cur["irm_mean"] - prev["irm_mean"])
if d > biggest:
biggest = d
irm_drop_ep = cur["epoch"]
if grok_epoch and irm_drop_ep > 0:
epoch_gap = abs(grok_epoch - irm_drop_ep)
# OOD grokking: did OOD acc improve significantly after ID convergence?
# Measure: max OOD acc in last 20% of training vs. OOD acc when ID
# first plateaued (epoch ~200-300 for standard training).
ood_early = np.mean([r["ood_acc"] for r in history[:5]]) if history else 0
ood_late = np.mean([r["ood_acc"] for r in history[-5:]]) if history else 0
ood_improvement = ood_late - ood_early
# Ungrokking detection: did OOD collapse after peaking?
ood_delta = final_ood - best_ood # Negative = ungrokking
summary = dict(
run_id = cfg["run_id"],
condition = cfg["condition"],
n_train = cfg["n_train"],
seed = cfg["seed"],
best_id_val = best_id_val,
best_ood = best_ood,
peak_ood_epoch = peak_ood_epoch or -1, # When peak was achieved
final_ood = final_ood,
ood_delta = ood_delta, # final - best (ungrokking signal)
ood_improvement = ood_improvement, # ← key: did OOD grok?
grokking_epoch = grok_epoch or -1,
irm_drop_pct = irm_drop_pct,
irm_drop_epoch = irm_drop_ep,
epoch_gap = epoch_gap,
final_weight_norm = history[-1]["weight_norm"] if history else None,
final_feature_rank= history[-1]["feature_rank"] if history else None,
final_irm = history[-1]["irm_mean"] if history else None,
final_shortcut_ratio = history[-1]["shortcut_ratio"] if history else None,
final_ood_gap = history[-1]["ood_gap"] if history else None,
)
with open(os.path.join(run_dir, "results", "summary.json"), "w") as f:
json.dump(summary, f, indent=2)
torch.save(model.state_dict(),
os.path.join(run_dir, "checkpoints", "final.pt"))
print(f"\n Best ID val (H3): {best_id_val:.4f}")
print(f" Best OOD (H4): {best_ood:.4f}")
print(f" OOD improvement: {ood_improvement:+.4f} ← did OOD grok?")
print(f" Grokking at: {grok_epoch}")
print(f" IRM drop: {irm_drop_pct:.1f}%",
flush=True)
return history
# ──────────────────────────────────────────────
# MAIN
# ──────────────────────────────────────────────
def main():
p = argparse.ArgumentParser()
p.add_argument("--condition", default="grokking",
choices=["standard", "grokking"])
p.add_argument("--n_train", type=int, default=300)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--log_every", type=int, default=50)
p.add_argument("--wandb_project", default="causalgrok")
p.add_argument("--wandb_mode", default="offline",
choices=["online", "offline", "disabled"])
p.add_argument("--run_dir", default=None)
p.add_argument("--data_root", default="data/wilds")
p.add_argument("--weight_decay", type=float, default=None)
p.add_argument("--init_scale", type=float, default=None)
p.add_argument("--n_epochs", type=int, default=None)
p.add_argument("--lr", type=float, default=None)
p.add_argument("--grokfast", choices=["on", "off"], default=None)
p.add_argument("--grad_clip", type=float, default=1.0)
p.add_argument("--irm_weight", type=float, default=0.0,
help="IRMv1 penalty weight added to training loss "
"(0 = pure cross-entropy / diagnostic-only IRM).")
args = p.parse_args()
cfg = get_config(args.condition)
cfg.update(n_train=args.n_train, seed=args.seed,
log_every=args.log_every, grad_clip=args.grad_clip)
if args.weight_decay is not None: cfg["weight_decay"] = args.weight_decay
if args.init_scale is not None: cfg["init_scale"] = args.init_scale
if args.n_epochs is not None: cfg["n_epochs"] = args.n_epochs
if args.lr is not None: cfg["lr"] = args.lr
if args.grokfast is not None: cfg["use_grokfast"] = (args.grokfast == "on")
cfg["irm_weight"] = args.irm_weight
if cfg["device"] == "cuda":
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(cfg["seed"])
np.random.seed(cfg["seed"])
if args.run_dir is None:
run_dir, run_id = make_run_dir(
["camelyon_v2", cfg["condition"],
f"n{cfg['n_train']}", f"s{cfg['seed']}"])
else:
run_dir = args.run_dir
ensure_run_dir(run_dir)
run_id = os.path.basename(os.path.normpath(run_dir))
cfg["run_id"] = run_id
cfg["run_dir"] = run_dir
save_config(cfg, run_dir)
if wandb:
wandb.init(project=args.wandb_project, config=cfg, name=run_id,
mode=args.wandb_mode, dir=run_dir)
print(f"\nDevice: {cfg['device']}")
print(f"Run ID: {run_id}")
print(f"Started: {datetime.now(timezone.utc).isoformat()}", flush=True)
train_loader, id_val_loader, ood_test_loader, train_subset = \
get_dataloaders(cfg, args.data_root)
envs = get_hospital_environments(train_subset, cfg["device"])
model = build_model(cfg)
print(f"Train: {len(train_subset)} | "
f"ID val (H3): {len(id_val_loader.dataset)} | "
f"OOD test (H4): {len(ood_test_loader.dataset)}")
print(f"Params: {sum(p.numel() for p in model.parameters()):,}",
flush=True)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=cfg["lr"], weight_decay=cfg["weight_decay"])
t0 = time.time()
train(cfg, model, train_loader, id_val_loader, ood_test_loader,
envs, optimizer, run_dir)
print(f"\nWall time: {(time.time()-t0)/60:.1f} min", flush=True)
if wandb:
wandb.finish()
if __name__ == "__main__":
main()
```
---
## 12. Full source: `experiments/mechinterp_m1.py`
```python
"""
CausalGrok — M1: Layer-wise Linear Probing
Nilesh
The mechanistic claim:
Before grokking: hospital probe HIGH (model uses stain shortcut),
tumor probe LOW in early layers
At transition: hospital probe DROPS, tumor probe RISES
After grokking: inverted — tumor high, hospital low
If OOD acc jump + hospital probe drop + tumor probe rise
all happen at the same epoch → mechanistic claim proven.
That's Figure 2 of the paper.
Usage:
# Run on all saved checkpoints from a run
python -m experiments.mechinterp_m1 \
--run_dir experiments/runs/<run_id> \
--data_root data/wilds
# Run on latest checkpoint only (quick check while training)
python -m experiments.mechinterp_m1 \
--run_dir experiments/runs/<run_id> \
--data_root data/wilds \
--latest_only
# Run on ALL camelyon_v2 grokking runs
python -m experiments.mechinterp_m1 \
--all_runs \
--data_root data/wilds
Output per run:
experiments/runs/<run_id>/mechinterp/
m1_probe_heatmap.png ← epoch × layer, hospital probe acc
m1_tumor_heatmap.png ← epoch × layer, tumor probe acc
m1_probe_curves.png ← hospital vs tumor probe over epochs (layer 4)
m1_probe_data.json ← raw numbers for paper tables
"""
from __future__ import annotations
import argparse
import glob
import json
import os
from typing import Dict, List, Optional
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import matplotlib
import matplotlib.pyplot as plt
import timm
import warnings
warnings.filterwarnings("ignore")
matplotlib.rcParams.update({"font.size": 11, "figure.dpi": 150})
# ──────────────────────────────────────────────
# RESNET-18 LAYER HOOKS
# Extract features after each of the 6 measurable stages:
# stem → layer1 → layer2 → layer3 → layer4 → avgpool
# ──────────────────────────────────────────────
LAYER_NAMES = [
"stem", # After initial conv + bn + relu + maxpool
"layer1", # ResNet block 1 (64 channels)
"layer2", # ResNet block 2 (128 channels)
"layer3", # ResNet block 3 (256 channels)
"layer4", # ResNet block 4 (512 channels)
"avgpool", # Global average pool — penultimate representation
]
def register_hooks(model):
"""
Register forward hooks on all 6 extraction points.
Returns (hooks, features_dict).
"""
features = {name: [] for name in LAYER_NAMES}
hooks = []
def make_hook(name):
def hook_fn(module, input, output):
if output.dim() == 4:
feat = output.mean(dim=[2, 3])
else:
feat = output.view(output.size(0), -1)
features[name].append(feat.detach().cpu())
return hook_fn
hooks.append(model.maxpool.register_forward_hook(make_hook("stem")))
hooks.append(model.layer1.register_forward_hook(make_hook("layer1")))
hooks.append(model.layer2.register_forward_hook(make_hook("layer2")))
hooks.append(model.layer3.register_forward_hook(make_hook("layer3")))
hooks.append(model.layer4.register_forward_hook(make_hook("layer4")))
hooks.append(model.global_pool.register_forward_hook(make_hook("avgpool")))
return hooks, features
def extract_features(model, loader, device, max_samples=1000):
"""
Run forward pass and collect features at all 6 layers.
"""
model.eval()
hooks, feat_dict = register_hooks(model)
all_hospital = []
all_tumor = []
count = 0
with torch.no_grad():
for batch in loader:
imgs = batch[0].to(device)
labels = batch[1].squeeze().long()
metadata = batch[2]
model(imgs)
all_hospital.append(metadata[:, 0].long())
all_tumor.append(labels)
count += imgs.size(0)
if count >= max_samples:
break
for h in hooks:
h.remove()
features = {k: torch.cat(v).numpy() for k, v in feat_dict.items()}
hospital_ids = torch.cat(all_hospital).numpy()
tumor_labels = torch.cat(all_tumor).numpy()
n = min(max_samples, len(hospital_ids))
features = {k: v[:n] for k, v in features.items()}
hospital_ids = hospital_ids[:n]
tumor_labels = tumor_labels[:n]
return features, hospital_ids, tumor_labels
def train_probe(X_train, y_train, X_val, y_val):
"""
Train logistic regression probe on frozen features.
"""
if len(np.unique(y_train)) < 2:
return 0.5
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
clf = LogisticRegression(
max_iter=500,
C=1.0,
solver="lbfgs",
multi_class="auto",
n_jobs=-1,
)
try:
clf.fit(X_train, y_train)
return clf.score(X_val, y_val)
except Exception:
return float("nan")
# ──────────────────────────────────────────────
# CHECKPOINT DISCOVERY
# ──────────────────────────────────────────────
def find_checkpoints(run_dir: str) -> List[tuple]:
"""
Find all checkpoints in a run directory.
Returns list of (epoch, checkpoint_path) sorted by epoch.
"""
ckpt_dir = os.path.join(run_dir, "checkpoints")
if not os.path.isdir(ckpt_dir):
return []
checkpoints = []
# Periodic checkpoints: ep050.pt, ep100.pt, etc.
for f in sorted(glob.glob(os.path.join(ckpt_dir, "ep*.pt"))):
epoch_str = os.path.basename(f).replace("ep", "").replace(".pt", "")
try:
epoch = int(epoch_str)
checkpoints.append((epoch, f))
except ValueError:
continue
# Final checkpoint
final = os.path.join(ckpt_dir, "final.pt")
if os.path.isfile(final):
hist_path = os.path.join(run_dir, "results", "history.json")
if os.path.isfile(hist_path):
try:
hist = json.load(open(hist_path))
epoch = hist[-1]["epoch"] if hist else 9999
except Exception:
epoch = 9999
else:
epoch = 9999
checkpoints.append((epoch, final))
return sorted(checkpoints, key=lambda x: x[0])
def load_model_from_checkpoint(ckpt_path: str, n_classes: int = 2,
device: str = "cuda") -> nn.Module:
model = timm.create_model("resnet18", pretrained=False,
num_classes=n_classes)
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state, strict=True)
model.eval()
return model.to(device)
# ──────────────────────────────────────────────
# MAIN PROBE ANALYSIS
# ──────────────────────────────────────────────
def run_probe_analysis(run_dir: str, data_root: str,
device: str = "cuda",
max_samples: int = 800,
latest_only: bool = False) -> Optional[Dict]:
"""
For each checkpoint in a run, extract features at all 6 layers
and train hospital + tumor probes.
"""
from utils.camelyon_data import get_camelyon_subsets
cfg_path = os.path.join(run_dir, "config.json")
if not os.path.isfile(cfg_path):
print(f" No config.json in {run_dir}, skipping")
return None
cfg = json.load(open(cfg_path))
condition = cfg.get("condition", "unknown")
n_train = cfg.get("n_train", 300)
seed = cfg.get("seed", 42)
print(f"\n{'='*55}")
print(f" M1 Probe Analysis: {os.path.basename(run_dir)}")
print(f" condition={condition}, n_train={n_train}, seed={seed}")
print(f"{'='*55}")
checkpoints = find_checkpoints(run_dir)
if not checkpoints:
print(f" No checkpoints found — skipping")
return None
if latest_only:
checkpoints = checkpoints[-1:]
print(f" Found {len(checkpoints)} checkpoints: "
f"epochs {[e for e,_ in checkpoints]}")
print(f" Hospital probe: fits on training data (H0-H2), "
f"evaluates on H3 and H4 separately")
# ── Data ─────────────────────────────────────────────────────────
transform = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_ds, id_val_ds, ood_test_ds, full_ds = get_camelyon_subsets(
root_dir=data_root, download=False)
# Wrap datasets with transform (WILDS returns PIL images)
class _TransformWrapper:
def __init__(self, dataset, transform):
self.dataset = dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img, label, metadata = self.dataset[idx]
return self.transform(img), label, metadata
id_val_t = _TransformWrapper(id_val_ds, transform)
ood_test_t = _TransformWrapper(ood_test_ds, transform)
train_t = _TransformWrapper(train_ds, transform)
torch.manual_seed(seed)
probe_idx = torch.randperm(len(id_val_t))[:max_samples // 2]
ood_idx = torch.randperm(len(ood_test_t))[:max_samples // 2]
train_idx = torch.randperm(len(train_t))[:max_samples]
probe_loader = DataLoader(
Subset(id_val_t, probe_idx),
batch_size=128, shuffle=False, num_workers=0)
ood_loader = DataLoader(
Subset(ood_test_t, ood_idx),
batch_size=128, shuffle=False, num_workers=0)
train_loader = DataLoader(
Subset(train_t, train_idx),
batch_size=128, shuffle=False, num_workers=0)
# ── Results storage ──────────────────────────────────────────────
results = {
"run_id": os.path.basename(run_dir),
"condition": condition,
"n_train": n_train,
"seed": seed,
"epochs": [],
"layers": LAYER_NAMES,
"hospital_probe_id": [], # Hospital accuracy on H3
"hospital_probe_ood": [], # Hospital accuracy on H4
"tumor_probe_id": [], # Tumor accuracy on H3
"tumor_probe_ood": [], # Tumor accuracy on H4
}
# ── Per-checkpoint analysis ───────────────────────────────────────
for epoch, ckpt_path in checkpoints:
print(f"\n Epoch {epoch} | {os.path.basename(ckpt_path)}")
try:
model = load_model_from_checkpoint(
ckpt_path, n_classes=2, device=device)
except Exception as e:
print(f" Failed to load checkpoint: {e}")
continue
# Extract features from all three datasets
feats_train, hosp_train, tumor_train = extract_features(
model, train_loader, device, max_samples=max_samples)
feats_id, hosp_id, tumor_id = extract_features(
model, probe_loader, device, max_samples=max_samples // 2)
feats_ood, hosp_ood, tumor_ood = extract_features(
model, ood_loader, device, max_samples=max_samples // 2)
epoch_hosp_id = []
epoch_hosp_ood = []
epoch_tumor_id = []
epoch_tumor_ood = []
for layer_name in LAYER_NAMES:
# Fit probes on training features, evaluate on H3 and H4
X_train_layer = feats_train[layer_name]
X_id_layer = feats_id[layer_name]
X_ood_layer = feats_ood[layer_name]
# Hospital probe: can model distinguish hospitals H0-H2?
# If yes on H3/H4 → stain is encoded
h_acc_id = train_probe(X_train_layer, hosp_train,
X_id_layer, hosp_id)
h_acc_ood = train_probe(X_train_layer, hosp_train,
X_ood_layer, hosp_ood)
# Tumor probe: can model distinguish tumor vs normal?
t_acc_id = train_probe(X_train_layer, tumor_train,
X_id_layer, tumor_id)
t_acc_ood = train_probe(X_train_layer, tumor_train,
X_ood_layer, tumor_ood)
epoch_hosp_id.append(h_acc_id)
epoch_hosp_ood.append(h_acc_ood)
epoch_tumor_id.append(t_acc_id)
epoch_tumor_ood.append(t_acc_ood)
print(f" {layer_name:8s}: "
f"hosp_H3={h_acc_id:.3f} hosp_H4={h_acc_ood:.3f} "
f"tumor_H3={t_acc_id:.3f} tumor_H4={t_acc_ood:.3f}")
results["epochs"].append(epoch)
results["hospital_probe_id"].append(epoch_hosp_id)
results["hospital_probe_ood"].append(epoch_hosp_ood)
results["tumor_probe_id"].append(epoch_tumor_id)
results["tumor_probe_ood"].append(epoch_tumor_ood)
del model
# ── Save raw data ─────────────────────────────────────────────────
out_dir = os.path.join(run_dir, "mechinterp")
os.makedirs(out_dir, exist_ok=True)
data_path = os.path.join(out_dir, "m1_probe_data.json")
with open(data_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\n Probe data → {data_path}")
# ── Plots ─────────────────────────────────────────────────────────
_plot_probe_heatmaps(results, out_dir)
_plot_probe_curves(results, out_dir)
print(f" Figures → {out_dir}/")
return results
def _plot_probe_heatmaps(results: Dict, out_dir: str):
"""
Epoch (x) × layer (y), color = probe accuracy.
Hospital probe shown on H3 (held-in held-out hospital, classes overlap
with training). The H4 version is degenerate by construction since the
probe is fit on the training-hospital class set and H4 is not in it
(hospital_probe_ood ≡ 0 across all epochs / layers).
Tumor probe shown on H4 (truly OOD hospital) since tumor labels are
binary and shared across hospitals — H4 captures the causal-feature
transferability we care about.
"""
epochs = results["epochs"]
layers = results["layers"]
if not epochs:
return
hosp_matrix = np.array(results["hospital_probe_id"]) # H3 — has signal
tumor_matrix = np.array(results["tumor_probe_ood"]) # H4 — true OOD
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for ax, matrix, title, cmap in [
(axes[0], hosp_matrix, "Hospital probe on H3 (shortcut recoverability)\nHigh = stain still encoded = BAD", "Reds"),
(axes[1], tumor_matrix, "Tumor probe on H4 (causal, OOD)\nHigh = causal feature transfers = GOOD", "Greens"),
]:
im = ax.imshow(
matrix.T,
aspect="auto",
cmap=cmap,
vmin=0.0, vmax=1.0,
interpolation="nearest",
origin="lower",
)
ax.set_xticks(range(len(epochs)))
ax.set_xticklabels(epochs, rotation=45, ha="right", fontsize=8)
ax.set_yticks(range(len(layers)))
ax.set_yticklabels(layers, fontsize=9)
ax.set_xlabel("Training epoch")
ax.set_ylabel("ResNet layer")
ax.set_title(title, fontsize=10, fontweight="bold")
plt.colorbar(im, ax=ax, label="Probe accuracy")
fig.suptitle(
f"M1 — Layer-wise Linear Probing: {results['run_id']}\n"
"Circuit signature: deep-layer hospital-probe drop (Reds) + sustained tumor recoverability (Greens)",
fontsize=10, y=1.02
)
plt.tight_layout()
out = os.path.join(out_dir, "m1_probe_heatmap.png")
plt.savefig(out, bbox_inches="tight")
plt.close()
def _plot_probe_curves(results: Dict, out_dir: str):
"""
Line plot per-layer: hospital probe (H3 — recoverability) + tumor probe
(H4 — causal-feature transfer to truly unseen hospital), with OOD accuracy
from history.json overlaid.
"""
epochs = results["epochs"]
run_dir = os.path.join(out_dir, "..")
layers = results["layers"]
avgpool_idx = layers.index("avgpool")
layer2_idx = layers.index("layer2")
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for ax, layer_idx, layer_label in [
(axes[0], avgpool_idx, "avgpool (penultimate)"),
(axes[1], layer2_idx, "layer2 (early)"),
]:
hosp = [results["hospital_probe_id"][i][layer_idx]
for i in range(len(epochs))]
tumor = [results["tumor_probe_ood"][i][layer_idx]
for i in range(len(epochs))]
ax.plot(epochs, hosp, "r-o", markersize=4, lw=2,
label="Hospital probe on H3 (shortcut recoverability ↓ want)")
ax.plot(epochs, tumor, "g-s", markersize=4, lw=2,
label="Tumor probe on H4 (causal transfer ↑ want)")
hist_path = os.path.join(run_dir, "results", "history.json")
if os.path.isfile(hist_path):
try:
hist = json.load(open(hist_path))
hist_eps = [r["epoch"] for r in hist]
ood_accs = [r.get("ood_acc", float("nan")) for r in hist]
ax.plot(hist_eps, ood_accs, "b--", lw=1.5, alpha=0.7,
label="OOD accuracy (H4)")
except Exception:
pass
ax.axhline(0.5, color="gray", ls=":", lw=1, alpha=0.5,
label="Chance (0.5)")
ax.set_xlabel("Training epoch")
ax.set_ylabel("Probe / OOD accuracy")
ax.set_title(f"Layer: {layer_label}", fontweight="bold")
ax.legend(fontsize=9)
ax.set_ylim([0, 1.05])
ax.grid(alpha=0.3)
fig.suptitle(
f"M1 — Probe Curves: {results['run_id']}\n"
"Hospital recoverability (H3) drops in deep layers + tumor transfers (H4) — circuit signature",
fontsize=10, y=1.02
)
plt.tight_layout()
out = os.path.join(out_dir, "m1_probe_curves.png")
plt.savefig(out, bbox_inches="tight")
plt.close()
def main():
p = argparse.ArgumentParser(
description="M1: Layer-wise linear probing for CausalGrok")
p.add_argument("--run_dir", default=None,
help="Single run directory to analyze")
p.add_argument("--all_runs", action="store_true",
help="Analyze all camelyon_v2 grokking runs")
p.add_argument("--data_root", default="data/wilds")
p.add_argument("--device", default="cuda")
p.add_argument("--max_samples", type=int, default=800)
p.add_argument("--latest_only", action="store_true",
help="Analyze only latest checkpoint (quick check)")
args = p.parse_args()
if args.all_runs:
run_dirs = sorted(glob.glob(
"experiments/runs/*camelyon_v2*grokking*"))
print(f"Found {len(run_dirs)} grokking runs")
all_results = []
for rd in run_dirs:
r = run_probe_analysis(rd, args.data_root,
device=args.device,
max_samples=args.max_samples,
latest_only=args.latest_only)
if r:
all_results.append(r)
if all_results:
os.makedirs("paper_figures", exist_ok=True)
with open("paper_figures/m1_all_probes.json", "w") as f:
json.dump(all_results, f, indent=2)
print(f"\nCombined → paper_figures/m1_all_probes.json")
elif args.run_dir:
run_probe_analysis(args.run_dir, args.data_root,
device=args.device,
max_samples=args.max_samples,
latest_only=args.latest_only)
else:
print("Specify --run_dir <path> or --all_runs")
if __name__ == "__main__":
main()
```
---
## 13. Full source: `experiments/mechinterp_m4_ablation.py`
```python
"""M4 — Representation Ablation: causal intervention on the shortcut subspace.
Pipeline:
1. Pick a checkpoint (peak-OOD epoch by default).
2. Extract features at avgpool (or `--layer`) for train (H0-H2) + OOD (H4) splits.
3. Fit a hospital-classification logistic-regression probe on train features.
The probe's weight rows define the *shortcut subspace* in feature space.
4. Build the projector P = W^T (W W^T)^-1 W onto that subspace and define
`ablate(h) = h - P h`.
5. Re-classify OOD images with the *same* trained classifier head, fed:
(a) raw features h — baseline OOD accuracy
(b) ablated features h - Ph — post-intervention OOD accuracy
6. Also report:
(c) shortcut accuracy (probe.score on h vs h-Ph)
(d) tumor probe accuracy on h vs h-Ph (sanity: the causal feature
should survive the intervention)
(e) head's tumor classification accuracy on H4 with raw vs ablated features
If the intervention is causal:
- shortcut probe accuracy: collapses
- OOD accuracy: improves (or at least doesn't decay as much)
- tumor probe accuracy: largely preserved
Usage
-----
python -m experiments.mechinterp_m4_ablation \\
--run_dir experiments/runs/<id> \\
--data_root data/wilds \\
--layer avgpool \\
[--epoch 50] # default: peak_ood_epoch from summary.json
[--max_samples 1000]
Output:
<run_dir>/mechinterp/m4_ablation_<layer>_ep<E>.json
<run_dir>/mechinterp/m4_ablation_<layer>_ep<E>.png
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Dict, Tuple
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
# Re-use M1 helpers — hooks, model loader, feature extraction, ckpt discovery.
from experiments.mechinterp_m1 import (
register_hooks,
extract_features,
load_model_from_checkpoint,
find_checkpoints,
)
from utils.camelyon_data import get_camelyon_subsets
class _TransformWrapper:
def __init__(self, dataset, transform):
self.dataset = dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img, label, metadata = self.dataset[idx]
return self.transform(img), label, metadata
def _build_loaders(data_root: str, max_samples: int, seed: int = 42):
transform = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets(
root_dir=data_root, download=False
)
train_t = _TransformWrapper(train_ds, transform)
ood_t = _TransformWrapper(ood_test_ds, transform)
torch.manual_seed(seed)
train_idx = torch.randperm(len(train_t))[:max_samples]
ood_idx = torch.randperm(len(ood_t))[:max_samples // 2]
train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128,
shuffle=False, num_workers=0)
ood_loader = DataLoader(Subset(ood_t, ood_idx), batch_size=128,
shuffle=False, num_workers=0)
return train_loader, ood_loader
def _select_epoch(run_dir: Path, requested: int | None) -> Tuple[int, Path]:
ckpts = find_checkpoints(str(run_dir))
if not ckpts:
raise FileNotFoundError(f"No checkpoints in {run_dir}/checkpoints/")
if requested is not None:
for ep, p in ckpts:
if ep == requested:
return ep, Path(p)
raise ValueError(f"Requested epoch {requested} not in checkpoints "
f"({[ep for ep, _ in ckpts]})")
# default: peak OOD epoch from summary.json
summary_path = run_dir / "results" / "summary.json"
peak = None
if summary_path.exists():
s = json.loads(summary_path.read_text())
peak = s.get("peak_ood_epoch", None)
if peak is not None and peak > 0:
# nearest periodic checkpoint
nearest = min(ckpts, key=lambda x: abs(x[0] - peak))
return nearest[0], Path(nearest[1])
# fall back to last checkpoint
return ckpts[-1][0], Path(ckpts[-1][1])
def _build_projector(W: np.ndarray) -> np.ndarray:
"""W has shape (k, d). Returns P (d, d) projecting onto rowspace(W)."""
# Use SVD for a stable orthonormal basis of rowspace
U, s, Vt = np.linalg.svd(W, full_matrices=False)
# rowspace basis = Vt rows where singular values > tol
tol = max(W.shape) * np.finfo(s.dtype).eps * (s.max() if s.size else 0.0)
keep = s > tol
basis = Vt[keep] # (k', d)
return basis.T @ basis # (d, d) projector onto rowspace
def _build_shortcut_subspace(
X: np.ndarray, hospital_ids: np.ndarray,
method: str = "lda", subspace_dim: int = 32
) -> np.ndarray:
"""Return a (k, d) basis whose row-span is the 'shortcut subspace'.
method='probe' — k = (n_classes - 1) probe weight rows (small subspace).
method='lda' — k = subspace_dim top between-class directions: take
per-hospital means in feature space, center them,
and run SVD. This gives a rank-bounded but data-driven
subspace that captures hospital-discriminating variance.
method='pca-class' — top-PCs of features colored by hospital (mean-removed
per class), giving us the variance directions that
mostly reflect within-hospital structure × class.
"""
if method == "probe":
clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1)
clf.fit(X, hospital_ids)
return clf.coef_
if method == "lda":
classes = np.unique(hospital_ids)
global_mean = X.mean(axis=0, keepdims=True)
between = []
for c in classes:
mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True)
between.append(mu_c - global_mean)
between = np.vstack(between) # (n_classes, d)
# Augment with random hospital-correlated directions to grow rank up
# to subspace_dim — use top PCs of *centered-by-hospital-mean* features.
if subspace_dim > between.shape[0]:
# within-hospital residuals
residuals = []
for c in classes:
mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True)
residuals.append(X[hospital_ids == c] - mu_c)
R = np.vstack(residuals)
# PCA on residuals — these are within-hospital directions; remove
# them from the shortcut subspace by KEEPING only the between-class
# directions. So we just return between as-is, plus the top PCs of
# the *original* features projected onto the orthogonal complement
# of `between` IF the user wants more dims.
U, s, Vt = np.linalg.svd(X - global_mean, full_matrices=False)
top = Vt[:subspace_dim]
# Score each PC by how much it correlates with hospital-id variance
# (one-hot expansion); keep top by that correlation.
one_hot = np.eye(len(classes))[
np.searchsorted(classes, hospital_ids)
] # (N, n_classes)
proj = (X - global_mean) @ top.T # (N, subspace_dim)
corrs = np.array([
np.max(np.abs([np.corrcoef(proj[:, k], one_hot[:, c])[0, 1]
for c in range(len(classes))]))
for k in range(subspace_dim)
])
# take the top-k most-hospital-correlated PCs
order = np.argsort(-np.nan_to_num(corrs))
top_hosp = top[order[:subspace_dim]]
# combine: between-class means + top-hospital-correlated PCs
return np.vstack([between, top_hosp])
return between
raise ValueError(f"Unknown method: {method}")
def _classifier_logits_from_features(
model: nn.Module, features: np.ndarray, layer: str, device: str
) -> np.ndarray:
"""Apply the *post-`layer`* part of the network to the (modified) features
and return the model's binary-classification logits.
For ResNet, `avgpool` features have shape (N, C). The classifier head
`model.fc` (timm: `model.get_classifier()`) maps C → 2. For non-avgpool
layers we do not currently support full propagation — caller should use
layer='avgpool' for OOD-accuracy interventions."""
if layer != "avgpool":
raise NotImplementedError(
"Re-applying the classifier head from intermediate spatial layers "
"is not yet supported. Use --layer avgpool for the head-level "
"ablation."
)
# Find the classifier head (timm convention: model.fc or model.get_classifier())
if hasattr(model, "get_classifier"):
head = model.get_classifier()
elif hasattr(model, "fc"):
head = model.fc
elif hasattr(model, "classifier"):
head = model.classifier
else:
raise RuntimeError("Could not locate classifier head on the model.")
head = head.to(device).eval()
with torch.no_grad():
x = torch.tensor(features, dtype=torch.float32, device=device)
logits = head(x).cpu().numpy()
return logits
def _accuracy(logits: np.ndarray, labels: np.ndarray) -> float:
if logits.ndim == 1 or logits.shape[1] == 1:
pred = (logits.flatten() > 0).astype(int)
else:
pred = logits.argmax(axis=1)
return float((pred == labels).mean())
def run_ablation(
run_dir: Path,
data_root: str,
layer: str = "avgpool",
epoch: int | None = None,
max_samples: int = 1000,
device: str = "cuda",
subspace_method: str = "lda",
subspace_dim: int = 32,
) -> Dict:
epoch, ckpt_path = _select_epoch(run_dir, epoch)
print(f"\n M4 — Representation Ablation")
print(f" run_dir : {run_dir.name}")
print(f" epoch : {epoch} ({ckpt_path.name})")
print(f" layer : {layer}")
# Load model and dataloaders
model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device)
model.eval()
register_hooks(model)
cfg_path = run_dir / "config.json"
seed = 42
if cfg_path.exists():
seed = json.loads(cfg_path.read_text()).get("seed", 42)
train_loader, ood_loader = _build_loaders(data_root, max_samples, seed=seed)
# Extract features
print(f" Extracting features ({max_samples} samples per split)...")
feats_train, hosp_train, tumor_train = extract_features(
model, train_loader, device, max_samples=max_samples
)
feats_ood, hosp_ood, tumor_ood = extract_features(
model, ood_loader, device, max_samples=max_samples // 2
)
if layer not in feats_train:
raise KeyError(f"Layer '{layer}' not in extracted features "
f"({list(feats_train.keys())})")
X_tr = np.asarray(feats_train[layer]) # (N_tr, D)
X_ood = np.asarray(feats_ood[layer]) # (N_ood, D)
if X_tr.ndim > 2: # spatial map; flatten
X_tr = X_tr.reshape(X_tr.shape[0], -1)
X_ood = X_ood.reshape(X_ood.shape[0], -1)
# Normalize features (probe is sensitive to scale; classifier head was
# trained on un-normalized features so we keep two parallel pipelines).
scaler = StandardScaler().fit(X_tr)
X_tr_n = scaler.transform(X_tr)
X_ood_n = scaler.transform(X_ood)
# ──────────── 1. Fit hospital probe + build shortcut subspace
print(f" Fitting hospital probe on H0/H1/H2 train features...")
hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1)
hosp_clf.fit(X_tr_n, hosp_train)
hosp_acc_train = hosp_clf.score(X_tr_n, hosp_train)
# Build a richer shortcut subspace via LDA-style between-class +
# hospital-correlated top PCs. This catches more shortcut variance than
# the (n_classes - 1)-D probe-rowspace alone.
W = _build_shortcut_subspace(X_tr_n, np.asarray(hosp_train),
method=subspace_method,
subspace_dim=subspace_dim)
P = _build_projector(W) # (D, D)
rank_subspace = int(np.linalg.matrix_rank(P, tol=1e-8))
print(f" Shortcut subspace: dim={rank_subspace} method={subspace_method} "
f"(probe train acc {hosp_acc_train:.3f})")
# ──────────── 2. Build ablated versions of features
# Apply the projection in the *normalized* feature space, then un-scale
# for re-feeding to the classifier head (which was trained on raw features).
def ablate_norm(X_n):
return X_n - X_n @ P.T
X_ood_ablated_n = ablate_norm(X_ood_n)
# un-scale
X_ood_ablated = scaler.inverse_transform(X_ood_ablated_n)
# Sanity probe metrics
print(f" Re-fitting tumor probe on train features...")
tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1)
tumor_clf.fit(X_tr_n, tumor_train)
tumor_acc_train = tumor_clf.score(X_tr_n, tumor_train)
# Probe accuracies on raw vs ablated OOD features
hosp_acc_ood_raw = hosp_clf.score(X_ood_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan")
hosp_acc_ood_ablated = hosp_clf.score(X_ood_ablated_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan")
tumor_acc_ood_raw = tumor_clf.score(X_ood_n, tumor_ood)
tumor_acc_ood_ablated = tumor_clf.score(X_ood_ablated_n, tumor_ood)
# ──────────── 3. Head-level OOD classification accuracy
print(f" Re-classifying OOD with model head (raw vs ablated features)...")
logits_raw = _classifier_logits_from_features(model, X_ood, layer, device)
logits_ablated = _classifier_logits_from_features(model, X_ood_ablated, layer, device)
head_acc_raw = _accuracy(logits_raw, tumor_ood)
head_acc_ablated = _accuracy(logits_ablated, tumor_ood)
# ──────────── 4. Pack + report
result = {
"run_id": run_dir.name,
"epoch": epoch,
"layer": layer,
"max_samples": max_samples,
"shortcut_subspace_dim": rank_subspace,
"hospital_probe_train_acc": hosp_acc_train,
"tumor_probe_train_acc": tumor_acc_train,
"hospital_probe_ood_raw": hosp_acc_ood_raw,
"hospital_probe_ood_ablated": hosp_acc_ood_ablated,
"tumor_probe_ood_raw": tumor_acc_ood_raw,
"tumor_probe_ood_ablated": tumor_acc_ood_ablated,
"head_ood_acc_raw": head_acc_raw,
"head_ood_acc_ablated": head_acc_ablated,
"intervention_effect": {
"shortcut_collapse": hosp_acc_ood_raw - hosp_acc_ood_ablated,
"ood_improvement": head_acc_ablated - head_acc_raw,
"tumor_preservation": tumor_acc_ood_ablated - tumor_acc_ood_raw,
},
}
print(f"\n RESULTS")
print(f" hospital probe (OOD): {hosp_acc_ood_raw:.3f} → {hosp_acc_ood_ablated:.3f} "
f"(Δ {result['intervention_effect']['shortcut_collapse']:+.3f})")
print(f" tumor probe (OOD) : {tumor_acc_ood_raw:.3f} → {tumor_acc_ood_ablated:.3f} "
f"(Δ {result['intervention_effect']['tumor_preservation']:+.3f})")
print(f" head OOD acc : {head_acc_raw:.3f} → {head_acc_ablated:.3f} "
f"(Δ {result['intervention_effect']['ood_improvement']:+.3f})")
return result
def plot_ablation(result: Dict, out_path: Path):
metrics = ["hospital_probe_ood", "tumor_probe_ood", "head_ood_acc"]
raw_keys = ["hospital_probe_ood_raw", "tumor_probe_ood_raw", "head_ood_acc_raw"]
ablated_keys = ["hospital_probe_ood_ablated", "tumor_probe_ood_ablated", "head_ood_acc_ablated"]
labels = ["Hospital probe\n(↓ = causal effect)",
"Tumor probe\n(stable = good)",
"Head OOD acc\n(↑ = causal effect)"]
raws = [result[k] for k in raw_keys]
ablateds = [result[k] for k in ablated_keys]
fig, ax = plt.subplots(figsize=(9, 5))
x = np.arange(len(metrics))
w = 0.35
b1 = ax.bar(x - w / 2, raws, w, label="raw features", color="#444")
b2 = ax.bar(x + w / 2, ablateds, w, label="shortcut-ablated", color="#c33")
for bars in (b1, b2):
for b in bars:
ax.text(b.get_x() + b.get_width() / 2, b.get_height() + 0.005,
f"{b.get_height():.3f}", ha="center", va="bottom", fontsize=9)
ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=9)
ax.set_ylim(0, 1.05); ax.set_ylabel("Accuracy")
ax.set_title(f"M4 — Causal Ablation of Shortcut Subspace\n"
f"{result['run_id']} • ep{result['epoch']} • layer={result['layer']} "
f"• subspace dim={result['shortcut_subspace_dim']}",
fontsize=10, fontweight="bold")
ax.legend(loc="upper right")
ax.grid(alpha=0.3, axis="y")
plt.tight_layout()
fig.savefig(out_path, dpi=180, bbox_inches="tight")
plt.close(fig)
def main():
p = argparse.ArgumentParser()
p.add_argument("--run_dir", required=True)
p.add_argument("--data_root", default="data/wilds")
p.add_argument("--layer", default="avgpool",
choices=["avgpool"]) # head-level intervention only at avgpool
p.add_argument("--epoch", type=int, default=None,
help="Specific checkpoint epoch; default = peak_ood_epoch from summary.json")
p.add_argument("--max_samples", type=int, default=1000)
p.add_argument("--device", default="cuda")
p.add_argument("--subspace_method", default="lda",
choices=["lda", "probe"],
help="lda = LDA-style between-class + hospital-correlated PCs; "
"probe = LR probe row-space (small, often only 2-D)")
p.add_argument("--subspace_dim", type=int, default=32,
help="Target subspace dim for lda method")
p.add_argument("--all_epochs", action="store_true",
help="Sweep across all periodic checkpoints")
args = p.parse_args()
run_dir = Path(args.run_dir)
out_dir = run_dir / "mechinterp"
out_dir.mkdir(parents=True, exist_ok=True)
if args.all_epochs:
# Sweep across every periodic checkpoint, build a trajectory.
ckpts = find_checkpoints(str(run_dir))
# de-duplicate (final.pt may share epoch with last ep*.pt)
seen = set(); uniq = []
for ep, p in ckpts:
if ep in seen:
continue
seen.add(ep); uniq.append((ep, p))
traj = []
for ep, _ in uniq:
try:
r = run_ablation(
run_dir=run_dir, data_root=args.data_root, layer=args.layer,
epoch=ep, max_samples=args.max_samples, device=args.device,
subspace_method=args.subspace_method,
subspace_dim=args.subspace_dim,
)
traj.append(r)
except Exception as e:
print(f" [skip ep{ep}] {e}")
out = out_dir / f"m4_ablation_{args.layer}_trajectory.json"
out.write_text(json.dumps(traj, indent=2))
plot_trajectory(traj, out.with_suffix(".png"))
print(f"\n → {out}")
print(f" → {out.with_suffix('.png')}")
return
result = run_ablation(
run_dir=run_dir,
data_root=args.data_root,
layer=args.layer,
epoch=args.epoch,
max_samples=args.max_samples,
device=args.device,
subspace_method=args.subspace_method,
subspace_dim=args.subspace_dim,
)
base = out_dir / f"m4_ablation_{args.layer}_ep{result['epoch']:05d}"
(base.with_suffix(".json")).write_text(json.dumps(result, indent=2))
plot_ablation(result, base.with_suffix(".png"))
print(f"\n → {base.with_suffix('.json')}")
print(f" → {base.with_suffix('.png')}")
def plot_trajectory(traj, out_path: Path):
"""Plot the intervention effect across training epochs."""
eps = [r["epoch"] for r in traj]
head_raw = [r["head_ood_acc_raw"] for r in traj]
head_abl = [r["head_ood_acc_ablated"] for r in traj]
tum_raw = [r["tumor_probe_ood_raw"] for r in traj]
tum_abl = [r["tumor_probe_ood_ablated"] for r in traj]
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Panel A: head OOD acc raw vs ablated
ax = axes[0]
ax.plot(eps, head_raw, "k-o", lw=2, label="raw features")
ax.plot(eps, head_abl, "r-s", lw=2, label="shortcut-ablated features")
ax.fill_between(eps, head_raw, head_abl,
where=[a > b for a, b in zip(head_abl, head_raw)],
color="seagreen", alpha=0.3, label="ablation helps")
ax.fill_between(eps, head_raw, head_abl,
where=[a < b for a, b in zip(head_abl, head_raw)],
color="salmon", alpha=0.3, label="ablation hurts")
ax.set_xlabel("Training epoch"); ax.set_ylabel("OOD (H4) head accuracy")
ax.set_title("Head OOD accuracy: raw vs shortcut-ablated", fontweight="bold")
ax.legend(fontsize=9); ax.grid(alpha=0.3)
# Panel B: tumor probe survival
ax = axes[1]
ax.plot(eps, tum_raw, "k-o", lw=2, label="raw features")
ax.plot(eps, tum_abl, "g-s", lw=2, label="shortcut-ablated features")
ax.set_xlabel("Training epoch"); ax.set_ylabel("Tumor probe OOD accuracy")
ax.set_title("Tumor probe survival under ablation\n(stable line = causal feature preserved)",
fontweight="bold")
ax.legend(fontsize=9); ax.grid(alpha=0.3); ax.set_ylim(0.4, 1.0)
rid = traj[0]["run_id"] if traj else "?"
layer = traj[0]["layer"] if traj else "?"
fig.suptitle(f"M4 — Causal Ablation Trajectory: {rid} • layer={layer}",
fontsize=11, fontweight="bold")
plt.tight_layout()
fig.savefig(out_path, dpi=180, bbox_inches="tight")
plt.close(fig)
if __name__ == "__main__":
main()
```
---
## 14. Full source: `experiments/mechinterp_m5_steering.py`
```python
"""M5 — Activation Steering: causally manipulate the shortcut direction.
For one checkpoint (default: peak_ood_epoch from summary.json), we:
1. Extract avgpool features for train (H0-H2) + OOD (H4) splits.
2. Identify the dominant shortcut direction `v_s` as the top eigenvector
of the between-hospital covariance (LDA's first projection direction).
3. Sweep α ∈ {-3, -2, -1, 0, +1, +2, +3} and apply
h' = h + α · σ_align · v_s
where σ_align is the std of features projected onto v_s (so α counts
in 'standard deviations of shortcut activation').
4. Re-classify OOD with the original head.
5. Re-fit hospital + tumor probes on the steered features and report
accuracy curves.
Strong mechanistic claim if:
- tumor-head OOD acc declines monotonically as |α| grows
- hospital-probe acc on steered features rises with |α|
- tumor-probe acc on steered features stays approximately flat (the
*causal* feature isn't aligned with the shortcut direction)
Usage
-----
python -m experiments.mechinterp_m5_steering \\
--run_dir experiments/runs/<id> \\
--data_root data/wilds \\
[--epoch 50] # default: peak_ood_epoch from summary.json
[--max_samples 1000] [--alphas " -3,-2,-1,0,1,2,3 "]
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Dict, List, Tuple
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from experiments.mechinterp_m1 import (
register_hooks, extract_features, load_model_from_checkpoint,
find_checkpoints,
)
from experiments.mechinterp_m4_ablation import (
_select_epoch, _build_loaders,
_classifier_logits_from_features, _accuracy,
)
def _top_lda_direction(X: np.ndarray, hospital_ids: np.ndarray) -> np.ndarray:
"""Return a unit vector aligned with the dominant between-hospital direction
in feature space (LDA-1)."""
classes = np.unique(hospital_ids)
global_mean = X.mean(axis=0, keepdims=True)
means = np.vstack([
X[hospital_ids == c].mean(axis=0, keepdims=True) - global_mean
for c in classes
])
# SVD: rows of Vt are the orthonormal between-class directions ranked by
# singular value (variance explained between hospitals).
U, s, Vt = np.linalg.svd(means, full_matrices=False)
return Vt[0] # (D,) unit vector
def run_steering(
run_dir: Path,
data_root: str,
epoch: int | None = None,
max_samples: int = 1000,
device: str = "cuda",
alphas: List[float] = None,
) -> Dict:
if alphas is None:
alphas = [-3.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 3.0]
epoch, ckpt_path = _select_epoch(run_dir, epoch)
print(f"\n M5 — Activation Steering")
print(f" run_dir : {run_dir.name}")
print(f" epoch : {epoch} ({ckpt_path.name})")
print(f" alphas : {alphas}")
model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device)
model.eval()
register_hooks(model)
cfg_path = run_dir / "config.json"
seed = 42
if cfg_path.exists():
seed = json.loads(cfg_path.read_text()).get("seed", 42)
train_loader, ood_loader = _build_loaders(data_root, max_samples, seed=seed)
print(f" Extracting features...")
feats_train, hosp_train, tumor_train = extract_features(
model, train_loader, device, max_samples=max_samples
)
feats_ood, hosp_ood, tumor_ood = extract_features(
model, ood_loader, device, max_samples=max_samples // 2
)
layer = "avgpool"
X_tr = np.asarray(feats_train[layer]); X_tr = X_tr.reshape(X_tr.shape[0], -1)
X_ood = np.asarray(feats_ood[layer]); X_ood = X_ood.reshape(X_ood.shape[0], -1)
hosp_train = np.asarray(hosp_train)
hosp_ood = np.asarray(hosp_ood)
tumor_train = np.asarray(tumor_train)
tumor_ood = np.asarray(tumor_ood)
# Standardize for probe-fitting; un-standardize when feeding head
scaler = StandardScaler().fit(X_tr)
X_tr_n = scaler.transform(X_tr)
X_ood_n = scaler.transform(X_ood)
# 1. Top LDA direction in normalized feature space
v = _top_lda_direction(X_tr_n, hosp_train) # (D,) unit vec
# Std of training features projected onto v (scale unit for α)
sigma = float(np.std(X_tr_n @ v))
print(f" Top hospital direction v_s : ‖v‖={np.linalg.norm(v):.3f}, "
f"σ(X_tr·v)={sigma:.3f}")
# 2. Pre-fit reference probes on un-steered train features
hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1).fit(X_tr_n, hosp_train)
tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1).fit(X_tr_n, tumor_train)
# 3. Sweep α
sweep = []
for alpha in alphas:
# Steer features along v in normalized space, then un-scale for the head.
X_ood_steered_n = X_ood_n + alpha * sigma * v[None, :]
X_ood_steered = scaler.inverse_transform(X_ood_steered_n)
# Head OOD accuracy
logits = _classifier_logits_from_features(model, X_ood_steered, layer, device)
head_acc = _accuracy(logits, tumor_ood)
# Probe accuracies on steered features
if len(np.unique(hosp_ood)) > 1:
hosp_acc = hosp_clf.score(X_ood_steered_n, hosp_ood)
else:
hosp_acc = float("nan")
tumor_acc = tumor_clf.score(X_ood_steered_n, tumor_ood)
sweep.append({
"alpha": float(alpha),
"head_ood_acc": head_acc,
"hospital_probe": hosp_acc,
"tumor_probe": tumor_acc,
})
print(f" α={alpha:+.2f} head_ood={head_acc:.3f} "
f"hosp_probe={hosp_acc if not np.isnan(hosp_acc) else 'nan':<5} "
f"tumor_probe={tumor_acc:.3f}")
return {
"run_id": run_dir.name,
"epoch": epoch,
"layer": layer,
"max_samples": max_samples,
"v_norm": float(np.linalg.norm(v)),
"sigma": sigma,
"sweep": sweep,
}
def plot_steering(result: Dict, out_path: Path):
sweep = result["sweep"]
a = [r["alpha"] for r in sweep]
head = [r["head_ood_acc"] for r in sweep]
hosp = [r["hospital_probe"] for r in sweep]
tumor = [r["tumor_probe"] for r in sweep]
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
# Panel A — Head OOD acc vs α
ax = axes[0]
ax.plot(a, head, "k-o", lw=2, ms=7)
ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5)
ax.set_xlabel("Steering coefficient α (in σ-units of shortcut direction)")
ax.set_ylabel("Head OOD (H4) accuracy")
ax.set_title("Causal effect of steering activations along v_s\n"
"(monotonic decline as |α| grows = causal evidence)",
fontweight="bold", fontsize=10)
ax.grid(alpha=0.3)
ax.set_ylim(0.4, max(0.85, max(head) + 0.05))
# Panel B — Probe accuracies vs α
ax = axes[1]
ax.plot(a, hosp, "r-s", lw=2, ms=7, label="Hospital probe (↑ with |α| = good)")
ax.plot(a, tumor, "g-^", lw=2, ms=7, label="Tumor probe (flat = causal disjoint)")
ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5)
ax.set_xlabel("Steering coefficient α")
ax.set_ylabel("Probe accuracy")
ax.set_title("Probe responses to steering", fontweight="bold", fontsize=10)
ax.legend(loc="best", fontsize=9); ax.grid(alpha=0.3)
ax.set_ylim(0, 1.05)
fig.suptitle(f"M5 — Activation Steering: {result['run_id']} "
f"• ep{result['epoch']} • layer={result['layer']}",
fontsize=11, fontweight="bold")
plt.tight_layout()
fig.savefig(out_path, dpi=180, bbox_inches="tight")
plt.close(fig)
def main():
p = argparse.ArgumentParser()
p.add_argument("--run_dir", required=True)
p.add_argument("--data_root", default="data/wilds")
p.add_argument("--epoch", type=int, default=None)
p.add_argument("--max_samples", type=int, default=1000)
p.add_argument("--device", default="cuda")
p.add_argument("--alphas", default=None,
help="Comma-separated α values, e.g. ' -3,-2,-1,0,1,2,3 '")
p.add_argument("--all_epochs", action="store_true",
help="Sweep across all periodic checkpoints; output a trajectory")
args = p.parse_args()
alphas = None
if args.alphas is not None:
alphas = [float(x) for x in args.alphas.split(",")]
run_dir = Path(args.run_dir)
out_dir = run_dir / "mechinterp"
out_dir.mkdir(parents=True, exist_ok=True)
if args.all_epochs:
# Trajectory mode: run M5 at every periodic checkpoint
ckpts = find_checkpoints(str(run_dir))
seen = set(); uniq = []
for ep, p in ckpts:
if ep in seen:
continue
seen.add(ep); uniq.append((ep, p))
traj = []
for ep, _ in uniq:
try:
r = run_steering(
run_dir=run_dir, data_root=args.data_root, epoch=ep,
max_samples=args.max_samples, device=args.device, alphas=alphas,
)
traj.append(r)
except Exception as e:
print(f" [skip ep{ep}] {e}")
out = out_dir / "m5_steering_trajectory.json"
out.write_text(json.dumps(traj, indent=2))
print(f"\n → {out}")
return
result = run_steering(
run_dir=run_dir, data_root=args.data_root, epoch=args.epoch,
max_samples=args.max_samples, device=args.device, alphas=alphas,
)
base = out_dir / f"m5_steering_ep{result['epoch']:05d}"
base.with_suffix(".json").write_text(json.dumps(result, indent=2))
plot_steering(result, base.with_suffix(".png"))
print(f"\n → {base.with_suffix('.json')}")
print(f" → {base.with_suffix('.png')}")
if __name__ == "__main__":
main()
```
---
## 15. Full source: `experiments/mechinterp_m6_neuron_ablation.py`
```python
"""M6 — Neuron-level Ablation (the textbook reviewer-asked intervention).
Pipeline:
1. At a chosen checkpoint (default: peak_ood_epoch), extract avgpool
features for train (H0-H2) and OOD (H4) splits.
2. Score each of the 512 avgpool channels by *how predictive its activation
is of hospital ID*: we use a one-vs-rest logistic-regression coefficient
per channel × class as the per-neuron shortcut score:
score_c = max_h |β_{h,c}| (β = coefficients of LR fit per channel)
↑ score_c → channel c is more strongly stain-shortcut-aligned.
3. Sweep top-K ∈ {0, 8, 16, 32, 64, 128} ablated neurons (zero out their
activations) and measure:
- head OOD acc (raw vs ablated)
- hospital-probe acc on raw vs ablated features
- tumor-probe acc on raw vs ablated features
4. Strong mechanistic claim:
- hospital-probe acc collapses sharply with K (these neurons are
carrying hospital info)
- head OOD acc *improves* (or at least preserves) at small K (the
model was using shortcut neurons to harm OOD)
- tumor-probe acc stays flat (causal info is distributed elsewhere)
Usage
-----
python -m experiments.mechinterp_m6_neuron_ablation \\
--run_dir experiments/runs/<id> \\
--data_root data/wilds \\
[--epoch 50] [--max_samples 1000] \\
[--ks "0,4,8,16,32,64,128,256"]
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Dict, List
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from experiments.mechinterp_m1 import (
register_hooks, extract_features, load_model_from_checkpoint,
)
from experiments.mechinterp_m4_ablation import (
_select_epoch, _TransformWrapper,
_classifier_logits_from_features, _accuracy,
)
from utils.camelyon_data import get_camelyon_subsets
def _build_loaders_with_id(data_root: str, max_samples: int, seed: int = 42):
"""Like M4's _build_loaders but also returns an ID validation loader so
we can track ID acc and compute the OOD/ID degradation ratio."""
transform = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets(
root_dir=data_root, download=False
)
train_t = _TransformWrapper(train_ds, transform)
id_t = _TransformWrapper(id_val_ds, transform)
ood_t = _TransformWrapper(ood_test_ds, transform)
torch.manual_seed(seed)
train_idx = torch.randperm(len(train_t))[:max_samples]
id_idx = torch.randperm(len(id_t))[:max_samples // 2]
ood_idx = torch.randperm(len(ood_t))[:max_samples // 2]
train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128,
shuffle=False, num_workers=0)
id_loader = DataLoader(Subset(id_t, id_idx), batch_size=128,
shuffle=False, num_workers=0)
ood_loader = DataLoader(Subset(ood_t, ood_idx), batch_size=128,
shuffle=False, num_workers=0)
return train_loader, id_loader, ood_loader
def _per_neuron_shortcut_scores(X_n: np.ndarray, hosp: np.ndarray) -> np.ndarray:
"""Return a (D,) array — score per channel c, larger = more hospital-predictive.
Uses a 1-feature-at-a-time log-reg fit's |coef| would be dominated by feature
scale; instead we fit a single multiclass LR over all features and use the
L2 norm of (β_{:,c}) — the column norm of the LR coefficient matrix —
as channel c's hospital-discrimination score.
"""
clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1).fit(X_n, hosp)
W = clf.coef_ # (n_classes, D)
# column norms — large means many class-discriminations rely on this neuron
return np.linalg.norm(W, axis=0) # (D,)
def _ablate_and_eval(
X_n, mask, scaler, head_target, model, layer, device,
hosp_clf, tumor_clf, hosp_target, tumor_target,
):
"""Apply mask to normalized features, unscale, evaluate everything."""
X_ablated_n = X_n * mask[None, :]
X_ablated = scaler.inverse_transform(X_ablated_n)
logits = _classifier_logits_from_features(model, X_ablated, layer, device)
head_acc = _accuracy(logits, head_target)
hosp_acc = hosp_clf.score(X_ablated_n, hosp_target) if hosp_clf is not None and len(np.unique(hosp_target)) > 1 else float("nan")
tumor_acc = tumor_clf.score(X_ablated_n, tumor_target)
return head_acc, hosp_acc, tumor_acc
def run_neuron_ablation(
run_dir: Path,
data_root: str,
epoch: int | None = None,
max_samples: int = 1000,
device: str = "cuda",
ks: List[int] = None,
n_random_samples: int = 5,
include_morphology: bool = True,
include_id: bool = True,
) -> Dict:
if ks is None:
# Dose-response curve emphasizing small K (per reviewer guidance)
ks = [0, 4, 8, 16, 32, 64, 128, 256]
epoch, ckpt_path = _select_epoch(run_dir, epoch)
print(f"\n M6 — Neuron Ablation (with random + morphology controls)")
print(f" run_dir : {run_dir.name}")
print(f" epoch : {epoch} ({ckpt_path.name})")
print(f" ks : {ks}")
print(f" random ablation: {n_random_samples} samplings per K")
model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device)
model.eval()
register_hooks(model)
cfg_path = run_dir / "config.json"
seed = 42
if cfg_path.exists():
seed = json.loads(cfg_path.read_text()).get("seed", 42)
if include_id:
train_loader, id_loader, ood_loader = _build_loaders_with_id(data_root, max_samples, seed=seed)
else:
from experiments.mechinterp_m4_ablation import _build_loaders as _bl
train_loader, ood_loader = _bl(data_root, max_samples, seed=seed)
id_loader = None
print(f" Extracting features (train + id + ood)...")
feats_train, hosp_train, tumor_train = extract_features(
model, train_loader, device, max_samples=max_samples
)
feats_ood, hosp_ood, tumor_ood = extract_features(
model, ood_loader, device, max_samples=max_samples // 2
)
feats_id, hosp_id, tumor_id = (None, None, None)
if id_loader is not None:
feats_id, hosp_id, tumor_id = extract_features(
model, id_loader, device, max_samples=max_samples // 2
)
layer = "avgpool"
def _to_2d(arr):
a = np.asarray(arr); return a.reshape(a.shape[0], -1)
X_tr = _to_2d(feats_train[layer])
X_ood = _to_2d(feats_ood[layer])
X_id = _to_2d(feats_id[layer]) if feats_id is not None else None
hosp_train = np.asarray(hosp_train); hosp_ood = np.asarray(hosp_ood)
tumor_train = np.asarray(tumor_train); tumor_ood = np.asarray(tumor_ood)
if X_id is not None:
hosp_id = np.asarray(hosp_id); tumor_id = np.asarray(tumor_id)
scaler = StandardScaler().fit(X_tr)
X_tr_n = scaler.transform(X_tr)
X_ood_n = scaler.transform(X_ood)
X_id_n = scaler.transform(X_id) if X_id is not None else None
# 1. Per-neuron scores: shortcut (hospital) and morphology (tumor)
print(f" Scoring {X_tr.shape[1]} avgpool channels...")
shortcut_scores = _per_neuron_shortcut_scores(X_tr_n, hosp_train)
morphology_scores = _per_neuron_shortcut_scores(X_tr_n, tumor_train) if include_morphology else None
rank_shortcut = np.argsort(-shortcut_scores)
rank_morphology = np.argsort(-morphology_scores) if morphology_scores is not None else None
hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1).fit(X_tr_n, hosp_train)
tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1).fit(X_tr_n, tumor_train)
rng = np.random.default_rng(seed)
D = X_tr.shape[1]
sweep = []
for k in ks:
row = {"k": int(k)}
# Mask helpers
def make_mask(indices):
m = np.ones(D)
if k > 0:
m[indices[:k]] = 0.0
return m
# ── A: top-K SHORTCUT neurons (the targeted ablation) ──
mask_s = make_mask(rank_shortcut)
h_ood, hp_ood, tp_ood = _ablate_and_eval(
X_ood_n, mask_s, scaler, tumor_ood, model, layer, device,
hosp_clf, tumor_clf, hosp_ood, tumor_ood,
)
row["shortcut_head_ood"] = float(h_ood)
row["shortcut_hosp_probe"] = float(hp_ood)
row["shortcut_tumor_probe"] = float(tp_ood)
if X_id_n is not None:
h_id, _, _ = _ablate_and_eval(
X_id_n, mask_s, scaler, tumor_id, model, layer, device,
None, tumor_clf, hosp_id, tumor_id,
)
row["shortcut_head_id"] = float(h_id)
# ── B: top-K MORPHOLOGY neurons (control: ablate the causal neurons) ──
if include_morphology and rank_morphology is not None:
mask_m = make_mask(rank_morphology)
h_ood_m, _, _ = _ablate_and_eval(
X_ood_n, mask_m, scaler, tumor_ood, model, layer, device,
None, tumor_clf, hosp_ood, tumor_ood,
)
row["morphology_head_ood"] = float(h_ood_m)
if X_id_n is not None:
h_id_m, _, _ = _ablate_and_eval(
X_id_n, mask_m, scaler, tumor_id, model, layer, device,
None, tumor_clf, hosp_id, tumor_id,
)
row["morphology_head_id"] = float(h_id_m)
# ── C: K RANDOM neurons (control: damage uniformly) ──
if k > 0:
r_oods, r_ids = [], []
for s_ in range(n_random_samples):
idx = rng.permutation(D)[:k]
m = np.ones(D); m[idx] = 0.0
h_ood_r, _, _ = _ablate_and_eval(
X_ood_n, m, scaler, tumor_ood, model, layer, device,
None, tumor_clf, hosp_ood, tumor_ood,
)
r_oods.append(h_ood_r)
if X_id_n is not None:
h_id_r, _, _ = _ablate_and_eval(
X_id_n, m, scaler, tumor_id, model, layer, device,
None, tumor_clf, hosp_id, tumor_id,
)
r_ids.append(h_id_r)
row["random_head_ood_mean"] = float(np.mean(r_oods))
row["random_head_ood_std"] = float(np.std(r_oods))
if r_ids:
row["random_head_id_mean"] = float(np.mean(r_ids))
row["random_head_id_std"] = float(np.std(r_ids))
else:
row["random_head_ood_mean"] = row["shortcut_head_ood"] # K=0 same as baseline
row["random_head_ood_std"] = 0.0
if X_id_n is not None:
row["random_head_id_mean"] = row.get("shortcut_head_id", float("nan"))
row["random_head_id_std"] = 0.0
sweep.append(row)
# Concise log line
print(f" K={k:>4} shortcut={row['shortcut_head_ood']:.3f} "
f"random={row.get('random_head_ood_mean', float('nan')):.3f}±"
f"{row.get('random_head_ood_std', 0):.3f} "
+ (f"morphology={row.get('morphology_head_ood', float('nan')):.3f}"
if include_morphology else ""))
return {
"run_id": run_dir.name,
"epoch": epoch,
"layer": layer,
"max_samples": max_samples,
"feature_dim": int(X_tr.shape[1]),
"shortcut_scores_top10": [int(i) for i in rank_shortcut[:10]],
"morphology_scores_top10": ([int(i) for i in rank_morphology[:10]]
if rank_morphology is not None else []),
"n_random_samples": n_random_samples,
"include_id": include_id,
"include_morphology": include_morphology,
"sweep": sweep,
}
def plot_neuron_ablation(result: Dict, out_path: Path):
sweep = result["sweep"]
ks = [r["k"] for r in sweep]
has_id = result.get("include_id", False)
has_morph = result.get("include_morphology", False)
fig, axes = plt.subplots(1, 2 if has_id else 1, figsize=(13, 5)) if has_id else \
plt.subplots(1, 1, figsize=(8, 5))
if not has_id:
axes = [axes]
# Panel A — Head OOD: shortcut vs random (vs morphology)
ax = axes[0]
shortcut_ood = [r.get("shortcut_head_ood") for r in sweep]
random_ood_mu = [r.get("random_head_ood_mean") for r in sweep]
random_ood_sd = [r.get("random_head_ood_std", 0) for r in sweep]
morphology_ood = [r.get("morphology_head_ood") for r in sweep] if has_morph else None
ax.plot(ks, shortcut_ood, "r-o", lw=2.2, ms=7, label="top-K shortcut neurons (targeted)")
ax.plot(ks, random_ood_mu, "k-s", lw=1.8, ms=6, label="K random neurons (control)")
ax.fill_between(ks,
[m - s for m, s in zip(random_ood_mu, random_ood_sd)],
[m + s for m, s in zip(random_ood_mu, random_ood_sd)],
color="black", alpha=0.15)
if has_morph and morphology_ood is not None:
ax.plot(ks, morphology_ood, "g-^", lw=1.8, ms=6,
label="top-K morphology neurons (control)")
base = shortcut_ood[0]
ax.axhline(base, color="gray", ls=":", lw=1, alpha=0.5,
label=f"K=0 baseline ({base:.3f})")
ax.set_xlabel("K (neurons zeroed at avgpool)")
ax.set_ylabel("Head OOD (H4) accuracy")
ax.set_xscale("symlog", linthresh=4)
ax.set_title("Targeted vs random ablation — OOD effect\n"
"(separation = shortcut neurons selectively hurt OOD)",
fontweight="bold", fontsize=10)
ax.legend(loc="best", fontsize=8); ax.grid(alpha=0.3)
# Panel B — ID/OOD tradeoff
if has_id:
ax = axes[1]
shortcut_id = [r.get("shortcut_head_id") for r in sweep]
random_id_mu = [r.get("random_head_id_mean") for r in sweep]
random_id_sd = [r.get("random_head_id_std", 0) for r in sweep]
ax.plot(ks, shortcut_id, "r--o", lw=2, ms=7, alpha=0.85, label="ID (shortcut ablation)")
ax.plot(ks, shortcut_ood, "r-o", lw=2, ms=7, label="OOD (shortcut ablation)")
ax.plot(ks, random_id_mu, "k--s", lw=1.6, ms=5, alpha=0.7, label="ID (random ablation)")
ax.plot(ks, random_ood_mu, "k-s", lw=1.6, ms=5, alpha=0.7, label="OOD (random ablation)")
ax.set_xlabel("K (neurons zeroed at avgpool)")
ax.set_ylabel("Head accuracy")
ax.set_xscale("symlog", linthresh=4)
ax.set_title("ID vs OOD degradation tradeoff\n"
"(targeted: OOD steady or ↑ while ID slowly ↓ = good)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8, loc="best"); ax.grid(alpha=0.3)
fig.suptitle(f"M6 — Targeted Neuron Ablation vs Random Control: {result['run_id']} "
f"• ep{result['epoch']}",
fontsize=11, fontweight="bold")
plt.tight_layout()
fig.savefig(out_path, dpi=180, bbox_inches="tight")
plt.close(fig)
def main():
p = argparse.ArgumentParser()
p.add_argument("--run_dir", required=True)
p.add_argument("--data_root", default="data/wilds")
p.add_argument("--epoch", type=int, default=None)
p.add_argument("--max_samples", type=int, default=1000)
p.add_argument("--device", default="cuda")
p.add_argument("--ks", default=None,
help="Comma-separated K values, e.g. '0,4,8,16,32,64,128,256'")
p.add_argument("--n_random_samples", type=int, default=5,
help="Random ablation: averages over this many random K-subsets")
p.add_argument("--no_morphology", action="store_true",
help="Skip the morphology-targeted ablation control")
p.add_argument("--no_id", action="store_true",
help="Skip ID accuracy evaluation (faster but loses ID/OOD ratio)")
args = p.parse_args()
ks = None
if args.ks is not None:
ks = [int(x) for x in args.ks.split(",")]
run_dir = Path(args.run_dir)
out_dir = run_dir / "mechinterp"
out_dir.mkdir(parents=True, exist_ok=True)
result = run_neuron_ablation(
run_dir=run_dir, data_root=args.data_root, epoch=args.epoch,
max_samples=args.max_samples, device=args.device, ks=ks,
n_random_samples=args.n_random_samples,
include_morphology=not args.no_morphology,
include_id=not args.no_id,
)
base = out_dir / f"m6_neuron_ablation_ep{result['epoch']:05d}"
base.with_suffix(".json").write_text(json.dumps(result, indent=2))
plot_neuron_ablation(result, base.with_suffix(".png"))
print(f"\n → {base.with_suffix('.json')}")
print(f" → {base.with_suffix('.png')}")
if __name__ == "__main__":
main()
```
---
## 16. Run inventory and summary results
14 runs at the canonical 3000-epoch config. The 11 runs with full M4/M5/M6 trajectories carry the paper's mechanistic claims; the 3 n=300 grokking runs have M1 only.
### `n = 1000` (5 grokking-favorable + 3 standard)
| Cond. | Seed | Run ID | Peak OOD | Peak ep | Final OOD | Δ (ungrok) | Best ID | Final ‖W‖ | Final rank | Final IRM |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| Grok | 7 | 20260508-183413_grokking_n1000_s7 | 0.6876 | 50 | 0.5882 | −0.0995 | 0.8797 | 1516.7 | 69.83 | 4.47e-12 |
| Grok | 42 | 20260505-080445_grokking_n1000_s42 | 0.7336 | 350 | 0.6639 | −0.0696 | 0.8976 | 1470.5 | 35.87 | 4.34e-12 |
| Grok | 123 | 20260505-100720_grokking_n1000_s123 | 0.7270 | 350 | 0.6447 | −0.0823 | 0.8994 | 1457.2 | 56.65 | 6.73e-15 |
| Grok | 456 | 20260505-100720_grokking_n1000_s456 | 0.6722 | 1100 | 0.5224 | −0.1498 | 0.8824 | 1493.6 | 64.54 | 2.87e-09 |
| Grok | 2024 | 20260508-183413_grokking_n1000_s2024 | 0.7056 | 400 | 0.5506 | −0.1550 | 0.8959 | 1632.4 | 65.77 | 4.56e-07 |
| Std | 42 | 20260505-100720_standard_n1000_s42 | 0.7615 | 1 | 0.6482 | −0.1133 | 0.9011 | 812.6 | 33.35 | 2.24e-13 |
| Std | 123 | 20260508-183413_standard_n1000_s123 | 0.8880* | 1 | 0.6645 | −0.2235 | 0.8957 | 798.3 | 37.08 | 9.53e-14 |
| Std | 456 | 20260508-183413_standard_n1000_s456 | 0.7450 | 1050 | 0.5783 | −0.1667 | 0.8950 | 792.4 | 35.30 | 7.18e-09 |
\* Std s123 peaks at epoch 1 on the random initialization (artifact).
**Aggregates (n=1000)**: grokking 5-seed mean peak = **0.7052 ± 0.0237**, mean Δ = **−0.1112 ± 0.0345**. Standard corrected 2-seed mean peak (s42, s456) = **0.7533**; raw 3-seed mean = 0.7982.
### `n = 500` and `n = 300`
| Cond. | Seed | Run ID | Peak OOD | Peak ep | Final OOD | Δ | Best ID |
| --- | --- | --- | --- | --- | --- | --- | --- |
| Grok | 42 | 20260505-080442_grokking_n500_s42 | 0.7924 | 50 | 0.5514 | −0.2410 | 0.8874 |
| Std | 42 | 20260505-100720_standard_n500_s42 | 0.7576 | 1050 | 0.6526 | −0.1050 | 0.8867 |
| Grok | 42 | 20260502-214859_grokking_n300_s42 | 0.7162 | 250 | 0.5189 | −0.1974 | 0.8664 |
| Grok | 123 | 20260502-214859_grokking_n300_s123 | 0.6961 | 50 | 0.5154 | −0.1807 | 0.8388 |
| Grok | 456 | 20260502-214859_grokking_n300_s456 | 0.6654 | 750 | 0.5469 | −0.1184 | 0.8522 |
| Std | 42 | 20260505-080836_standard_n300_s42 | 0.7647 | 250 | 0.7052 | −0.0596 | 0.8584 |
**Universal finding**: every run *ungrokks* (Δ < 0 for all 14). No run shows a plateau-then-jump. `grokking_epoch = -1` everywhere; `irm_drop_pct ≈ 100%` everywhere.
---
## 17. Per-run `config.json` and `summary.json` (all 14 runs)
Every run's exact saved JSON, verbatim from disk.
### `20260502-214859_grokking_n300_s42`
`config.json`:
```json
{
"seed": 42,
"n_train": 300,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "grokking",
"lr": 0.001,
"weight_decay": 0.005,
"n_epochs": 3000,
"init_scale": 4.0,
"use_grokfast": true,
"grokfast_alpha": 0.98,
"grokfast_lamb": 2.0,
"grad_clip": 1.0,
"run_id": "20260502-214859_grokking_n300_s42",
"run_dir": "experiments/runs/20260502-214859_grokking_n300_s42"
}
```
`results/summary.json`:
```json
{
"run_id": "20260502-214859_grokking_n300_s42",
"condition": "grokking",
"n_train": 300,
"seed": 42,
"best_id_val": 0.866388557806913,
"best_ood": 0.7162273379264937,
"peak_ood_epoch": 250,
"final_ood": 0.5188703647094787,
"ood_delta": -0.19735697321701506,
"ood_improvement": 0.01951701272132994,
"grokking_epoch": -1,
"final_weight_norm": 1131.2872887615824,
"final_feature_rank": 39.32365036010742,
"final_irm": 6.867336560523185e-14,
"final_shortcut_ratio": 1.0053635176200695,
"final_ood_gap": 0.3284478712857537,
"ungrokking_detected": true
}
```
### `20260502-214859_grokking_n300_s123`
`config.json`:
```json
{
"seed": 123,
"n_train": 300,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "grokking",
"lr": 0.001,
"weight_decay": 0.005,
"n_epochs": 3000,
"init_scale": 4.0,
"use_grokfast": true,
"grokfast_alpha": 0.98,
"grokfast_lamb": 2.0,
"grad_clip": 1.0,
"run_id": "20260502-214859_grokking_n300_s123",
"run_dir": "experiments/runs/20260502-214859_grokking_n300_s123"
}
```
`results/summary.json`:
```json
{
"run_id": "20260502-214859_grokking_n300_s123",
"condition": "grokking",
"n_train": 300,
"seed": 123,
"best_id_val": 0.8387663885578069,
"best_ood": 0.6961459778493663,
"peak_ood_epoch": 50,
"final_ood": 0.515413737155219,
"ood_delta": -0.18073224069414728,
"ood_improvement": 0.015413737155218987,
"grokking_epoch": -1,
"final_weight_norm": 981.2013665038359,
"final_feature_rank": 44.795772552490234,
"final_irm": 5.876544368309256e-13,
"final_shortcut_ratio": 1.0042143792951101,
"final_ood_gap": 0.23887708525240914,
"ungrokking_detected": true
}
```
### `20260502-214859_grokking_n300_s456`
`config.json`:
```json
{
"seed": 456,
"n_train": 300,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "grokking",
"lr": 0.001,
"weight_decay": 0.005,
"n_epochs": 3000,
"init_scale": 4.0,
"use_grokfast": true,
"grokfast_alpha": 0.98,
"grokfast_lamb": 2.0,
"grad_clip": 1.0,
"run_id": "20260502-214859_grokking_n300_s456",
"run_dir": "experiments/runs/20260502-214859_grokking_n300_s456"
}
```
`results/summary.json`:
```json
{
"run_id": "20260502-214859_grokking_n300_s456",
"condition": "grokking",
"n_train": 300,
"seed": 456,
"best_id_val": 0.8521752085816449,
"best_ood": 0.6653655324852447,
"peak_ood_epoch": 750,
"final_ood": 0.5469348884238249,
"ood_delta": -0.11843064406141979,
"ood_improvement": 0.04693488842382487,
"grokking_epoch": -1,
"final_weight_norm": 1109.5530019538146,
"final_feature_rank": 49.14884567260742,
"final_irm": 8.174740884214771e-10,
"final_shortcut_ratio": 0.9680124053677005,
"final_ood_gap": 0.25908418189798677,
"ungrokking_detected": true
}
```
### `20260505-080836_standard_n300_s42`
`config.json`:
```json
{
"seed": 42,
"n_train": 300,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "standard",
"lr": 0.001,
"weight_decay": 0.0001,
"n_epochs": 3000,
"init_scale": 1.0,
"use_grokfast": false,
"grad_clip": 1.0,
"run_id": "20260505-080836_standard_n300_s42",
"run_dir": "experiments/runs/20260505-080836_standard_n300_s42"
}
```
`results/summary.json`:
```json
{
"run_id": "20260505-080836_standard_n300_s42",
"condition": "standard",
"n_train": 300,
"seed": 42,
"best_id_val": 0.858373063170441,
"best_ood": 0.7647259388153408,
"peak_ood_epoch": 250,
"final_ood": 0.7051637783055471,
"ood_delta": -0.05956216050979368,
"ood_improvement": -0.026599572036588692,
"grokking_epoch": -1,
"irm_drop_pct": 99.99996794236725,
"irm_drop_epoch": 100,
"epoch_gap": -1,
"final_weight_norm": 452.7899771783757,
"final_feature_rank": 17.35470962524414,
"final_irm": 1.032695706726372e-09,
"final_shortcut_ratio": 1.0080115087353168,
"final_ood_gap": 0.12806029797574014
}
```
### `20260505-080442_grokking_n500_s42`
`config.json`:
```json
{
"seed": 42,
"n_train": 500,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "grokking",
"lr": 0.001,
"weight_decay": 0.005,
"n_epochs": 3000,
"init_scale": 4.0,
"use_grokfast": true,
"grokfast_alpha": 0.98,
"grokfast_lamb": 2.0,
"grad_clip": 1.0,
"run_id": "20260505-080442_grokking_n500_s42",
"run_dir": "experiments/runs/20260505-080442_grokking_n500_s42"
}
```
`results/summary.json`:
```json
{
"run_id": "20260505-080442_grokking_n500_s42",
"condition": "grokking",
"n_train": 500,
"seed": 42,
"best_id_val": 0.8873957091775924,
"best_ood": 0.7924495026688927,
"peak_ood_epoch": 50,
"final_ood": 0.5514496672702048,
"ood_delta": -0.24099983539868786,
"ood_improvement": -0.06643779246126003,
"grokking_epoch": -1,
"irm_drop_pct": 99.99999995667906,
"irm_drop_epoch": 50,
"epoch_gap": -1,
"final_weight_norm": 1046.672482929869,
"final_feature_rank": 34.301780700683594,
"final_irm": 6.552892841682478e-07,
"final_shortcut_ratio": 0.9969108279290858,
"final_ood_gap": 0.32080897396936603
}
```
### `20260505-100720_standard_n500_s42`
`config.json`:
```json
{
"seed": 42,
"n_train": 500,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "standard",
"lr": 0.001,
"weight_decay": 0.0001,
"n_epochs": 3000,
"init_scale": 1.0,
"use_grokfast": false,
"grad_clip": 1.0,
"run_id": "20260505-100720_standard_n500_s42",
"run_dir": "experiments/runs/20260505-100720_standard_n500_s42"
}
```
`results/summary.json`:
```json
{
"run_id": "20260505-100720_standard_n500_s42",
"condition": "standard",
"n_train": 500,
"seed": 42,
"best_id_val": 0.8866507747318236,
"best_ood": 0.7575775389752393,
"peak_ood_epoch": 1050,
"final_ood": 0.6525854163237472,
"ood_delta": -0.10499212265149205,
"ood_improvement": -0.08494603428410186,
"grokking_epoch": -1,
"irm_drop_pct": 99.99998953242944,
"irm_drop_epoch": 50,
"epoch_gap": -1,
"final_weight_norm": 555.1170332945212,
"final_feature_rank": 20.722705841064453,
"final_irm": 5.05333608291636e-10,
"final_shortcut_ratio": 0.9843676046096169,
"final_ood_gap": 0.19765296269889876
}
```
### `20260505-080445_grokking_n1000_s42`
`config.json`:
```json
{
"seed": 42,
"n_train": 1000,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "grokking",
"lr": 0.001,
"weight_decay": 0.005,
"n_epochs": 3000,
"init_scale": 4.0,
"use_grokfast": true,
"grokfast_alpha": 0.98,
"grokfast_lamb": 2.0,
"grad_clip": 1.0,
"run_id": "20260505-080445_grokking_n1000_s42",
"run_dir": "experiments/runs/20260505-080445_grokking_n1000_s42"
}
```
`results/summary.json`:
```json
{
"run_id": "20260505-080445_grokking_n1000_s42",
"condition": "grokking",
"n_train": 1000,
"seed": 42,
"best_id_val": 0.8976460071513707,
"best_ood": 0.7335575046441084,
"peak_ood_epoch": 350,
"final_ood": 0.6639076351494345,
"ood_delta": -0.06964986949467389,
"ood_improvement": 0.011399816587109313,
"grokking_epoch": -1,
"irm_drop_pct": 99.99999500910222,
"irm_drop_epoch": 50,
"epoch_gap": -1,
"final_weight_norm": 1470.4773196265805,
"final_feature_rank": 35.86945724487305,
"final_irm": 4.340286376830482e-12,
"final_shortcut_ratio": 0.9871634909839839,
"final_ood_gap": 0.20680154244293736
}
```
### `20260505-100720_grokking_n1000_s123`
`config.json`:
```json
{
"seed": 123,
"n_train": 1000,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "grokking",
"lr": 0.001,
"weight_decay": 0.005,
"n_epochs": 3000,
"init_scale": 4.0,
"use_grokfast": true,
"grokfast_alpha": 0.98,
"grokfast_lamb": 2.0,
"grad_clip": 1.0,
"run_id": "20260505-100720_grokking_n1000_s123",
"run_dir": "experiments/runs/20260505-100720_grokking_n1000_s123"
}
```
`results/summary.json`:
```json
{
"run_id": "20260505-100720_grokking_n1000_s123",
"condition": "grokking",
"n_train": 1000,
"seed": 123,
"best_id_val": 0.8994338498212158,
"best_ood": 0.7269734521598044,
"peak_ood_epoch": 350,
"final_ood": 0.6446610388694242,
"ood_delta": -0.08231241329038019,
"ood_improvement": 0.01956639311496222,
"grokking_epoch": -1,
"irm_drop_pct": 99.9999991755092,
"irm_drop_epoch": 50,
"epoch_gap": -1,
"final_weight_norm": 1457.2260357210637,
"final_feature_rank": 56.64516067504883,
"final_irm": 6.7278793620213426e-15,
"final_shortcut_ratio": 0.9760735232865748,
"final_ood_gap": 0.23534492060614198
}
```
### `20260505-100720_grokking_n1000_s456`
`config.json`:
```json
{
"seed": 456,
"n_train": 1000,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "grokking",
"lr": 0.001,
"weight_decay": 0.005,
"n_epochs": 3000,
"init_scale": 4.0,
"use_grokfast": true,
"grokfast_alpha": 0.98,
"grokfast_lamb": 2.0,
"grad_clip": 1.0,
"run_id": "20260505-100720_grokking_n1000_s456",
"run_dir": "experiments/runs/20260505-100720_grokking_n1000_s456"
}
```
`results/summary.json`:
```json
{
"run_id": "20260505-100720_grokking_n1000_s456",
"condition": "grokking",
"n_train": 1000,
"seed": 456,
"best_id_val": 0.8824493444576877,
"best_ood": 0.6721847297011311,
"peak_ood_epoch": 1100,
"final_ood": 0.522397535683213,
"ood_delta": -0.14978719401791807,
"ood_improvement": -0.08302960472170617,
"grokking_epoch": -1,
"irm_drop_pct": 99.99999977269624,
"irm_drop_epoch": 50,
"epoch_gap": -1,
"final_weight_norm": 1493.580040593733,
"final_feature_rank": 64.54296875,
"final_irm": 2.8693030174054e-09,
"final_shortcut_ratio": 1.0356636404914459,
"final_ood_gap": 0.3221495441737596
}
```
### `20260508-183413_grokking_n1000_s7`
`config.json`:
```json
{
"seed": 7,
"n_train": 1000,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "grokking",
"lr": 0.001,
"weight_decay": 0.005,
"n_epochs": 3000,
"init_scale": 4.0,
"use_grokfast": true,
"grokfast_alpha": 0.98,
"grokfast_lamb": 2.0,
"grad_clip": 1.0,
"run_id": "20260508-183413_grokking_n1000_s7",
"run_dir": "experiments/runs/20260508-183413_grokking_n1000_s7"
}
```
`results/summary.json`:
```json
{
"run_id": "20260508-183413_grokking_n1000_s7",
"condition": "grokking",
"n_train": 1000,
"seed": 7,
"best_id_val": 0.8797079856972586,
"best_ood": 0.6876454958026665,
"peak_ood_epoch": 50,
"final_ood": 0.5881910315799375,
"ood_delta": -0.09945446422272908,
"ood_improvement": -0.03798527993980294,
"grokking_epoch": -1,
"irm_drop_pct": 99.99999355254296,
"irm_drop_epoch": 50,
"epoch_gap": -1,
"final_weight_norm": 1516.661424559645,
"final_feature_rank": 69.8335952758789,
"final_irm": 4.471252361415434e-12,
"final_shortcut_ratio": 0.996297612800058,
"final_ood_gap": 0.26586141180504463
}
```
### `20260508-183413_grokking_n1000_s2024`
`config.json`:
```json
{
"seed": 2024,
"n_train": 1000,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "grokking",
"lr": 0.001,
"weight_decay": 0.005,
"n_epochs": 3000,
"init_scale": 4.0,
"use_grokfast": true,
"grokfast_alpha": 0.98,
"grokfast_lamb": 2.0,
"grad_clip": 1.0,
"run_id": "20260508-183413_grokking_n1000_s2024",
"run_dir": "experiments/runs/20260508-183413_grokking_n1000_s2024"
}
```
`results/summary.json`:
```json
{
"run_id": "20260508-183413_grokking_n1000_s2024",
"condition": "grokking",
"n_train": 1000,
"seed": 2024,
"best_id_val": 0.8959177592371871,
"best_ood": 0.7056105532955534,
"peak_ood_epoch": 400,
"final_ood": 0.5506031462365085,
"ood_delta": -0.15500740705904492,
"ood_improvement": -0.04230488865896964,
"grokking_epoch": -1,
"irm_drop_pct": 99.9999998741651,
"irm_drop_epoch": 50,
"epoch_gap": -1,
"final_weight_norm": 1632.3948325021925,
"final_feature_rank": 65.77389526367188,
"final_irm": 4.5635979972757923e-07,
"final_shortcut_ratio": 0.9633737610070339,
"final_ood_gap": 0.308663838268855
}
```
### `20260505-100720_standard_n1000_s42`
`config.json`:
```json
{
"seed": 42,
"n_train": 1000,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "standard",
"lr": 0.001,
"weight_decay": 0.0001,
"n_epochs": 3000,
"init_scale": 1.0,
"use_grokfast": false,
"grad_clip": 1.0,
"run_id": "20260505-100720_standard_n1000_s42",
"run_dir": "experiments/runs/20260505-100720_standard_n1000_s42"
}
```
`results/summary.json`:
```json
{
"run_id": "20260505-100720_standard_n1000_s42",
"condition": "standard",
"n_train": 1000,
"seed": 42,
"best_id_val": 0.9011025029797378,
"best_ood": 0.7615162132292426,
"peak_ood_epoch": 1,
"final_ood": 0.6482234815528958,
"ood_delta": -0.11329273167634679,
"ood_improvement": -0.022357561078844124,
"grokking_epoch": -1,
"irm_drop_pct": 99.99999329006182,
"irm_drop_epoch": 50,
"epoch_gap": -1,
"final_weight_norm": 812.5540534619066,
"final_feature_rank": 33.34842300415039,
"final_irm": 2.2439123855463178e-13,
"final_shortcut_ratio": 0.99423543483118,
"final_ood_gap": 0.24736650652815306
}
```
### `20260508-183413_standard_n1000_s123`
`config.json`:
```json
{
"seed": 123,
"n_train": 1000,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "standard",
"lr": 0.001,
"weight_decay": 0.0001,
"n_epochs": 3000,
"init_scale": 1.0,
"use_grokfast": false,
"grad_clip": 1.0,
"run_id": "20260508-183413_standard_n1000_s123",
"run_dir": "experiments/runs/20260508-183413_standard_n1000_s123"
}
```
`results/summary.json`:
```json
{
"run_id": "20260508-183413_standard_n1000_s123",
"condition": "standard",
"n_train": 1000,
"seed": 123,
"best_id_val": 0.8957091775923719,
"best_ood": 0.8879652926376185,
"peak_ood_epoch": 1,
"final_ood": 0.6644837397418111,
"ood_delta": -0.2234815528958074,
"ood_improvement": -0.10371998965363183,
"grokking_epoch": -1,
"irm_drop_pct": 99.9999667168104,
"irm_drop_epoch": 150,
"epoch_gap": -1,
"final_weight_norm": 798.3091903337586,
"final_feature_rank": 37.0809326171875,
"final_irm": 9.526699497634447e-14,
"final_shortcut_ratio": 0.9913218948516812,
"final_ood_gap": 0.22869266073494698
}
```
### `20260508-183413_standard_n1000_s456`
`config.json`:
```json
{
"seed": 456,
"n_train": 1000,
"batch_size": 32,
"img_size": 96,
"n_classes": 2,
"log_every": 50,
"device": "cuda",
"condition": "standard",
"lr": 0.001,
"weight_decay": 0.0001,
"n_epochs": 3000,
"init_scale": 1.0,
"use_grokfast": false,
"grad_clip": 1.0,
"run_id": "20260508-183413_standard_n1000_s456",
"run_dir": "experiments/runs/20260508-183413_standard_n1000_s456"
}
```
`results/summary.json`:
```json
{
"run_id": "20260508-183413_standard_n1000_s456",
"condition": "standard",
"n_train": 1000,
"seed": 456,
"best_id_val": 0.8949940405244339,
"best_ood": 0.7449737813624285,
"peak_ood_epoch": 1050,
"final_ood": 0.5783149528534813,
"ood_delta": -0.1666588285089472,
"ood_improvement": 0.02752839372633853,
"grokking_epoch": -1,
"irm_drop_pct": 99.99999467428148,
"irm_drop_epoch": 50,
"epoch_gap": -1,
"final_weight_norm": 792.3747983792097,
"final_feature_rank": 35.297889709472656,
"final_irm": 7.180730676736857e-09,
"final_shortcut_ratio": 0.987733651871859,
"final_ood_gap": 0.2792237837376986
}
```
---
## 18. Full training log: grokking n=1000 seed=42
```
# launched: 2026-05-05T08:04:45Z
# host: ubuntu-Standard-PC-Q35-ICH9-2009
# pwd: /home/garima/CausalGrok
# cmd: /home/garima/anaconda3/envs/causalgrok/bin/python -u -m experiments.causalgrok_camelyon_v2 --condition grokking --n_train 1000 --seed 42 --run_dir experiments/runs/20260505-080445_grokking_n1000_s42 --wandb_project causalgrok --wandb_mode offline
----
Device: cuda
Run ID: 20260505-080445_grokking_n1000_s42
Started: 2026-05-05T08:04:50.408035+00:00
Env hospital=0: 181 samples, positive rate=0.53
Env hospital=3: 371 samples, positive rate=0.48
Env hospital=4: 448 samples, positive rate=0.50
Train: 1000 | ID val (H3): 33560 | OOD test (H4): 85054
Params: 11,177,538
============================================================
GROKKING | Camelyon17 v2 | 3000 epochs
WD=0.005 | α=4.0 | n=1000
Tracking: ID val (H3) + OOD test (H4) at every checkpoint
Grokking detection: watching OOD acc, not ID val acc
IRM envs: 3 hospitals
============================================================
ep 1 | tr 0.734 | id 0.728 | ood 0.499 | gap +0.228 | ‖W‖ 356.1 | rank 109.6 | IRM 0.2004 | sc 1.12x
ep 50 | tr 0.996 | id 0.859 | ood 0.697 | gap +0.162 | ‖W‖ 495.4 | rank 91.4 | IRM 0.0001 | sc 0.97x
ep 100 | tr 1.000 | id 0.883 | ood 0.592 | gap +0.291 | ‖W‖ 551.6 | rank 81.7 | IRM 0.0000 | sc 0.96x
ep 150 | tr 1.000 | id 0.872 | ood 0.641 | gap +0.231 | ‖W‖ 650.2 | rank 72.2 | IRM 0.0000 | sc 0.98x
✓ Checkpoint → ep00200.pt
ep 200 | tr 1.000 | id 0.885 | ood 0.613 | gap +0.272 | ‖W‖ 741.5 | rank 62.6 | IRM 0.0000 | sc 0.97x
ep 250 | tr 1.000 | id 0.890 | ood 0.659 | gap +0.231 | ‖W‖ 842.7 | rank 61.1 | IRM 0.0000 | sc 0.99x
ep 300 | tr 1.000 | id 0.876 | ood 0.650 | gap +0.227 | ‖W‖ 878.1 | rank 60.7 | IRM 0.0000 | sc 0.99x
ep 350 | tr 1.000 | id 0.881 | ood 0.734 | gap +0.147 | ‖W‖ 907.9 | rank 53.9 | IRM 0.0000 | sc 0.97x
✓ Checkpoint → ep00400.pt
ep 400 | tr 1.000 | id 0.875 | ood 0.647 | gap +0.229 | ‖W‖ 1005.7 | rank 60.9 | IRM 0.0000 | sc 1.00x
ep 450 | tr 1.000 | id 0.876 | ood 0.615 | gap +0.261 | ‖W‖ 1031.2 | rank 57.8 | IRM 0.0000 | sc 0.99x
ep 500 | tr 1.000 | id 0.880 | ood 0.607 | gap +0.273 | ‖W‖ 1023.1 | rank 55.7 | IRM 0.0000 | sc 0.99x
ep 550 | tr 1.000 | id 0.888 | ood 0.611 | gap +0.276 | ‖W‖ 1101.7 | rank 52.6 | IRM 0.0000 | sc 0.99x
✓ Checkpoint → ep00600.pt
ep 600 | tr 1.000 | id 0.884 | ood 0.546 | gap +0.337 | ‖W‖ 1102.1 | rank 54.5 | IRM 0.0000 | sc 1.00x
ep 650 | tr 1.000 | id 0.878 | ood 0.605 | gap +0.273 | ‖W‖ 1142.4 | rank 55.2 | IRM 0.0000 | sc 0.98x
ep 700 | tr 1.000 | id 0.889 | ood 0.587 | gap +0.303 | ‖W‖ 1187.3 | rank 47.9 | IRM 0.0000 | sc 0.98x
ep 750 | tr 1.000 | id 0.892 | ood 0.670 | gap +0.222 | ‖W‖ 1180.7 | rank 48.4 | IRM 0.0000 | sc 0.98x
✓ Checkpoint → ep00800.pt
ep 800 | tr 1.000 | id 0.879 | ood 0.675 | gap +0.204 | ‖W‖ 1265.4 | rank 48.0 | IRM 0.0000 | sc 0.99x
ep 850 | tr 1.000 | id 0.886 | ood 0.627 | gap +0.259 | ‖W‖ 1300.7 | rank 48.3 | IRM 0.0000 | sc 0.98x
ep 900 | tr 1.000 | id 0.886 | ood 0.653 | gap +0.233 | ‖W‖ 1290.6 | rank 48.0 | IRM 0.0000 | sc 0.98x
ep 950 | tr 1.000 | id 0.883 | ood 0.640 | gap +0.243 | ‖W‖ 1280.4 | rank 47.1 | IRM 0.0000 | sc 0.98x
✓ Checkpoint → ep01000.pt
ep 1000 | tr 1.000 | id 0.886 | ood 0.653 | gap +0.233 | ‖W‖ 1270.3 | rank 47.3 | IRM 0.0000 | sc 0.99x
ep 1050 | tr 1.000 | id 0.898 | ood 0.697 | gap +0.201 | ‖W‖ 1289.1 | rank 42.4 | IRM 0.0000 | sc 0.99x
ep 1100 | tr 0.999 | id 0.876 | ood 0.641 | gap +0.235 | ‖W‖ 1308.0 | rank 46.4 | IRM 0.0000 | sc 0.99x
ep 1150 | tr 1.000 | id 0.894 | ood 0.663 | gap +0.231 | ‖W‖ 1313.5 | rank 45.8 | IRM 0.0000 | sc 0.99x
✓ Checkpoint → ep01200.pt
ep 1200 | tr 0.999 | id 0.878 | ood 0.592 | gap +0.286 | ‖W‖ 1320.4 | rank 42.9 | IRM 0.0000 | sc 0.99x
ep 1250 | tr 1.000 | id 0.877 | ood 0.546 | gap +0.330 | ‖W‖ 1373.5 | rank 47.2 | IRM 0.0000 | sc 0.99x
ep 1300 | tr 1.000 | id 0.853 | ood 0.600 | gap +0.254 | ‖W‖ 1375.3 | rank 48.0 | IRM 0.0000 | sc 0.98x
ep 1350 | tr 1.000 | id 0.881 | ood 0.575 | gap +0.306 | ‖W‖ 1415.1 | rank 49.4 | IRM 0.0000 | sc 0.99x
✓ Checkpoint → ep01400.pt
ep 1400 | tr 1.000 | id 0.892 | ood 0.605 | gap +0.287 | ‖W‖ 1428.3 | rank 42.5 | IRM 0.0000 | sc 1.00x
ep 1450 | tr 1.000 | id 0.867 | ood 0.619 | gap +0.248 | ‖W‖ 1438.2 | rank 41.4 | IRM 0.0000 | sc 0.99x
ep 1500 | tr 1.000 | id 0.875 | ood 0.652 | gap +0.223 | ‖W‖ 1440.8 | rank 47.2 | IRM 0.0000 | sc 0.97x
ep 1550 | tr 1.000 | id 0.879 | ood 0.643 | gap +0.236 | ‖W‖ 1429.5 | rank 45.8 | IRM 0.0000 | sc 0.97x
✓ Checkpoint → ep01600.pt
ep 1600 | tr 1.000 | id 0.878 | ood 0.630 | gap +0.248 | ‖W‖ 1418.2 | rank 46.4 | IRM 0.0000 | sc 0.98x
ep 1650 | tr 1.000 | id 0.881 | ood 0.636 | gap +0.245 | ‖W‖ 1406.9 | rank 45.9 | IRM 0.0000 | sc 0.98x
ep 1700 | tr 1.000 | id 0.883 | ood 0.634 | gap +0.249 | ‖W‖ 1395.7 | rank 47.0 | IRM 0.0000 | sc 0.98x
ep 1750 | tr 1.000 | id 0.875 | ood 0.702 | gap +0.173 | ‖W‖ 1441.2 | rank 45.9 | IRM 0.0000 | sc 0.99x
✓ Checkpoint → ep01800.pt
ep 1800 | tr 1.000 | id 0.880 | ood 0.648 | gap +0.232 | ‖W‖ 1433.0 | rank 45.6 | IRM 0.0000 | sc 0.98x
ep 1850 | tr 1.000 | id 0.880 | ood 0.644 | gap +0.236 | ‖W‖ 1421.6 | rank 45.8 | IRM 0.0000 | sc 0.98x
ep 1900 | tr 1.000 | id 0.885 | ood 0.644 | gap +0.242 | ‖W‖ 1410.4 | rank 43.9 | IRM 0.0000 | sc 0.98x
ep 1950 | tr 1.000 | id 0.883 | ood 0.549 | gap +0.334 | ‖W‖ 1415.3 | rank 49.1 | IRM 0.0000 | sc 0.98x
✓ Checkpoint → ep02000.pt
ep 2000 | tr 1.000 | id 0.889 | ood 0.581 | gap +0.308 | ‖W‖ 1404.6 | rank 47.1 | IRM 0.0000 | sc 0.98x
ep 2050 | tr 1.000 | id 0.888 | ood 0.577 | gap +0.311 | ‖W‖ 1393.4 | rank 46.6 | IRM 0.0000 | sc 0.98x
ep 2100 | tr 1.000 | id 0.884 | ood 0.617 | gap +0.266 | ‖W‖ 1460.6 | rank 33.9 | IRM 0.0000 | sc 1.00x
ep 2150 | tr 1.000 | id 0.870 | ood 0.597 | gap +0.273 | ‖W‖ 1470.9 | rank 37.5 | IRM 0.0000 | sc 1.00x
✓ Checkpoint → ep02200.pt
ep 2200 | tr 1.000 | id 0.869 | ood 0.568 | gap +0.301 | ‖W‖ 1460.2 | rank 38.5 | IRM 0.0000 | sc 0.99x
ep 2250 | tr 1.000 | id 0.870 | ood 0.588 | gap +0.282 | ‖W‖ 1448.6 | rank 36.9 | IRM 0.0000 | sc 0.98x
ep 2300 | tr 0.998 | id 0.872 | ood 0.706 | gap +0.166 | ‖W‖ 1485.2 | rank 41.9 | IRM 0.0000 | sc 0.97x
ep 2350 | tr 1.000 | id 0.877 | ood 0.648 | gap +0.229 | ‖W‖ 1506.9 | rank 41.2 | IRM 0.0000 | sc 1.00x
✓ Checkpoint → ep02400.pt
ep 2400 | tr 1.000 | id 0.876 | ood 0.650 | gap +0.226 | ‖W‖ 1495.0 | rank 40.1 | IRM 0.0000 | sc 1.00x
ep 2450 | tr 1.000 | id 0.869 | ood 0.682 | gap +0.187 | ‖W‖ 1486.6 | rank 36.8 | IRM 0.0000 | sc 0.98x
ep 2500 | tr 1.000 | id 0.884 | ood 0.621 | gap +0.263 | ‖W‖ 1510.2 | rank 40.7 | IRM 0.0000 | sc 1.00x
ep 2550 | tr 1.000 | id 0.876 | ood 0.667 | gap +0.209 | ‖W‖ 1499.1 | rank 38.9 | IRM 0.0000 | sc 0.99x
✓ Checkpoint → ep02600.pt
ep 2600 | tr 1.000 | id 0.871 | ood 0.635 | gap +0.236 | ‖W‖ 1506.6 | rank 39.4 | IRM 0.0000 | sc 1.01x
ep 2650 | tr 1.000 | id 0.874 | ood 0.692 | gap +0.183 | ‖W‖ 1500.6 | rank 36.8 | IRM 0.0000 | sc 0.99x
ep 2700 | tr 1.000 | id 0.877 | ood 0.607 | gap +0.269 | ‖W‖ 1510.5 | rank 39.9 | IRM 0.0000 | sc 0.97x
ep 2750 | tr 1.000 | id 0.873 | ood 0.599 | gap +0.273 | ‖W‖ 1521.2 | rank 38.5 | IRM 0.0000 | sc 1.00x
✓ Checkpoint → ep02800.pt
ep 2800 | tr 1.000 | id 0.878 | ood 0.584 | gap +0.295 | ‖W‖ 1515.0 | rank 36.9 | IRM 0.0000 | sc 1.00x
ep 2850 | tr 1.000 | id 0.878 | ood 0.619 | gap +0.259 | ‖W‖ 1504.0 | rank 36.2 | IRM 0.0000 | sc 0.99x
ep 2900 | tr 1.000 | id 0.880 | ood 0.614 | gap +0.266 | ‖W‖ 1492.6 | rank 35.7 | IRM 0.0000 | sc 0.99x
ep 2950 | tr 1.000 | id 0.871 | ood 0.619 | gap +0.252 | ‖W‖ 1482.1 | rank 36.4 | IRM 0.0000 | sc 0.99x
✓ Checkpoint → ep03000.pt
ep 3000 | tr 1.000 | id 0.871 | ood 0.664 | gap +0.207 | ‖W‖ 1470.5 | rank 35.9 | IRM 0.0000 | sc 0.99x
Best ID val (H3): 0.8976
Best OOD (H4): 0.7336
OOD improvement: +0.0114 ← did OOD grok?
Grokking at: None
IRM drop: 100.0%
Wall time: 358.5 min
```
---
## 19. Full training log: standard n=1000 seed=42
```
# launched: 2026-05-05T10:07:20Z
# host: ubuntu-Standard-PC-Q35-ICH9-2009
# pwd: /home/garima/CausalGrok
# cmd: /home/garima/anaconda3/envs/causalgrok/bin/python -u -m experiments.causalgrok_camelyon_v2 --condition standard --n_train 1000 --seed 42 --run_dir experiments/runs/20260505-100720_standard_n1000_s42 --wandb_project causalgrok --wandb_mode offline --n_epochs 3000
----
Device: cuda
Run ID: 20260505-100720_standard_n1000_s42
Started: 2026-05-05T10:07:36.596218+00:00
Env hospital=0: 181 samples, positive rate=0.53
Env hospital=3: 371 samples, positive rate=0.48
Env hospital=4: 448 samples, positive rate=0.50
Train: 1000 | ID val (H3): 33560 | OOD test (H4): 85054
Params: 11,177,538
============================================================
STANDARD | Camelyon17 v2 | 3000 epochs
WD=0.0001 | α=1.0 | n=1000
Tracking: ID val (H3) + OOD test (H4) at every checkpoint
Grokking detection: watching OOD acc, not ID val acc
IRM envs: 3 hospitals
============================================================
ep 1 | tr 0.681 | id 0.664 | ood 0.762 | gap -0.098 | ‖W‖ 105.0 | rank 78.5 | IRM 0.1490 | sc 1.25x
ep 50 | tr 0.999 | id 0.897 | ood 0.620 | gap +0.276 | ‖W‖ 162.5 | rank 45.1 | IRM 0.0000 | sc 0.99x
ep 100 | tr 1.000 | id 0.901 | ood 0.722 | gap +0.179 | ‖W‖ 196.1 | rank 35.1 | IRM 0.0000 | sc 0.98x
ep 150 | tr 0.997 | id 0.880 | ood 0.648 | gap +0.232 | ‖W‖ 228.8 | rank 29.4 | IRM 0.0000 | sc 0.98x
✓ Checkpoint → ep00200.pt
ep 200 | tr 1.000 | id 0.886 | ood 0.611 | gap +0.274 | ‖W‖ 268.7 | rank 30.1 | IRM 0.0000 | sc 0.99x
ep 250 | tr 1.000 | id 0.890 | ood 0.672 | gap +0.218 | ‖W‖ 294.3 | rank 30.9 | IRM 0.0000 | sc 0.97x
ep 300 | tr 1.000 | id 0.900 | ood 0.684 | gap +0.216 | ‖W‖ 323.3 | rank 26.7 | IRM 0.0000 | sc 0.98x
ep 350 | tr 1.000 | id 0.891 | ood 0.573 | gap +0.318 | ‖W‖ 343.5 | rank 27.9 | IRM 0.0000 | sc 0.98x
✓ Checkpoint → ep00400.pt
ep 400 | tr 1.000 | id 0.885 | ood 0.642 | gap +0.243 | ‖W‖ 361.4 | rank 28.6 | IRM 0.0000 | sc 0.98x
ep 450 | tr 1.000 | id 0.894 | ood 0.700 | gap +0.194 | ‖W‖ 377.9 | rank 31.3 | IRM 0.0000 | sc 0.98x
ep 500 | tr 1.000 | id 0.890 | ood 0.705 | gap +0.185 | ‖W‖ 378.2 | rank 29.3 | IRM 0.0000 | sc 0.97x
ep 550 | tr 1.000 | id 0.895 | ood 0.656 | gap +0.239 | ‖W‖ 412.9 | rank 26.5 | IRM 0.0000 | sc 0.97x
✓ Checkpoint → ep00600.pt
ep 600 | tr 1.000 | id 0.862 | ood 0.717 | gap +0.145 | ‖W‖ 426.0 | rank 29.7 | IRM 0.0000 | sc 0.97x
ep 650 | tr 1.000 | id 0.885 | ood 0.713 | gap +0.172 | ‖W‖ 445.0 | rank 25.9 | IRM 0.0000 | sc 0.99x
ep 700 | tr 1.000 | id 0.892 | ood 0.639 | gap +0.253 | ‖W‖ 454.4 | rank 28.0 | IRM 0.0000 | sc 0.98x
ep 750 | tr 1.000 | id 0.880 | ood 0.648 | gap +0.232 | ‖W‖ 472.1 | rank 25.5 | IRM 0.0000 | sc 0.98x
✓ Checkpoint → ep00800.pt
ep 800 | tr 1.000 | id 0.888 | ood 0.681 | gap +0.207 | ‖W‖ 489.6 | rank 28.8 | IRM 0.0000 | sc 0.97x
ep 850 | tr 1.000 | id 0.887 | ood 0.626 | gap +0.262 | ‖W‖ 506.4 | rank 28.2 | IRM 0.0000 | sc 0.99x
ep 900 | tr 1.000 | id 0.888 | ood 0.703 | gap +0.185 | ‖W‖ 515.4 | rank 31.3 | IRM 0.0000 | sc 0.99x
ep 950 | tr 1.000 | id 0.883 | ood 0.667 | gap +0.215 | ‖W‖ 526.2 | rank 27.6 | IRM 0.0000 | sc 1.00x
✓ Checkpoint → ep01000.pt
ep 1000 | tr 1.000 | id 0.897 | ood 0.674 | gap +0.222 | ‖W‖ 530.5 | rank 26.2 | IRM 0.0000 | sc 1.00x
ep 1050 | tr 1.000 | id 0.896 | ood 0.581 | gap +0.315 | ‖W‖ 544.2 | rank 26.3 | IRM 0.0000 | sc 0.98x
ep 1100 | tr 1.000 | id 0.877 | ood 0.655 | gap +0.222 | ‖W‖ 559.3 | rank 26.2 | IRM 0.0000 | sc 1.00x
ep 1150 | tr 1.000 | id 0.899 | ood 0.694 | gap +0.205 | ‖W‖ 562.6 | rank 23.7 | IRM 0.0000 | sc 0.99x
✓ Checkpoint → ep01200.pt
ep 1200 | tr 1.000 | id 0.892 | ood 0.578 | gap +0.315 | ‖W‖ 581.9 | rank 27.0 | IRM 0.0000 | sc 0.98x
ep 1250 | tr 1.000 | id 0.889 | ood 0.647 | gap +0.242 | ‖W‖ 595.8 | rank 30.1 | IRM 0.0000 | sc 0.99x
ep 1300 | tr 1.000 | id 0.880 | ood 0.616 | gap +0.264 | ‖W‖ 604.9 | rank 31.7 | IRM 0.0000 | sc 0.98x
ep 1350 | tr 1.000 | id 0.878 | ood 0.649 | gap +0.228 | ‖W‖ 606.3 | rank 30.0 | IRM 0.0000 | sc 0.98x
✓ Checkpoint → ep01400.pt
ep 1400 | tr 1.000 | id 0.879 | ood 0.735 | gap +0.144 | ‖W‖ 624.0 | rank 32.7 | IRM 0.0000 | sc 0.96x
ep 1450 | tr 1.000 | id 0.885 | ood 0.703 | gap +0.182 | ‖W‖ 625.3 | rank 33.0 | IRM 0.0000 | sc 0.98x
ep 1500 | tr 1.000 | id 0.894 | ood 0.686 | gap +0.207 | ‖W‖ 626.2 | rank 30.6 | IRM 0.0000 | sc 0.98x
ep 1550 | tr 1.000 | id 0.879 | ood 0.670 | gap +0.209 | ‖W‖ 640.3 | rank 32.3 | IRM 0.0000 | sc 0.97x
✓ Checkpoint → ep01600.pt
ep 1600 | tr 1.000 | id 0.871 | ood 0.669 | gap +0.202 | ‖W‖ 653.6 | rank 31.0 | IRM 0.0000 | sc 0.98x
ep 1650 | tr 1.000 | id 0.886 | ood 0.540 | gap +0.346 | ‖W‖ 662.7 | rank 34.4 | IRM 0.0000 | sc 0.98x
ep 1700 | tr 1.000 | id 0.897 | ood 0.659 | gap +0.239 | ‖W‖ 667.4 | rank 38.2 | IRM 0.0000 | sc 0.97x
ep 1750 | tr 1.000 | id 0.871 | ood 0.716 | gap +0.155 | ‖W‖ 678.7 | rank 31.7 | IRM 0.0000 | sc 0.97x
✓ Checkpoint → ep01800.pt
ep 1800 | tr 1.000 | id 0.887 | ood 0.665 | gap +0.222 | ‖W‖ 682.0 | rank 33.0 | IRM 0.0000 | sc 0.98x
ep 1850 | tr 1.000 | id 0.892 | ood 0.598 | gap +0.294 | ‖W‖ 685.4 | rank 31.1 | IRM 0.0000 | sc 0.98x
ep 1900 | tr 1.000 | id 0.890 | ood 0.574 | gap +0.316 | ‖W‖ 690.9 | rank 32.8 | IRM 0.0000 | sc 0.99x
ep 1950 | tr 1.000 | id 0.889 | ood 0.623 | gap +0.265 | ‖W‖ 706.3 | rank 30.0 | IRM 0.0000 | sc 0.98x
✓ Checkpoint → ep02000.pt
ep 2000 | tr 1.000 | id 0.890 | ood 0.595 | gap +0.296 | ‖W‖ 714.3 | rank 31.5 | IRM 0.0000 | sc 0.98x
ep 2050 | tr 1.000 | id 0.888 | ood 0.596 | gap +0.292 | ‖W‖ 720.1 | rank 31.4 | IRM 0.0000 | sc 0.99x
ep 2100 | tr 1.000 | id 0.860 | ood 0.686 | gap +0.173 | ‖W‖ 730.1 | rank 30.6 | IRM 0.0000 | sc 0.99x
ep 2150 | tr 1.000 | id 0.892 | ood 0.712 | gap +0.180 | ‖W‖ 732.0 | rank 28.5 | IRM 0.0000 | sc 0.99x
✓ Checkpoint → ep02200.pt
ep 2200 | tr 1.000 | id 0.881 | ood 0.621 | gap +0.260 | ‖W‖ 736.6 | rank 31.0 | IRM 0.0000 | sc 0.99x
ep 2250 | tr 1.000 | id 0.882 | ood 0.621 | gap +0.261 | ‖W‖ 736.5 | rank 29.6 | IRM 0.0000 | sc 0.99x
ep 2300 | tr 1.000 | id 0.877 | ood 0.641 | gap +0.236 | ‖W‖ 748.9 | rank 32.6 | IRM 0.0000 | sc 0.99x
ep 2350 | tr 1.000 | id 0.883 | ood 0.652 | gap +0.230 | ‖W‖ 753.4 | rank 35.2 | IRM 0.0000 | sc 0.99x
✓ Checkpoint → ep02400.pt
ep 2400 | tr 1.000 | id 0.884 | ood 0.673 | gap +0.212 | ‖W‖ 753.4 | rank 34.2 | IRM 0.0000 | sc 1.00x
ep 2450 | tr 1.000 | id 0.887 | ood 0.677 | gap +0.209 | ‖W‖ 755.5 | rank 32.2 | IRM 0.0000 | sc 1.00x
ep 2500 | tr 1.000 | id 0.887 | ood 0.601 | gap +0.286 | ‖W‖ 768.2 | rank 35.0 | IRM 0.0000 | sc 0.98x
ep 2550 | tr 1.000 | id 0.880 | ood 0.645 | gap +0.235 | ‖W‖ 772.8 | rank 35.5 | IRM 0.0000 | sc 0.98x
✓ Checkpoint → ep02600.pt
ep 2600 | tr 1.000 | id 0.884 | ood 0.631 | gap +0.253 | ‖W‖ 774.0 | rank 34.8 | IRM 0.0000 | sc 0.99x
ep 2650 | tr 1.000 | id 0.885 | ood 0.584 | gap +0.301 | ‖W‖ 776.4 | rank 37.4 | IRM 0.0000 | sc 0.99x
ep 2700 | tr 1.000 | id 0.888 | ood 0.597 | gap +0.291 | ‖W‖ 780.4 | rank 36.2 | IRM 0.0000 | sc 0.99x
ep 2750 | tr 1.000 | id 0.895 | ood 0.683 | gap +0.212 | ‖W‖ 786.8 | rank 35.2 | IRM 0.0000 | sc 0.99x
✓ Checkpoint → ep02800.pt
ep 2800 | tr 1.000 | id 0.894 | ood 0.597 | gap +0.297 | ‖W‖ 794.5 | rank 34.0 | IRM 0.0000 | sc 0.99x
ep 2850 | tr 1.000 | id 0.874 | ood 0.623 | gap +0.252 | ‖W‖ 803.4 | rank 32.7 | IRM 0.0000 | sc 0.99x
ep 2900 | tr 1.000 | id 0.891 | ood 0.679 | gap +0.212 | ‖W‖ 808.3 | rank 35.3 | IRM 0.0000 | sc 0.98x
ep 2950 | tr 1.000 | id 0.886 | ood 0.704 | gap +0.182 | ‖W‖ 811.6 | rank 33.7 | IRM 0.0000 | sc 1.00x
✓ Checkpoint → ep03000.pt
ep 3000 | tr 1.000 | id 0.896 | ood 0.648 | gap +0.247 | ‖W‖ 812.6 | rank 33.3 | IRM 0.0000 | sc 0.99x
Best ID val (H3): 0.9011
Best OOD (H4): 0.7615
OOD improvement: -0.0224 ← did OOD grok?
Grokking at: None
IRM drop: 100.0%
Wall time: 327.3 min
```
---
## 20. M5 — Full activation-steering JSONs (8 runs at n=1000)
Every run's full `m5_steering_ep*.json`, verbatim from disk. `head_ood_acc` is head OOD accuracy on the steered features `h' = h + alpha * sigma * v_s`. `tumor_probe` is the linear-probe accuracy on the same steered features. `hospital_probe` is NaN by construction (H4 hospital labels do not overlap with the training-hospital probe's class set).
### `20260505-080445_grokking_n1000_s42`
```json
{
"run_id": "20260505-080445_grokking_n1000_s42",
"epoch": 400,
"layer": "avgpool",
"max_samples": 800,
"v_norm": 1.0,
"sigma": 8.685188293457031,
"sweep": [
{
"alpha": -3.0,
"head_ood_acc": 0.63,
"hospital_probe": NaN,
"tumor_probe": 0.6625
},
{
"alpha": -2.0,
"head_ood_acc": 0.625,
"hospital_probe": NaN,
"tumor_probe": 0.6525
},
{
"alpha": -1.0,
"head_ood_acc": 0.64,
"hospital_probe": NaN,
"tumor_probe": 0.665
},
{
"alpha": -0.5,
"head_ood_acc": 0.6325,
"hospital_probe": NaN,
"tumor_probe": 0.655
},
{
"alpha": 0.0,
"head_ood_acc": 0.6525,
"hospital_probe": NaN,
"tumor_probe": 0.66
},
{
"alpha": 0.5,
"head_ood_acc": 0.6575,
"hospital_probe": NaN,
"tumor_probe": 0.665
},
{
"alpha": 1.0,
"head_ood_acc": 0.6575,
"hospital_probe": NaN,
"tumor_probe": 0.6675
},
{
"alpha": 2.0,
"head_ood_acc": 0.62,
"hospital_probe": NaN,
"tumor_probe": 0.65
},
{
"alpha": 3.0,
"head_ood_acc": 0.59,
"hospital_probe": NaN,
"tumor_probe": 0.635
}
]
}
```
### `20260505-100720_grokking_n1000_s123`
```json
{
"run_id": "20260505-100720_grokking_n1000_s123",
"epoch": 400,
"layer": "avgpool",
"max_samples": 800,
"v_norm": 1.0,
"sigma": 6.000692844390869,
"sweep": [
{
"alpha": -3.0,
"head_ood_acc": 0.7325,
"hospital_probe": NaN,
"tumor_probe": 0.6925
},
{
"alpha": -2.0,
"head_ood_acc": 0.7325,
"hospital_probe": NaN,
"tumor_probe": 0.6925
},
{
"alpha": -1.0,
"head_ood_acc": 0.73,
"hospital_probe": NaN,
"tumor_probe": 0.695
},
{
"alpha": -0.5,
"head_ood_acc": 0.71,
"hospital_probe": NaN,
"tumor_probe": 0.695
},
{
"alpha": 0.0,
"head_ood_acc": 0.6975,
"hospital_probe": NaN,
"tumor_probe": 0.6925
},
{
"alpha": 0.5,
"head_ood_acc": 0.6925,
"hospital_probe": NaN,
"tumor_probe": 0.69
},
{
"alpha": 1.0,
"head_ood_acc": 0.69,
"hospital_probe": NaN,
"tumor_probe": 0.685
},
{
"alpha": 2.0,
"head_ood_acc": 0.68,
"hospital_probe": NaN,
"tumor_probe": 0.685
},
{
"alpha": 3.0,
"head_ood_acc": 0.6575,
"hospital_probe": NaN,
"tumor_probe": 0.685
}
]
}
```
### `20260505-100720_grokking_n1000_s456`
```json
{
"run_id": "20260505-100720_grokking_n1000_s456",
"epoch": 1000,
"layer": "avgpool",
"max_samples": 800,
"v_norm": 1.0,
"sigma": 6.7488112449646,
"sweep": [
{
"alpha": -3.0,
"head_ood_acc": 0.6825,
"hospital_probe": NaN,
"tumor_probe": 0.63
},
{
"alpha": -2.0,
"head_ood_acc": 0.6575,
"hospital_probe": NaN,
"tumor_probe": 0.615
},
{
"alpha": -1.0,
"head_ood_acc": 0.64,
"hospital_probe": NaN,
"tumor_probe": 0.605
},
{
"alpha": -0.5,
"head_ood_acc": 0.63,
"hospital_probe": NaN,
"tumor_probe": 0.595
},
{
"alpha": 0.0,
"head_ood_acc": 0.6225,
"hospital_probe": NaN,
"tumor_probe": 0.595
},
{
"alpha": 0.5,
"head_ood_acc": 0.62,
"hospital_probe": NaN,
"tumor_probe": 0.5925
},
{
"alpha": 1.0,
"head_ood_acc": 0.61,
"hospital_probe": NaN,
"tumor_probe": 0.5925
},
{
"alpha": 2.0,
"head_ood_acc": 0.6025,
"hospital_probe": NaN,
"tumor_probe": 0.5825
},
{
"alpha": 3.0,
"head_ood_acc": 0.5975,
"hospital_probe": NaN,
"tumor_probe": 0.59
}
]
}
```
### `20260508-183413_grokking_n1000_s7`
```json
{
"run_id": "20260508-183413_grokking_n1000_s7",
"epoch": 200,
"layer": "avgpool",
"max_samples": 800,
"v_norm": 0.9999999403953552,
"sigma": 5.532620429992676,
"sweep": [
{
"alpha": -3.0,
"head_ood_acc": 0.495,
"hospital_probe": NaN,
"tumor_probe": 0.58
},
{
"alpha": -2.0,
"head_ood_acc": 0.49,
"hospital_probe": NaN,
"tumor_probe": 0.5625
},
{
"alpha": -1.0,
"head_ood_acc": 0.485,
"hospital_probe": NaN,
"tumor_probe": 0.5425
},
{
"alpha": -0.5,
"head_ood_acc": 0.485,
"hospital_probe": NaN,
"tumor_probe": 0.5325
},
{
"alpha": 0.0,
"head_ood_acc": 0.485,
"hospital_probe": NaN,
"tumor_probe": 0.52
},
{
"alpha": 0.5,
"head_ood_acc": 0.4875,
"hospital_probe": NaN,
"tumor_probe": 0.5125
},
{
"alpha": 1.0,
"head_ood_acc": 0.4875,
"hospital_probe": NaN,
"tumor_probe": 0.52
},
{
"alpha": 2.0,
"head_ood_acc": 0.4775,
"hospital_probe": NaN,
"tumor_probe": 0.51
},
{
"alpha": 3.0,
"head_ood_acc": 0.4825,
"hospital_probe": NaN,
"tumor_probe": 0.51
}
]
}
```
### `20260508-183413_grokking_n1000_s2024`
```json
{
"run_id": "20260508-183413_grokking_n1000_s2024",
"epoch": 400,
"layer": "avgpool",
"max_samples": 800,
"v_norm": 1.0,
"sigma": 9.224357604980469,
"sweep": [
{
"alpha": -3.0,
"head_ood_acc": 0.7425,
"hospital_probe": NaN,
"tumor_probe": 0.69
},
{
"alpha": -2.0,
"head_ood_acc": 0.7775,
"hospital_probe": NaN,
"tumor_probe": 0.69
},
{
"alpha": -1.0,
"head_ood_acc": 0.7325,
"hospital_probe": NaN,
"tumor_probe": 0.69
},
{
"alpha": -0.5,
"head_ood_acc": 0.715,
"hospital_probe": NaN,
"tumor_probe": 0.69
},
{
"alpha": 0.0,
"head_ood_acc": 0.7125,
"hospital_probe": NaN,
"tumor_probe": 0.69
},
{
"alpha": 0.5,
"head_ood_acc": 0.705,
"hospital_probe": NaN,
"tumor_probe": 0.69
},
{
"alpha": 1.0,
"head_ood_acc": 0.68,
"hospital_probe": NaN,
"tumor_probe": 0.69
},
{
"alpha": 2.0,
"head_ood_acc": 0.6225,
"hospital_probe": NaN,
"tumor_probe": 0.69
},
{
"alpha": 3.0,
"head_ood_acc": 0.595,
"hospital_probe": NaN,
"tumor_probe": 0.69
}
]
}
```
### `20260505-100720_standard_n1000_s42`
```json
{
"run_id": "20260505-100720_standard_n1000_s42",
"epoch": 200,
"layer": "avgpool",
"max_samples": 800,
"v_norm": 1.0,
"sigma": 7.23820686340332,
"sweep": [
{
"alpha": -3.0,
"head_ood_acc": 0.585,
"hospital_probe": NaN,
"tumor_probe": 0.6125
},
{
"alpha": -2.0,
"head_ood_acc": 0.6,
"hospital_probe": NaN,
"tumor_probe": 0.62
},
{
"alpha": -1.0,
"head_ood_acc": 0.61,
"hospital_probe": NaN,
"tumor_probe": 0.625
},
{
"alpha": -0.5,
"head_ood_acc": 0.6175,
"hospital_probe": NaN,
"tumor_probe": 0.63
},
{
"alpha": 0.0,
"head_ood_acc": 0.6175,
"hospital_probe": NaN,
"tumor_probe": 0.625
},
{
"alpha": 0.5,
"head_ood_acc": 0.6175,
"hospital_probe": NaN,
"tumor_probe": 0.6225
},
{
"alpha": 1.0,
"head_ood_acc": 0.61,
"hospital_probe": NaN,
"tumor_probe": 0.6275
},
{
"alpha": 2.0,
"head_ood_acc": 0.6025,
"hospital_probe": NaN,
"tumor_probe": 0.6325
},
{
"alpha": 3.0,
"head_ood_acc": 0.5925,
"hospital_probe": NaN,
"tumor_probe": 0.625
}
]
}
```
### `20260508-183413_standard_n1000_s123`
```json
{
"run_id": "20260508-183413_standard_n1000_s123",
"epoch": 200,
"layer": "avgpool",
"max_samples": 800,
"v_norm": 1.0,
"sigma": 13.366204261779785,
"sweep": [
{
"alpha": -3.0,
"head_ood_acc": 0.5425,
"hospital_probe": NaN,
"tumor_probe": 0.565
},
{
"alpha": -2.0,
"head_ood_acc": 0.5975,
"hospital_probe": NaN,
"tumor_probe": 0.615
},
{
"alpha": -1.0,
"head_ood_acc": 0.6725,
"hospital_probe": NaN,
"tumor_probe": 0.68
},
{
"alpha": -0.5,
"head_ood_acc": 0.7125,
"hospital_probe": NaN,
"tumor_probe": 0.71
},
{
"alpha": 0.0,
"head_ood_acc": 0.72,
"hospital_probe": NaN,
"tumor_probe": 0.7375
},
{
"alpha": 0.5,
"head_ood_acc": 0.7175,
"hospital_probe": NaN,
"tumor_probe": 0.7425
},
{
"alpha": 1.0,
"head_ood_acc": 0.6825,
"hospital_probe": NaN,
"tumor_probe": 0.7425
},
{
"alpha": 2.0,
"head_ood_acc": 0.6275,
"hospital_probe": NaN,
"tumor_probe": 0.695
},
{
"alpha": 3.0,
"head_ood_acc": 0.555,
"hospital_probe": NaN,
"tumor_probe": 0.6425
}
]
}
```
### `20260508-183413_standard_n1000_s456`
```json
{
"run_id": "20260508-183413_standard_n1000_s456",
"epoch": 1000,
"layer": "avgpool",
"max_samples": 800,
"v_norm": 1.0,
"sigma": 10.473133087158203,
"sweep": [
{
"alpha": -3.0,
"head_ood_acc": 0.6375,
"hospital_probe": NaN,
"tumor_probe": 0.6475
},
{
"alpha": -2.0,
"head_ood_acc": 0.6575,
"hospital_probe": NaN,
"tumor_probe": 0.66
},
{
"alpha": -1.0,
"head_ood_acc": 0.675,
"hospital_probe": NaN,
"tumor_probe": 0.69
},
{
"alpha": -0.5,
"head_ood_acc": 0.6725,
"hospital_probe": NaN,
"tumor_probe": 0.69
},
{
"alpha": 0.0,
"head_ood_acc": 0.6175,
"hospital_probe": NaN,
"tumor_probe": 0.6575
},
{
"alpha": 0.5,
"head_ood_acc": 0.58,
"hospital_probe": NaN,
"tumor_probe": 0.5975
},
{
"alpha": 1.0,
"head_ood_acc": 0.5375,
"hospital_probe": NaN,
"tumor_probe": 0.585
},
{
"alpha": 2.0,
"head_ood_acc": 0.505,
"hospital_probe": NaN,
"tumor_probe": 0.5225
},
{
"alpha": 3.0,
"head_ood_acc": 0.505,
"hospital_probe": NaN,
"tumor_probe": 0.51
}
]
}
```
---
## 21. M5 — Aggregated sweep tables
### Grokking-favorable
| α | s7 (ep200, σ=5.53) | s42 (ep400, σ=8.69) | s123 (ep400, σ=6.00) | s456 (ep1000, σ=6.75) | s2024 (ep400, σ=9.22) |
| --- | --- | --- | --- | --- | --- |
| −3.0 | 0.4950 | 0.6300 | 0.7325 | 0.6825 | 0.7425 |
| −2.0 | 0.4900 | 0.6250 | 0.7325 | 0.6575 | 0.7775 |
| −1.0 | 0.4850 | 0.6400 | 0.7300 | 0.6400 | 0.7325 |
| −0.5 | 0.4850 | 0.6325 | 0.7100 | 0.6300 | 0.7150 |
| 0.0 | 0.4850 | 0.6525 | 0.6975 | 0.6225 | 0.7125 |
| +0.5 | 0.4875 | 0.6575 | 0.6925 | 0.6200 | 0.7050 |
| +1.0 | 0.4875 | 0.6575 | 0.6900 | 0.6100 | 0.6800 |
| +2.0 | 0.4775 | 0.6200 | 0.6800 | 0.6025 | 0.6225 |
| +3.0 | 0.4825 | 0.5900 | 0.6575 | 0.5975 | 0.5950 |
| Strict mono? | yes | no (α=0 ≥ α=−3) | yes | yes | yes |
### Standard
| α | s42 (ep200, σ=7.24) | s123 (ep200, σ=13.37) | s456 (ep1000, σ=10.47) |
| --- | --- | --- | --- |
| −3.0 | 0.5850 | 0.5425 | 0.6375 |
| −2.0 | 0.6000 | 0.5975 | 0.6575 |
| −1.0 | 0.6100 | 0.6725 | 0.6750 |
| −0.5 | 0.6175 | 0.7125 | 0.6725 |
| 0.0 | 0.6175 | 0.7200 | 0.6175 |
| +0.5 | 0.6175 | 0.7175 | 0.5800 |
| +1.0 | 0.6100 | 0.6825 | 0.5375 |
| +2.0 | 0.6025 | 0.6275 | 0.5050 |
| +3.0 | 0.5925 | 0.5550 | 0.5050 |
| Strict mono? | no (peak at α=0) | no (peak at α=0) | yes |
**Aggregates and statistics**:
- Strict monotonicity (`acc(−3) ≥ acc(0) ≥ acc(+3)`): **4/5 grokking** vs **1/3 standard**.
- Mean σ: grokking 7.24, standard 10.36 (1.43× ratio — the σ-scaling confound).
- Mean Δ(α=0→−3): grokking **+0.0225 ± 0.0276**, standard **−0.0633 ± 0.0835**.
- Mean Δ(α=0→+3): grokking **−0.0500 ± 0.039**, standard **−0.101 ± 0.058**.
- **Fisher exact one-sided p = 0.286**, **Mann-Whitney U one-sided p = 0.071** (continuous statistic, primary), binomial sign test p = 0.188. **None reaches p < 0.05.**
---
## 22. M6 — Full K-sweep results (per-seed, all K)
Aggregated from `paper_figures/m6_summary.csv`. Δ(targ−rand) is `head_OOD(top-K shortcut ablated) − mean(head_OOD over 5 K-random ablations)`. Positive = targeted shortcut ablation beats random.
### Grokking-favorable (n=1000)
| K | s7 | s42 | s123 | s456 | s2024 | N_+ |
| --- | --- | --- | --- | --- | --- | --- |
| 0 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | — |
| 4 | +0.0015 | +0.0045 | +0.0015 | +0.0025 | −0.0005 | 4/5 |
| 8 | −0.0010 | +0.0025 | +0.0010 | 0.0000 | −0.0020 | 2/5 |
| 16 | +0.0005 | −0.0010 | +0.0055 | +0.0035 | −0.0015 | 3/5 |
| 32 | −0.0045 | −0.0025 | +0.0010 | +0.0045 | −0.0020 | 2/5 |
| 64 | −0.0015 | +0.0005 | +0.0115 | +0.0120 | −0.0115 | 3/5 |
| 128 | −0.0060 | +0.0005 | +0.0090 | +0.0015 | −0.0225 | 3/5 |
| 256 | −0.0345 | −0.0010 | +0.0060 | +0.0075 | −0.0205 | 2/5 |
### Standard (n=1000)
| K | s42 | s123 | s456 | N_+ |
| --- | --- | --- | --- | --- |
| 0 | 0.0000 | 0.0000 | 0.0000 | — |
| 4 | +0.0015 | −0.0025 | −0.0025 | 1/3 |
| 8 | −0.0035 | −0.0010 | −0.0005 | 0/3 |
| 16 | −0.0030 | −0.0025 | +0.0025 | 1/3 |
| 32 | −0.0040 | −0.0055 | −0.0060 | 0/3 |
| 64 | −0.0035 | −0.0100 | −0.0040 | 0/3 |
| 128 | −0.0085 | −0.0090 | −0.0025 | 0/3 |
| 256 | −0.0115 | +0.0040 | +0.0075 | 2/3 |
### Per-seed full ablation rows (K=64 and K=256)
```
grokking s42 K= 64: base=0.6575 short=0.6600 rand=0.6595±0.0043 morph=0.6475 Δshort-base=+0.0025 Δtarg-rand=+0.0005
grokking s42 K=256: base=0.6575 short=0.6625 rand=0.6635±0.0054 morph=0.5500 Δshort-base=+0.0050 Δtarg-rand=−0.0010
grokking s123 K= 64: base=0.6825 short=0.6950 rand=0.6835±0.0066 morph=0.6625 Δshort-base=+0.0125 Δtarg-rand=+0.0115
grokking s123 K=256: base=0.6825 short=0.7275 rand=0.7215±0.0179 morph=0.5575 Δshort-base=+0.0450 Δtarg-rand=+0.0060
grokking s456 K= 64: base=0.6450 short=0.6425 rand=0.6305±0.0120 morph=0.6550 Δshort-base=−0.0025 Δtarg-rand=+0.0120
grokking s456 K=256: base=0.6450 short=0.6325 rand=0.6250±0.0105 morph=0.5275 Δshort-base=−0.0125 Δtarg-rand=+0.0075
grokking s7 K= 64: base=0.4925 short=0.4875 rand=0.4890±0.0075 morph=0.4475 Δshort-base=−0.0050 Δtarg-rand=−0.0015
grokking s7 K=256: base=0.4925 short=0.4650 rand=0.4995±0.0056 morph=0.3800 Δshort-base=−0.0275 Δtarg-rand=−0.0345
grokking s2024 K= 64: base=0.7100 short=0.7125 rand=0.7240±0.0086 morph=0.7225 Δshort-base=+0.0025 Δtarg-rand=−0.0115
grokking s2024 K=256: base=0.7100 short=0.7225 rand=0.7430±0.0129 morph=0.5050 Δshort-base=+0.0125 Δtarg-rand=−0.0205
standard s42 K= 64: base=0.6150 short=0.6125 rand=0.6160±0.0034 morph=0.6100 Δshort-base=−0.0025 Δtarg-rand=−0.0035
standard s42 K=256: base=0.6150 short=0.6100 rand=0.6215±0.0020 morph=0.5950 Δshort-base=−0.0050 Δtarg-rand=−0.0115
standard s123 K= 64: base=0.7225 short=0.7150 rand=0.7250±0.0016 morph=0.7125 Δshort-base=−0.0075 Δtarg-rand=−0.0100
standard s123 K=256: base=0.7225 short=0.6900 rand=0.6860±0.0025 morph=0.6975 Δshort-base=−0.0325 Δtarg-rand=+0.0040
standard s456 K= 64: base=0.5975 short=0.5975 rand=0.6015±0.0030 morph=0.5850 Δshort-base= 0.0000 Δtarg-rand=−0.0040
standard s456 K=256: base=0.5975 short=0.5500 rand=0.5425±0.0055 morph=0.5525 Δshort-base=−0.0475 Δtarg-rand=+0.0075
```
**Aggregates**:
- K=64: grokking **3/5** positive Δ(targ−rand) (mean +0.0022 ± 0.0088); standard **0/3** (mean −0.0058 ± 0.0030). Fisher one-sided p = 0.179.
- K=256: grokking 2/5 positive; standard 2/3 positive — essentially symmetric.
- ID accuracy stays within 0.01 of baseline across all targeted ablations.
- Random control averaged over 5 samplings per K (the main M6 weakness).
The complete 88-row reviewer CSV is at `paper_figures/m6_summary.csv` — every run, every K, every condition (baseline / shortcut / random ± sd / morphology).
---
## 23. Exact commands
Training (launched detached under nohup via `scripts/launch.sh`):
```bash
python -u -m experiments.causalgrok_camelyon_v2 \
--condition grokking --n_train 1000 --seed 42 \
--run_dir experiments/runs/<run_id> \
--wandb_project causalgrok --wandb_mode offline
# standard: --condition standard (wd/init/grokfast set automatically by get_config)
```
Mechanistic interpretability:
```bash
python -m experiments.mechinterp_m1 --run_dir experiments/runs/<id> --data_root data/wilds
python -m experiments.mechinterp_m4_ablation --run_dir experiments/runs/<id> --data_root data/wilds --layer avgpool --all_epochs
python -m experiments.mechinterp_m5_steering --run_dir experiments/runs/<id> --data_root data/wilds
python -m experiments.mechinterp_m6_neuron_ablation --run_dir experiments/runs/<id> --data_root data/wilds \
--ks "0,4,8,16,32,64,128,256"
```
Figures:
```bash
python -m experiments.regenerate_all_figures # rebuilds all 7 paper figures from saved JSON in <2 min
```
---
## 24. Output layout (per run)
```
experiments/runs/<run_id>/
├── config.json # full hyperparameter config
├── run.pid
├── checkpoints/
│ ├── ep00200.pt … ep03000.pt # 15 periodic checkpoints, ~44 MB each
│ └── final.pt
├── logs/
│ ├── train.log # launch cmd + per-checkpoint log lines
│ └── train.err
├── results/
│ ├── history.json # 61 per-checkpoint metric rows
│ └── summary.json # final-summary fields
├── wandb/ # offline wandb run metadata
└── mechinterp/
├── m1_probe_data.json + heatmap/curves PNG
├── m4_ablation_avgpool_trajectory.json + PNG
├── m5_steering_ep<E>.json + PNG
└── m6_neuron_ablation_ep<E>.json + PNG
```
Aggregated: `paper_figures/m6_summary.csv` (88 data rows), `paper_figures/*.{png,pdf}` (7 figures).
Checkpoints (~10 GB, 240 `.pt` files) are mirrored to Hugging Face at `nileshsarkar-ai/CausalGrok`.
---
*All numerical values in this document were read directly from on-disk `config.json` / `results/summary.json` / `results/history.json` / `mechinterp/*.json` / `paper_figures/m6_summary.csv`. Training logs in §18 and §19 are verbatim from each run's `logs/train.log`. Source code in §10–§15 is the exact code that ran for every reported result.*