# CausalGrok — Complete Training, Evaluation, and Mechanistic-Interpretability Reference Paper: *Interventional Analysis of Shortcut Geometry Under Grokking-Favorable Training*. This document is the **complete preservation archive** of the project. It contains the full source of every script that ran, every per-run config and summary, the full training logs of two reference runs, every M5 activation-steering result, the full M6 K-sweep data, and all formulas. All numerical values are read directly from on-disk `config.json` / `results/summary.json` / `results/history.json` / `mechinterp/*.json` / `paper_figures/m6_summary.csv`. All source code is the exact code that produced every reported result. **Contents** 1. Environment 2. Dataset and data pipeline 3. Model architecture and initialization 4. Hyperparameter tables 5. Loss function (cross-entropy) — formula 6. IRM penalty (diagnostic) — formula 7. Grokfast EMA — formula 8. Training loop and checkpointing 9. Evaluation metrics — formulas 10. Full source: `utils/grokfast.py` 11. Full source: `experiments/causalgrok_camelyon_v2.py` 12. Full source: `experiments/mechinterp_m1.py` 13. Full source: `experiments/mechinterp_m4_ablation.py` 14. Full source: `experiments/mechinterp_m5_steering.py` 15. Full source: `experiments/mechinterp_m6_neuron_ablation.py` 16. Run inventory and summary results (14 runs) 17. Per-run `config.json` and `summary.json` (all 14 runs) 18. Full training log: grokking n=1000 seed=42 19. Full training log: standard n=1000 seed=42 20. M5 — Full activation-steering JSONs (8 runs at n=1000) 21. M5 — Aggregated sweep tables 22. M6 — Full K-sweep results (per-seed, all K) 23. Exact commands 24. Output layout --- ## 1. Environment | Item | Value | | --- | --- | | Python env | `conda env: causalgrok` (Python 3.10) | | Framework | PyTorch, `timm`, `torchvision`, `scikit-learn`, `numpy`, `wandb` (offline) | | Model | `timm.create_model("resnet18", pretrained=False, num_classes=2)` | | Device | CUDA (NVIDIA A100 80GB PCIe) | | Precision | TF32 (`set_float32_matmul_precision("high")`, `cudnn.benchmark=True`, `allow_tf32=True`) | | Params | 11,177,538 | | Dataset | Camelyon17 via WILDS (`utils.camelyon_data.get_camelyon_subsets`, auto-download) | | Wall time per run | ~5.5 h (grokking n=1000 s42, 3000 epochs, A100) | ## 2. Dataset and data pipeline Camelyon17 (WILDS), H&E-stained histopathology patches, binary tumor label. Five hospitals; the WILDS split provides train hospitals, an in-distribution validation set, and a held-out OOD test hospital. - **ID validation**: 33,560 images. - **OOD test (held-out hospital)**: 85,054 images. - **Train**: subsampled to `n_train`. The n=1000 seed-42 run drew hospitals `{0, 3, 4}` with 181 / 371 / 448 samples (positive rates 0.53 / 0.48 / 0.50). Transforms: `Resize((96, 96))`, `ToTensor`, `Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])`. Train loader: `batch_size=32, shuffle=True`; eval loaders: `batch_size=256, shuffle=False`; `num_workers=0`, `pin_memory=True`. IRM environments — one `{x, y}` dict per unique training hospital, built once from `metadata[:, 0]`. ## 3. Model architecture and initialization ResNet-18 (`timm`, no ImageNet pretraining), 96×96 input, 2-class head, 11,177,538 parameters. The grokking-favorable regime multiplies every multi-dim weight tensor by `init_scale = 4.0` at initialization; standard uses `init_scale = 1.0` (no rescaling). `avgpool` feature dimension `D = 512`. Six probed stages: `stem, layer1, layer2, layer3, layer4, avgpool`. ## 4. Hyperparameter tables | Hyperparameter | Standard | Grokking-favorable | | --- | --- | --- | | Optimizer | AdamW | AdamW | | Learning rate | 1e-3 | 1e-3 | | Weight decay | 1e-4 | 5e-3 (50×) | | Epochs | 3000 | 3000 | | Init scale | 1.0 | 4.0 | | Grokfast EMA | off | on | | Grokfast alpha (EMA decay) | — | 0.98 | | Grokfast lamb (slow-grad amplification) | — | 2.0 | | Gradient clip (max-norm) | 1.0 | 1.0 | | Batch size | 32 | 32 | | Image size | 96 × 96 | 96 × 96 | | `log_every` (metric cadence) | 50 epochs | 50 epochs | | Checkpoint cadence | 200 epochs | 200 epochs | | IRM weight in loss | 0.0 | 0.0 | **Three-axis confound**: weight decay (50× ratio), init scale (4× ratio), and Grokfast EMA (on vs off) differ simultaneously between regimes. No single-axis ablation was run. ## 5. Loss function — cross-entropy ``` L = CE(f_theta(x), y) = - (1/B) * sum_i log [ exp(z_{i, y_i}) / sum_k exp(z_{i, k}) ] ``` where `z = f_theta(x)` are the logits. The training objective is pure CE for every reported run (`irm_weight = 0.0`). ## 6. IRM penalty (diagnostic, not in loss) IRMv1 penalty (Arjovsky et al. 2019): ``` IRM_e = || grad_{w=1.0} CE( w · f_theta(x^e), y^e ) ||^2 IRM = (1/|E|) sum_{e in E} IRM_e ``` Logged at every checkpoint as `irm_mean`, `irm_var`. Collapses from ~0.20 to ~1e-13 within ~50–150 epochs in every run (a CE-memorization consequence, not invariance learning). ## 7. Grokfast EMA (grokking-favorable only) Grokfast (Lee et al. 2024, arXiv:2405.20233): ``` g_ema <- alpha * g_ema + (1 - alpha) * g # alpha = 0.98 g <- g + lamb * g_ema # lamb = 2.0 ``` Applied after `loss.backward()`, before `optimizer.step()`. ## 8. Training loop and checkpointing - Loop over `epoch ∈ [1, 3000]`. - Per minibatch: forward, CE loss, `loss.backward()`, Grokfast filter (grokking only), `clip_grad_norm_(max_norm=1.0)`, `optimizer.step()`. - Metrics computed every 50 epochs (+ epoch 1) → `results/history.json` (61 rows/run). - Checkpoints every 200 epochs → 15 `.pt` per run + `final.pt`. - OOD-aware early stopping exists but defaults off; all reported runs train the full 3000 epochs. - Grokking detection watches OOD-accuracy plateau-then-jump; `grokking_epoch = -1` (never fires) for every run. ## 9. Evaluation metrics — formulas | Metric | Definition | | --- | --- | | Accuracy | `argmax` accuracy on a loader; `train_acc`, `id_val_acc`, `ood_acc`. `ood_gap = id_val_acc - ood_acc`. | | Weight norm | `||W|| = sqrt(sum_p ||p||_2^2)` | | Effective feature rank | `exp(- sum_i ŝ_i log ŝ_i)` where `ŝ_i = s_i / sum_j s_j` are the normalized SVD singular values of `layer4` avgpool features on ≤300 samples. | | Shortcut ratio | `min(border_conf / center_conf, 10)`. `>1` = stain-reliant, `<1` = tissue-reliant. | | IRM penalty | see §6 — computed across the 3 training-hospital environments per epoch. | Summary fields (`summary.json`): `best_id_val`, `best_ood`, `peak_ood_epoch`, `final_ood`, `ood_delta` (= `final_ood − best_ood`, ungrokking signal), `ood_improvement`, `grokking_epoch`, `irm_drop_pct`, `irm_drop_epoch`, `epoch_gap`, `final_weight_norm`, `final_feature_rank`, `final_irm`, `final_shortcut_ratio`, `final_ood_gap`. --- ## 10. Full source: `utils/grokfast.py` ```python """ utils.grokfast — accelerated grokking by amplifying slow-varying gradient components (Lee et al. 2024, arXiv:2405.20233). Maintain an EMA of gradients across steps; the slow-EMA component corresponds to the generalising circuit. Adding it back into the live gradient (scaled by `lamb`) accelerates the grokking transition 20-100×. """ from __future__ import annotations def gradfilter_ema(model, grads_ema, alpha: float = 0.98, lamb: float = 2.0): """ Call this AFTER `loss.backward()` and BEFORE `optimizer.step()`. Args: model: the network whose gradients we are filtering. grads_ema: dict {param_name: ema_grad}, or None on the first call. alpha: EMA decay (0.98 → very slow, emphasises persistent grads). lamb: amplification factor for the slow component. Returns: Updated `grads_ema` dict — pass it back in on the next step. """ if grads_ema is None: grads_ema = {} for name, p in model.named_parameters(): if p.requires_grad and p.grad is not None: if name not in grads_ema: grads_ema[name] = p.grad.data.detach().clone() else: grads_ema[name] = ( grads_ema[name] * alpha + p.grad.data.detach() * (1 - alpha) ) p.grad.data = p.grad.data + grads_ema[name] * lamb return grads_ema ``` --- ## 11. Full source: `experiments/causalgrok_camelyon_v2.py` ```python """ CausalGrok — Camelyon17 Training Loop v2 Nilesh KEY CHANGE FROM v1: OOD test accuracy (H4 — unseen hospital) is now tracked at EVERY checkpoint, not just at the end. Grokking detection watches OOD acc, not ID val acc. This is the correct signal. The paper claim: after ID accuracy converges (fast, expected), the model undergoes a delayed phase transition in OOD generalization — grokking the cross-hospital invariant causal features. This co-occurs with a drop in IRM penalty. That is the grokking we care about for clinical deployment. Two curves to watch: val_acc (H3 ID val) — converges fast, expected ~0.86 by ep 50 ood_acc (H4 OOD test) — should plateau then JUMP (the grokking) Run via: python -m experiments.causalgrok_camelyon_v2 --condition grokking --n_train 300 """ from __future__ import annotations import argparse import json import os import time from datetime import datetime, timezone import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from torch.utils.data import DataLoader, Subset import timm try: import wandb except ImportError: wandb = None from utils.grokfast import gradfilter_ema from utils.camelyon_data import get_camelyon_subsets from utils.run_dir import make_run_dir, ensure_run_dir, save_config # ────────────────────────────────────────────── # CONFIG # ────────────────────────────────────────────── def get_config(condition): base = dict( seed=42, n_train=300, batch_size=32, img_size=96, n_classes=2, log_every=50, device="cuda" if torch.cuda.is_available() else "cpu", ) if condition == "standard": base.update(dict( condition="standard", lr=1e-3, weight_decay=1e-4, # Default 3000 epochs to match grokking config and the # paper's reported runs; previously defaulted to 300 which # made the standard baseline trivially under-trained # relative to grokking. See paper Limitations §M3. n_epochs=3000, init_scale=1.0, use_grokfast=False, )) elif condition == "grokking": base.update(dict( condition="grokking", lr=1e-3, weight_decay=5e-3, n_epochs=3000, init_scale=4.0, use_grokfast=True, grokfast_alpha=0.98, grokfast_lamb=2.0, )) return base # ────────────────────────────────────────────── # WILDS-SAFE METRICS # All handle the (imgs, labels, metadata) 3-tuple WILDS batch format. # ────────────────────────────────────────────── @torch.no_grad() def accuracy_wilds(model, loader, device, max_samples=None): model.eval() correct = total = 0 for batch in loader: imgs = batch[0].to(device) labels = batch[1].squeeze().long().to(device) preds = model(imgs).argmax(1) correct += (preds == labels).sum().item() total += labels.size(0) if max_samples and total >= max_samples: break return correct / max(total, 1) @torch.no_grad() def weight_norm_fn(model): return sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 @torch.no_grad() def feature_rank_wilds(model, loader, device, n=300): model.eval() feats = [] def hook_fn(module, input, output): avg_pool = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1)) feats.append(avg_pool.view(avg_pool.size(0), -1).cpu()) hook = model.layer4[-1].register_forward_hook(hook_fn) count = 0 for batch in loader: model(batch[0].to(device)) count += batch[0].size(0) if count >= n: break hook.remove() if not feats: return float("nan") F_mat = torch.cat(feats)[:n] try: _, s, _ = torch.svd(F_mat) s = s / (s.sum() + 1e-10) return torch.exp(-(s * torch.log(s + 1e-10)).sum()).item() except Exception: return float("nan") @torch.no_grad() def shortcut_ratio_wilds(model, loader, device, n_samples=200): """ Stain shortcut proxy: compare model confidence on center crop (tissue — causal features) vs. border region (stain — spurious). sc > 1.0 = relying on border stain more than tissue (shortcut) sc < 1.0 = relying on tissue center more than stain (causal) The transition from > 1.0 to < 1.0 during training is the attribution-level signature of the grokking transition. """ model.eval() cc, bc = [], [] count = 0 for batch in loader: if count >= n_samples: break imgs = batch[0].to(device) B, C, H, W = imgs.shape hs, he = H // 4, 3 * H // 4 ws, we = W // 4, 3 * W // 4 center = F.interpolate( imgs[:, :, hs:he, ws:we], size=(H, W), mode="bilinear", align_corners=False ) border = imgs.clone() border[:, :, hs:he, ws:we] = 0.0 cc.append(F.softmax(model(center), 1).max(1).values.mean().item()) bc.append(F.softmax(model(border), 1).max(1).values.mean().item()) count += imgs.size(0) cconf = float(np.mean(cc)) if cc else 0.5 bconf = float(np.mean(bc)) if bc else 0.5 return cconf, bconf def irm_penalty_wilds(model, envs, device): """ IRMv1 penalty across TRAINING hospital environments (H0-H2). Diagnostic version: uses create_graph=False, returns floats. Used as a monitoring metric only (logged per epoch). """ model.eval() penalties = [] for env in envs: w = torch.tensor(1.0, requires_grad=True, device=device) logits = model(env["x"]) * w loss = F.cross_entropy(logits, env["y"]) grad = torch.autograd.grad(loss, w, create_graph=False)[0] penalties.append(grad.item() ** 2) t = torch.tensor(penalties) return t.mean().item(), t.var().item() def irm_penalty_train_time(logits_list, y_list): """ IRMv1 penalty for use INSIDE the training loss (differentiable). Splits a batch by environment, computes per-env loss with a virtual scale variable, takes the squared gradient of each per-env loss w.r.t. that scale, returns the mean across envs. Args: logits_list: list of (per-env) logits tensors y_list: list of (per-env) label tensors Returns: scalar tensor (differentiable), the IRM penalty contribution. """ penalty = 0.0 n = 0 for logits, y in zip(logits_list, y_list): if logits.shape[0] == 0: continue scale = torch.tensor(1.0, requires_grad=True, device=logits.device) loss = F.cross_entropy(logits * scale, y) grad = torch.autograd.grad(loss, scale, create_graph=True)[0] penalty = penalty + grad ** 2 n += 1 if n == 0: return torch.tensor(0.0, device=logits_list[0].device) return penalty / n def eval_irm_penalty_wilds(model, id_val_loader, ood_test_loader, device): """ IRM penalty evaluated on HELD-OUT environments (H3 and H4). This avoids the measurement artifact of training on H0-H2 where loss→0. HIGH penalty = model relies on hospital-discriminating features = shortcuts. LOW penalty = model ignores hospital labels = causal features. """ model.eval() penalties = [] # Create environment views from eval data for loader, hospital_label in [ (id_val_loader, "H3"), (ood_test_loader, "H4"), ]: xs, ys = [], [] count = 0 with torch.no_grad(): for batch in loader: imgs = batch[0].to(device) labels = batch[1].squeeze().long().to(device) xs.append(model(imgs)) ys.append(labels) count += imgs.size(0) if count >= 500: break if xs: x = torch.cat(xs) y = torch.cat(ys) w = torch.tensor(1.0, requires_grad=True, device=device) logits = x * w loss = F.cross_entropy(logits, y) try: grad = torch.autograd.grad(loss, w, create_graph=False)[0] penalties.append(grad.item() ** 2) except: penalties.append(float("nan")) if penalties and not any(np.isnan(p) for p in penalties): return float(np.mean(penalties)), float(np.var(penalties)) else: return float("nan"), float("nan") # ────────────────────────────────────────────── # DATA # ────────────────────────────────────────────── class TransformWrapper: def __init__(self, dataset, transform): self.dataset = dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, idx): img, label, metadata = self.dataset[idx] return self.transform(img), label, metadata def get_dataloaders(cfg, data_root): transform = transforms.Compose([ transforms.Resize((cfg["img_size"], cfg["img_size"])), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets( root_dir=data_root, download=True) # Subsample training set torch.manual_seed(cfg["seed"]) indices = torch.randperm(len(train_ds))[:cfg["n_train"]] train_subset = Subset(train_ds, indices) # Wrap with TransformWrapper to apply transforms train_subset = TransformWrapper(train_subset, transform) id_val_ds = TransformWrapper(id_val_ds, transform) ood_test_ds = TransformWrapper(ood_test_ds, transform) train_loader = DataLoader(train_subset, batch_size=cfg["batch_size"], shuffle=True, num_workers=0, pin_memory=True) id_val_loader = DataLoader(id_val_ds, batch_size=256, shuffle=False, num_workers=0, pin_memory=True) ood_test_loader = DataLoader(ood_test_ds, batch_size=256, shuffle=False, num_workers=0, pin_memory=True) return train_loader, id_val_loader, ood_test_loader, train_subset def get_hospital_environments(train_subset, device): """ Build IRM environments from ground-truth hospital labels. Returns list of {x, y} dicts — one per unique hospital in the subset. Hospitals in Camelyon17 train split: 0, 1, 2. """ loader = DataLoader(train_subset, batch_size=512, shuffle=False, num_workers=4) all_imgs, all_labels, all_meta = [], [], [] for imgs, labels, meta in loader: all_imgs.append(imgs) all_labels.append(labels.squeeze().long()) all_meta.append(meta) all_imgs = torch.cat(all_imgs) all_labels = torch.cat(all_labels) hospitals = torch.cat(all_meta)[:, 0].long() # field 0 = hospital ID envs = [] for h in torch.unique(hospitals): mask = hospitals == h n = mask.sum().item() envs.append({ "x": all_imgs[mask].to(device), "y": all_labels[mask].to(device), "hospital": int(h), }) pos_rate = all_labels[mask].float().mean().item() print(f" Env hospital={int(h)}: {n} samples, " f"positive rate={pos_rate:.2f}") return envs # ────────────────────────────────────────────── # MODEL # ────────────────────────────────────────────── def build_model(cfg): model = timm.create_model("resnet18", pretrained=False, num_classes=cfg["n_classes"]) if cfg["init_scale"] != 1.0: with torch.no_grad(): for name, p in model.named_parameters(): if "weight" in name and p.dim() > 1: p.data *= cfg["init_scale"] return model.to(cfg["device"]) # ────────────────────────────────────────────── # TRAIN # ────────────────────────────────────────────── def train(cfg, model, train_loader, id_val_loader, ood_test_loader, envs, optimizer, run_dir): criterion = nn.CrossEntropyLoss() grads_ema = None history = [] best_id_val = 0.0 best_ood = 0.0 peak_ood_epoch = None # Epoch where best_ood was achieved grok_epoch = None irm_base = None history_path = os.path.join(run_dir, "results", "history.json") grad_clip = cfg.get("grad_clip", 1.0) # Grokking detection parameters. # We watch OOD accuracy (H4), not ID val accuracy (H3). # ID val converges fast (expected). OOD is what should grok. plateau_window = 10 plateau_eps = 0.01 # Ungrokking early stopping parameters. # If OOD peaks then declines, stop at the peak rather than training to convergence. ood_patience = cfg.get("ood_patience", 20) # checkpoints to wait before stopping ood_min_delta = cfg.get("ood_min_delta", 0.01) # minimum improvement threshold use_ood_early_stop = cfg.get("use_ood_early_stop", False) print(f"\n{'='*60}") print(f" {cfg['condition'].upper()} | Camelyon17 v2 | {cfg['n_epochs']} epochs") print(f" WD={cfg['weight_decay']} | α={cfg['init_scale']} | n={cfg['n_train']}") print(f" Tracking: ID val (H3) + OOD test (H4) at every checkpoint") print(f" Grokking detection: watching OOD acc, not ID val acc") print(f" IRM envs: {len(envs)} hospitals") print(f"{'='*60}", flush=True) irm_weight = float(cfg.get("irm_weight", 0.0)) use_irm_in_loss = irm_weight > 0.0 if use_irm_in_loss: print(f" IRM-in-loss: ENABLED, alpha={irm_weight}", flush=True) else: print(f" IRM-in-loss: disabled (CE-only training; IRM penalty is diagnostic)", flush=True) for epoch in range(1, cfg["n_epochs"] + 1): # ── Train step ──────────────────────────────────────────────── model.train() loss_sum = n_b = 0 for imgs, labels, metadata in train_loader: imgs = imgs.to(cfg["device"]) labels = labels.squeeze().long().to(cfg["device"]) optimizer.zero_grad() logits = model(imgs) ce_loss = criterion(logits, labels) if use_irm_in_loss: # Split this batch by training hospital (H0/H1/H2) and # compute IRMv1 penalty as a differentiable scalar. hosp_ids = metadata[:, 0].long().to(cfg["device"]) logits_per_env, y_per_env = [], [] for h in [0, 1, 2]: mask = (hosp_ids == h) if mask.sum() < 2: continue logits_per_env.append(logits[mask]) y_per_env.append(labels[mask]) if len(logits_per_env) >= 2: irm_term = irm_penalty_train_time(logits_per_env, y_per_env) loss = ce_loss + irm_weight * irm_term else: loss = ce_loss else: loss = ce_loss loss.backward() if cfg.get("use_grokfast"): grads_ema = gradfilter_ema( model, grads_ema, alpha=cfg.get("grokfast_alpha", 0.98), lamb=cfg.get("grokfast_lamb", 2.0)) if grad_clip > 0: torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=grad_clip) optimizer.step() loss_sum += loss.item() n_b += 1 # ── Checkpoint metrics ──────────────────────────────────────── if epoch % cfg["log_every"] == 0 or epoch == 1: tr_acc = accuracy_wilds(model, train_loader, cfg["device"]) id_acc = accuracy_wilds(model, id_val_loader, cfg["device"]) ood_acc = accuracy_wilds(model, ood_test_loader, cfg["device"]) # KEY wn = weight_norm_fn(model) fr = feature_rank_wilds(model, id_val_loader, cfg["device"]) irm_m, irm_v = irm_penalty_wilds(model, envs, cfg["device"]) cconf, bconf = shortcut_ratio_wilds( model, id_val_loader, cfg["device"]) if irm_base is None: irm_base = irm_m # ── OOD grokking detection ──────────────────────────────── # Require sustained plateau in OOD acc before the jump. # The ID val acc plateau is expected and not grokking. if grok_epoch is None and len(history) >= plateau_window: last = history[-plateau_window:] ref = last[-1]["ood_acc"] flat = sum(1 for r in last if abs(r["ood_acc"] - ref) < plateau_eps) if flat >= plateau_window - 2 and ood_acc > best_ood + 0.05: grok_epoch = epoch irm_drop = (irm_base - irm_m) / (irm_base + 1e-8) * 100 print(f"\n *** OOD GROKKING at epoch {epoch} ***") print(f" OOD: {best_ood:.3f} → {ood_acc:.3f} | " f"IRM drop: {irm_drop:.1f}%", flush=True) if id_acc > best_id_val: best_id_val = id_acc if ood_acc > best_ood: best_ood = ood_acc peak_ood_epoch = epoch # Track when peak was achieved sc_ratio = min(bconf / (cconf + 1e-8), 10.0) # OOD gap: how much worse is OOD vs ID? # This should shrink at the grokking transition. ood_gap = id_acc - ood_acc row = dict( epoch = epoch, train_loss = loss_sum / n_b, train_acc = tr_acc, id_val_acc = id_acc, ood_acc = ood_acc, # ← primary grokking signal ood_gap = ood_gap, # ← should narrow at transition weight_norm = wn, feature_rank = fr, irm_mean = irm_m, irm_var = irm_v, center_conf = cconf, border_conf = bconf, shortcut_ratio = sc_ratio, grokking_detected = grok_epoch is not None, ) history.append(row) if wandb: wandb.log(row) with open(history_path, "w") as f: json.dump(history, f, indent=2) # Save periodic checkpoint for M1 analysis (every 200 epochs) if epoch % 200 == 0: ckpt_dir = os.path.join(run_dir, "checkpoints") os.makedirs(ckpt_dir, exist_ok=True) ckpt_path = os.path.join(ckpt_dir, f"ep{epoch:05d}.pt") torch.save(model.state_dict(), ckpt_path) print(f" ✓ Checkpoint → ep{epoch:05d}.pt", flush=True) # ── OOD-aware early stopping (if ungrokking detected) ─────── # If OOD peaks then declines, stop at the peak rather than full epochs. if use_ood_early_stop and peak_ood_epoch is not None and len(history) >= ood_patience: recent_ood = [r["ood_acc"] for r in history[-ood_patience:]] ood_trend = max(recent_ood) - min(recent_ood) if ood_acc < best_ood - ood_min_delta: print(f"\n *** EARLY STOP (OOD declining) at epoch {epoch} ***", flush=True) print(f" Peak OOD: {best_ood:.4f} at epoch {peak_ood_epoch}", flush=True) print(f" Current: {ood_acc:.4f} ({ood_acc-best_ood:+.4f})", flush=True) # Save peak checkpoint separately for clinical deployment if peak_ood_epoch and peak_ood_epoch % 200 == 0: peak_src = os.path.join(run_dir, "checkpoints", f"ep{peak_ood_epoch:05d}.pt") peak_dst = os.path.join(run_dir, "checkpoints", "peak_ood.pt") if os.path.exists(peak_src): import shutil shutil.copy(peak_src, peak_dst) print(f" Saved peak → checkpoints/peak_ood.pt", flush=True) break # Exit training loop print(f" ep {epoch:5d} | " f"tr {tr_acc:.3f} | " f"id {id_acc:.3f} | " f"ood {ood_acc:.3f} | " f"gap {ood_gap:+.3f} | " # + means OOD worse than ID f"‖W‖ {wn:.1f} | " f"rank {fr:.1f} | " f"IRM {irm_m:.4f} | " f"sc {sc_ratio:.2f}x", flush=True) # ── Final summary ───────────────────────────────────────────────── # One final OOD eval at the very end final_ood = accuracy_wilds(model, ood_test_loader, cfg["device"]) if wandb: wandb.log({"final_ood_acc": final_ood, "grokking_epoch": grok_epoch or -1}) # Decision numbers irm_drop_pct = float("nan") irm_drop_ep = epoch_gap = -1 if history: irm0 = history[0]["irm_mean"] irm_min = min(r["irm_mean"] for r in history) if irm0: irm_drop_pct = (irm0 - irm_min) / (irm0 + 1e-8) * 100 if len(history) > 1: biggest = 0.0 for prev, cur in zip(history[:-1], history[1:]): d = abs(cur["irm_mean"] - prev["irm_mean"]) if d > biggest: biggest = d irm_drop_ep = cur["epoch"] if grok_epoch and irm_drop_ep > 0: epoch_gap = abs(grok_epoch - irm_drop_ep) # OOD grokking: did OOD acc improve significantly after ID convergence? # Measure: max OOD acc in last 20% of training vs. OOD acc when ID # first plateaued (epoch ~200-300 for standard training). ood_early = np.mean([r["ood_acc"] for r in history[:5]]) if history else 0 ood_late = np.mean([r["ood_acc"] for r in history[-5:]]) if history else 0 ood_improvement = ood_late - ood_early # Ungrokking detection: did OOD collapse after peaking? ood_delta = final_ood - best_ood # Negative = ungrokking summary = dict( run_id = cfg["run_id"], condition = cfg["condition"], n_train = cfg["n_train"], seed = cfg["seed"], best_id_val = best_id_val, best_ood = best_ood, peak_ood_epoch = peak_ood_epoch or -1, # When peak was achieved final_ood = final_ood, ood_delta = ood_delta, # final - best (ungrokking signal) ood_improvement = ood_improvement, # ← key: did OOD grok? grokking_epoch = grok_epoch or -1, irm_drop_pct = irm_drop_pct, irm_drop_epoch = irm_drop_ep, epoch_gap = epoch_gap, final_weight_norm = history[-1]["weight_norm"] if history else None, final_feature_rank= history[-1]["feature_rank"] if history else None, final_irm = history[-1]["irm_mean"] if history else None, final_shortcut_ratio = history[-1]["shortcut_ratio"] if history else None, final_ood_gap = history[-1]["ood_gap"] if history else None, ) with open(os.path.join(run_dir, "results", "summary.json"), "w") as f: json.dump(summary, f, indent=2) torch.save(model.state_dict(), os.path.join(run_dir, "checkpoints", "final.pt")) print(f"\n Best ID val (H3): {best_id_val:.4f}") print(f" Best OOD (H4): {best_ood:.4f}") print(f" OOD improvement: {ood_improvement:+.4f} ← did OOD grok?") print(f" Grokking at: {grok_epoch}") print(f" IRM drop: {irm_drop_pct:.1f}%", flush=True) return history # ────────────────────────────────────────────── # MAIN # ────────────────────────────────────────────── def main(): p = argparse.ArgumentParser() p.add_argument("--condition", default="grokking", choices=["standard", "grokking"]) p.add_argument("--n_train", type=int, default=300) p.add_argument("--seed", type=int, default=42) p.add_argument("--log_every", type=int, default=50) p.add_argument("--wandb_project", default="causalgrok") p.add_argument("--wandb_mode", default="offline", choices=["online", "offline", "disabled"]) p.add_argument("--run_dir", default=None) p.add_argument("--data_root", default="data/wilds") p.add_argument("--weight_decay", type=float, default=None) p.add_argument("--init_scale", type=float, default=None) p.add_argument("--n_epochs", type=int, default=None) p.add_argument("--lr", type=float, default=None) p.add_argument("--grokfast", choices=["on", "off"], default=None) p.add_argument("--grad_clip", type=float, default=1.0) p.add_argument("--irm_weight", type=float, default=0.0, help="IRMv1 penalty weight added to training loss " "(0 = pure cross-entropy / diagnostic-only IRM).") args = p.parse_args() cfg = get_config(args.condition) cfg.update(n_train=args.n_train, seed=args.seed, log_every=args.log_every, grad_clip=args.grad_clip) if args.weight_decay is not None: cfg["weight_decay"] = args.weight_decay if args.init_scale is not None: cfg["init_scale"] = args.init_scale if args.n_epochs is not None: cfg["n_epochs"] = args.n_epochs if args.lr is not None: cfg["lr"] = args.lr if args.grokfast is not None: cfg["use_grokfast"] = (args.grokfast == "on") cfg["irm_weight"] = args.irm_weight if cfg["device"] == "cuda": torch.set_float32_matmul_precision("high") torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.manual_seed(cfg["seed"]) np.random.seed(cfg["seed"]) if args.run_dir is None: run_dir, run_id = make_run_dir( ["camelyon_v2", cfg["condition"], f"n{cfg['n_train']}", f"s{cfg['seed']}"]) else: run_dir = args.run_dir ensure_run_dir(run_dir) run_id = os.path.basename(os.path.normpath(run_dir)) cfg["run_id"] = run_id cfg["run_dir"] = run_dir save_config(cfg, run_dir) if wandb: wandb.init(project=args.wandb_project, config=cfg, name=run_id, mode=args.wandb_mode, dir=run_dir) print(f"\nDevice: {cfg['device']}") print(f"Run ID: {run_id}") print(f"Started: {datetime.now(timezone.utc).isoformat()}", flush=True) train_loader, id_val_loader, ood_test_loader, train_subset = \ get_dataloaders(cfg, args.data_root) envs = get_hospital_environments(train_subset, cfg["device"]) model = build_model(cfg) print(f"Train: {len(train_subset)} | " f"ID val (H3): {len(id_val_loader.dataset)} | " f"OOD test (H4): {len(ood_test_loader.dataset)}") print(f"Params: {sum(p.numel() for p in model.parameters()):,}", flush=True) optimizer = torch.optim.AdamW( model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"]) t0 = time.time() train(cfg, model, train_loader, id_val_loader, ood_test_loader, envs, optimizer, run_dir) print(f"\nWall time: {(time.time()-t0)/60:.1f} min", flush=True) if wandb: wandb.finish() if __name__ == "__main__": main() ``` --- ## 12. Full source: `experiments/mechinterp_m1.py` ```python """ CausalGrok — M1: Layer-wise Linear Probing Nilesh The mechanistic claim: Before grokking: hospital probe HIGH (model uses stain shortcut), tumor probe LOW in early layers At transition: hospital probe DROPS, tumor probe RISES After grokking: inverted — tumor high, hospital low If OOD acc jump + hospital probe drop + tumor probe rise all happen at the same epoch → mechanistic claim proven. That's Figure 2 of the paper. Usage: # Run on all saved checkpoints from a run python -m experiments.mechinterp_m1 \ --run_dir experiments/runs/ \ --data_root data/wilds # Run on latest checkpoint only (quick check while training) python -m experiments.mechinterp_m1 \ --run_dir experiments/runs/ \ --data_root data/wilds \ --latest_only # Run on ALL camelyon_v2 grokking runs python -m experiments.mechinterp_m1 \ --all_runs \ --data_root data/wilds Output per run: experiments/runs//mechinterp/ m1_probe_heatmap.png ← epoch × layer, hospital probe acc m1_tumor_heatmap.png ← epoch × layer, tumor probe acc m1_probe_curves.png ← hospital vs tumor probe over epochs (layer 4) m1_probe_data.json ← raw numbers for paper tables """ from __future__ import annotations import argparse import glob import json import os from typing import Dict, List, Optional import numpy as np import torch import torch.nn as nn import torchvision.transforms as transforms from torch.utils.data import DataLoader, Subset from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import StandardScaler import matplotlib import matplotlib.pyplot as plt import timm import warnings warnings.filterwarnings("ignore") matplotlib.rcParams.update({"font.size": 11, "figure.dpi": 150}) # ────────────────────────────────────────────── # RESNET-18 LAYER HOOKS # Extract features after each of the 6 measurable stages: # stem → layer1 → layer2 → layer3 → layer4 → avgpool # ────────────────────────────────────────────── LAYER_NAMES = [ "stem", # After initial conv + bn + relu + maxpool "layer1", # ResNet block 1 (64 channels) "layer2", # ResNet block 2 (128 channels) "layer3", # ResNet block 3 (256 channels) "layer4", # ResNet block 4 (512 channels) "avgpool", # Global average pool — penultimate representation ] def register_hooks(model): """ Register forward hooks on all 6 extraction points. Returns (hooks, features_dict). """ features = {name: [] for name in LAYER_NAMES} hooks = [] def make_hook(name): def hook_fn(module, input, output): if output.dim() == 4: feat = output.mean(dim=[2, 3]) else: feat = output.view(output.size(0), -1) features[name].append(feat.detach().cpu()) return hook_fn hooks.append(model.maxpool.register_forward_hook(make_hook("stem"))) hooks.append(model.layer1.register_forward_hook(make_hook("layer1"))) hooks.append(model.layer2.register_forward_hook(make_hook("layer2"))) hooks.append(model.layer3.register_forward_hook(make_hook("layer3"))) hooks.append(model.layer4.register_forward_hook(make_hook("layer4"))) hooks.append(model.global_pool.register_forward_hook(make_hook("avgpool"))) return hooks, features def extract_features(model, loader, device, max_samples=1000): """ Run forward pass and collect features at all 6 layers. """ model.eval() hooks, feat_dict = register_hooks(model) all_hospital = [] all_tumor = [] count = 0 with torch.no_grad(): for batch in loader: imgs = batch[0].to(device) labels = batch[1].squeeze().long() metadata = batch[2] model(imgs) all_hospital.append(metadata[:, 0].long()) all_tumor.append(labels) count += imgs.size(0) if count >= max_samples: break for h in hooks: h.remove() features = {k: torch.cat(v).numpy() for k, v in feat_dict.items()} hospital_ids = torch.cat(all_hospital).numpy() tumor_labels = torch.cat(all_tumor).numpy() n = min(max_samples, len(hospital_ids)) features = {k: v[:n] for k, v in features.items()} hospital_ids = hospital_ids[:n] tumor_labels = tumor_labels[:n] return features, hospital_ids, tumor_labels def train_probe(X_train, y_train, X_val, y_val): """ Train logistic regression probe on frozen features. """ if len(np.unique(y_train)) < 2: return 0.5 scaler = StandardScaler() X_train = scaler.fit_transform(X_train) X_val = scaler.transform(X_val) clf = LogisticRegression( max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1, ) try: clf.fit(X_train, y_train) return clf.score(X_val, y_val) except Exception: return float("nan") # ────────────────────────────────────────────── # CHECKPOINT DISCOVERY # ────────────────────────────────────────────── def find_checkpoints(run_dir: str) -> List[tuple]: """ Find all checkpoints in a run directory. Returns list of (epoch, checkpoint_path) sorted by epoch. """ ckpt_dir = os.path.join(run_dir, "checkpoints") if not os.path.isdir(ckpt_dir): return [] checkpoints = [] # Periodic checkpoints: ep050.pt, ep100.pt, etc. for f in sorted(glob.glob(os.path.join(ckpt_dir, "ep*.pt"))): epoch_str = os.path.basename(f).replace("ep", "").replace(".pt", "") try: epoch = int(epoch_str) checkpoints.append((epoch, f)) except ValueError: continue # Final checkpoint final = os.path.join(ckpt_dir, "final.pt") if os.path.isfile(final): hist_path = os.path.join(run_dir, "results", "history.json") if os.path.isfile(hist_path): try: hist = json.load(open(hist_path)) epoch = hist[-1]["epoch"] if hist else 9999 except Exception: epoch = 9999 else: epoch = 9999 checkpoints.append((epoch, final)) return sorted(checkpoints, key=lambda x: x[0]) def load_model_from_checkpoint(ckpt_path: str, n_classes: int = 2, device: str = "cuda") -> nn.Module: model = timm.create_model("resnet18", pretrained=False, num_classes=n_classes) state = torch.load(ckpt_path, map_location=device) model.load_state_dict(state, strict=True) model.eval() return model.to(device) # ────────────────────────────────────────────── # MAIN PROBE ANALYSIS # ────────────────────────────────────────────── def run_probe_analysis(run_dir: str, data_root: str, device: str = "cuda", max_samples: int = 800, latest_only: bool = False) -> Optional[Dict]: """ For each checkpoint in a run, extract features at all 6 layers and train hospital + tumor probes. """ from utils.camelyon_data import get_camelyon_subsets cfg_path = os.path.join(run_dir, "config.json") if not os.path.isfile(cfg_path): print(f" No config.json in {run_dir}, skipping") return None cfg = json.load(open(cfg_path)) condition = cfg.get("condition", "unknown") n_train = cfg.get("n_train", 300) seed = cfg.get("seed", 42) print(f"\n{'='*55}") print(f" M1 Probe Analysis: {os.path.basename(run_dir)}") print(f" condition={condition}, n_train={n_train}, seed={seed}") print(f"{'='*55}") checkpoints = find_checkpoints(run_dir) if not checkpoints: print(f" No checkpoints found — skipping") return None if latest_only: checkpoints = checkpoints[-1:] print(f" Found {len(checkpoints)} checkpoints: " f"epochs {[e for e,_ in checkpoints]}") print(f" Hospital probe: fits on training data (H0-H2), " f"evaluates on H3 and H4 separately") # ── Data ───────────────────────────────────────────────────────── transform = transforms.Compose([ transforms.Resize((96, 96)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_ds, id_val_ds, ood_test_ds, full_ds = get_camelyon_subsets( root_dir=data_root, download=False) # Wrap datasets with transform (WILDS returns PIL images) class _TransformWrapper: def __init__(self, dataset, transform): self.dataset = dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, idx): img, label, metadata = self.dataset[idx] return self.transform(img), label, metadata id_val_t = _TransformWrapper(id_val_ds, transform) ood_test_t = _TransformWrapper(ood_test_ds, transform) train_t = _TransformWrapper(train_ds, transform) torch.manual_seed(seed) probe_idx = torch.randperm(len(id_val_t))[:max_samples // 2] ood_idx = torch.randperm(len(ood_test_t))[:max_samples // 2] train_idx = torch.randperm(len(train_t))[:max_samples] probe_loader = DataLoader( Subset(id_val_t, probe_idx), batch_size=128, shuffle=False, num_workers=0) ood_loader = DataLoader( Subset(ood_test_t, ood_idx), batch_size=128, shuffle=False, num_workers=0) train_loader = DataLoader( Subset(train_t, train_idx), batch_size=128, shuffle=False, num_workers=0) # ── Results storage ────────────────────────────────────────────── results = { "run_id": os.path.basename(run_dir), "condition": condition, "n_train": n_train, "seed": seed, "epochs": [], "layers": LAYER_NAMES, "hospital_probe_id": [], # Hospital accuracy on H3 "hospital_probe_ood": [], # Hospital accuracy on H4 "tumor_probe_id": [], # Tumor accuracy on H3 "tumor_probe_ood": [], # Tumor accuracy on H4 } # ── Per-checkpoint analysis ─────────────────────────────────────── for epoch, ckpt_path in checkpoints: print(f"\n Epoch {epoch} | {os.path.basename(ckpt_path)}") try: model = load_model_from_checkpoint( ckpt_path, n_classes=2, device=device) except Exception as e: print(f" Failed to load checkpoint: {e}") continue # Extract features from all three datasets feats_train, hosp_train, tumor_train = extract_features( model, train_loader, device, max_samples=max_samples) feats_id, hosp_id, tumor_id = extract_features( model, probe_loader, device, max_samples=max_samples // 2) feats_ood, hosp_ood, tumor_ood = extract_features( model, ood_loader, device, max_samples=max_samples // 2) epoch_hosp_id = [] epoch_hosp_ood = [] epoch_tumor_id = [] epoch_tumor_ood = [] for layer_name in LAYER_NAMES: # Fit probes on training features, evaluate on H3 and H4 X_train_layer = feats_train[layer_name] X_id_layer = feats_id[layer_name] X_ood_layer = feats_ood[layer_name] # Hospital probe: can model distinguish hospitals H0-H2? # If yes on H3/H4 → stain is encoded h_acc_id = train_probe(X_train_layer, hosp_train, X_id_layer, hosp_id) h_acc_ood = train_probe(X_train_layer, hosp_train, X_ood_layer, hosp_ood) # Tumor probe: can model distinguish tumor vs normal? t_acc_id = train_probe(X_train_layer, tumor_train, X_id_layer, tumor_id) t_acc_ood = train_probe(X_train_layer, tumor_train, X_ood_layer, tumor_ood) epoch_hosp_id.append(h_acc_id) epoch_hosp_ood.append(h_acc_ood) epoch_tumor_id.append(t_acc_id) epoch_tumor_ood.append(t_acc_ood) print(f" {layer_name:8s}: " f"hosp_H3={h_acc_id:.3f} hosp_H4={h_acc_ood:.3f} " f"tumor_H3={t_acc_id:.3f} tumor_H4={t_acc_ood:.3f}") results["epochs"].append(epoch) results["hospital_probe_id"].append(epoch_hosp_id) results["hospital_probe_ood"].append(epoch_hosp_ood) results["tumor_probe_id"].append(epoch_tumor_id) results["tumor_probe_ood"].append(epoch_tumor_ood) del model # ── Save raw data ───────────────────────────────────────────────── out_dir = os.path.join(run_dir, "mechinterp") os.makedirs(out_dir, exist_ok=True) data_path = os.path.join(out_dir, "m1_probe_data.json") with open(data_path, "w") as f: json.dump(results, f, indent=2) print(f"\n Probe data → {data_path}") # ── Plots ───────────────────────────────────────────────────────── _plot_probe_heatmaps(results, out_dir) _plot_probe_curves(results, out_dir) print(f" Figures → {out_dir}/") return results def _plot_probe_heatmaps(results: Dict, out_dir: str): """ Epoch (x) × layer (y), color = probe accuracy. Hospital probe shown on H3 (held-in held-out hospital, classes overlap with training). The H4 version is degenerate by construction since the probe is fit on the training-hospital class set and H4 is not in it (hospital_probe_ood ≡ 0 across all epochs / layers). Tumor probe shown on H4 (truly OOD hospital) since tumor labels are binary and shared across hospitals — H4 captures the causal-feature transferability we care about. """ epochs = results["epochs"] layers = results["layers"] if not epochs: return hosp_matrix = np.array(results["hospital_probe_id"]) # H3 — has signal tumor_matrix = np.array(results["tumor_probe_ood"]) # H4 — true OOD fig, axes = plt.subplots(1, 2, figsize=(14, 5)) for ax, matrix, title, cmap in [ (axes[0], hosp_matrix, "Hospital probe on H3 (shortcut recoverability)\nHigh = stain still encoded = BAD", "Reds"), (axes[1], tumor_matrix, "Tumor probe on H4 (causal, OOD)\nHigh = causal feature transfers = GOOD", "Greens"), ]: im = ax.imshow( matrix.T, aspect="auto", cmap=cmap, vmin=0.0, vmax=1.0, interpolation="nearest", origin="lower", ) ax.set_xticks(range(len(epochs))) ax.set_xticklabels(epochs, rotation=45, ha="right", fontsize=8) ax.set_yticks(range(len(layers))) ax.set_yticklabels(layers, fontsize=9) ax.set_xlabel("Training epoch") ax.set_ylabel("ResNet layer") ax.set_title(title, fontsize=10, fontweight="bold") plt.colorbar(im, ax=ax, label="Probe accuracy") fig.suptitle( f"M1 — Layer-wise Linear Probing: {results['run_id']}\n" "Circuit signature: deep-layer hospital-probe drop (Reds) + sustained tumor recoverability (Greens)", fontsize=10, y=1.02 ) plt.tight_layout() out = os.path.join(out_dir, "m1_probe_heatmap.png") plt.savefig(out, bbox_inches="tight") plt.close() def _plot_probe_curves(results: Dict, out_dir: str): """ Line plot per-layer: hospital probe (H3 — recoverability) + tumor probe (H4 — causal-feature transfer to truly unseen hospital), with OOD accuracy from history.json overlaid. """ epochs = results["epochs"] run_dir = os.path.join(out_dir, "..") layers = results["layers"] avgpool_idx = layers.index("avgpool") layer2_idx = layers.index("layer2") fig, axes = plt.subplots(1, 2, figsize=(14, 5)) for ax, layer_idx, layer_label in [ (axes[0], avgpool_idx, "avgpool (penultimate)"), (axes[1], layer2_idx, "layer2 (early)"), ]: hosp = [results["hospital_probe_id"][i][layer_idx] for i in range(len(epochs))] tumor = [results["tumor_probe_ood"][i][layer_idx] for i in range(len(epochs))] ax.plot(epochs, hosp, "r-o", markersize=4, lw=2, label="Hospital probe on H3 (shortcut recoverability ↓ want)") ax.plot(epochs, tumor, "g-s", markersize=4, lw=2, label="Tumor probe on H4 (causal transfer ↑ want)") hist_path = os.path.join(run_dir, "results", "history.json") if os.path.isfile(hist_path): try: hist = json.load(open(hist_path)) hist_eps = [r["epoch"] for r in hist] ood_accs = [r.get("ood_acc", float("nan")) for r in hist] ax.plot(hist_eps, ood_accs, "b--", lw=1.5, alpha=0.7, label="OOD accuracy (H4)") except Exception: pass ax.axhline(0.5, color="gray", ls=":", lw=1, alpha=0.5, label="Chance (0.5)") ax.set_xlabel("Training epoch") ax.set_ylabel("Probe / OOD accuracy") ax.set_title(f"Layer: {layer_label}", fontweight="bold") ax.legend(fontsize=9) ax.set_ylim([0, 1.05]) ax.grid(alpha=0.3) fig.suptitle( f"M1 — Probe Curves: {results['run_id']}\n" "Hospital recoverability (H3) drops in deep layers + tumor transfers (H4) — circuit signature", fontsize=10, y=1.02 ) plt.tight_layout() out = os.path.join(out_dir, "m1_probe_curves.png") plt.savefig(out, bbox_inches="tight") plt.close() def main(): p = argparse.ArgumentParser( description="M1: Layer-wise linear probing for CausalGrok") p.add_argument("--run_dir", default=None, help="Single run directory to analyze") p.add_argument("--all_runs", action="store_true", help="Analyze all camelyon_v2 grokking runs") p.add_argument("--data_root", default="data/wilds") p.add_argument("--device", default="cuda") p.add_argument("--max_samples", type=int, default=800) p.add_argument("--latest_only", action="store_true", help="Analyze only latest checkpoint (quick check)") args = p.parse_args() if args.all_runs: run_dirs = sorted(glob.glob( "experiments/runs/*camelyon_v2*grokking*")) print(f"Found {len(run_dirs)} grokking runs") all_results = [] for rd in run_dirs: r = run_probe_analysis(rd, args.data_root, device=args.device, max_samples=args.max_samples, latest_only=args.latest_only) if r: all_results.append(r) if all_results: os.makedirs("paper_figures", exist_ok=True) with open("paper_figures/m1_all_probes.json", "w") as f: json.dump(all_results, f, indent=2) print(f"\nCombined → paper_figures/m1_all_probes.json") elif args.run_dir: run_probe_analysis(args.run_dir, args.data_root, device=args.device, max_samples=args.max_samples, latest_only=args.latest_only) else: print("Specify --run_dir or --all_runs") if __name__ == "__main__": main() ``` --- ## 13. Full source: `experiments/mechinterp_m4_ablation.py` ```python """M4 — Representation Ablation: causal intervention on the shortcut subspace. Pipeline: 1. Pick a checkpoint (peak-OOD epoch by default). 2. Extract features at avgpool (or `--layer`) for train (H0-H2) + OOD (H4) splits. 3. Fit a hospital-classification logistic-regression probe on train features. The probe's weight rows define the *shortcut subspace* in feature space. 4. Build the projector P = W^T (W W^T)^-1 W onto that subspace and define `ablate(h) = h - P h`. 5. Re-classify OOD images with the *same* trained classifier head, fed: (a) raw features h — baseline OOD accuracy (b) ablated features h - Ph — post-intervention OOD accuracy 6. Also report: (c) shortcut accuracy (probe.score on h vs h-Ph) (d) tumor probe accuracy on h vs h-Ph (sanity: the causal feature should survive the intervention) (e) head's tumor classification accuracy on H4 with raw vs ablated features If the intervention is causal: - shortcut probe accuracy: collapses - OOD accuracy: improves (or at least doesn't decay as much) - tumor probe accuracy: largely preserved Usage ----- python -m experiments.mechinterp_m4_ablation \\ --run_dir experiments/runs/ \\ --data_root data/wilds \\ --layer avgpool \\ [--epoch 50] # default: peak_ood_epoch from summary.json [--max_samples 1000] Output: /mechinterp/m4_ablation__ep.json /mechinterp/m4_ablation__ep.png """ from __future__ import annotations import argparse import json import os import sys from pathlib import Path from typing import Dict, Tuple import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import StandardScaler from torch.utils.data import DataLoader, Subset from torchvision import transforms ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT)) # Re-use M1 helpers — hooks, model loader, feature extraction, ckpt discovery. from experiments.mechinterp_m1 import ( register_hooks, extract_features, load_model_from_checkpoint, find_checkpoints, ) from utils.camelyon_data import get_camelyon_subsets class _TransformWrapper: def __init__(self, dataset, transform): self.dataset = dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, idx): img, label, metadata = self.dataset[idx] return self.transform(img), label, metadata def _build_loaders(data_root: str, max_samples: int, seed: int = 42): transform = transforms.Compose([ transforms.Resize((96, 96)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets( root_dir=data_root, download=False ) train_t = _TransformWrapper(train_ds, transform) ood_t = _TransformWrapper(ood_test_ds, transform) torch.manual_seed(seed) train_idx = torch.randperm(len(train_t))[:max_samples] ood_idx = torch.randperm(len(ood_t))[:max_samples // 2] train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128, shuffle=False, num_workers=0) ood_loader = DataLoader(Subset(ood_t, ood_idx), batch_size=128, shuffle=False, num_workers=0) return train_loader, ood_loader def _select_epoch(run_dir: Path, requested: int | None) -> Tuple[int, Path]: ckpts = find_checkpoints(str(run_dir)) if not ckpts: raise FileNotFoundError(f"No checkpoints in {run_dir}/checkpoints/") if requested is not None: for ep, p in ckpts: if ep == requested: return ep, Path(p) raise ValueError(f"Requested epoch {requested} not in checkpoints " f"({[ep for ep, _ in ckpts]})") # default: peak OOD epoch from summary.json summary_path = run_dir / "results" / "summary.json" peak = None if summary_path.exists(): s = json.loads(summary_path.read_text()) peak = s.get("peak_ood_epoch", None) if peak is not None and peak > 0: # nearest periodic checkpoint nearest = min(ckpts, key=lambda x: abs(x[0] - peak)) return nearest[0], Path(nearest[1]) # fall back to last checkpoint return ckpts[-1][0], Path(ckpts[-1][1]) def _build_projector(W: np.ndarray) -> np.ndarray: """W has shape (k, d). Returns P (d, d) projecting onto rowspace(W).""" # Use SVD for a stable orthonormal basis of rowspace U, s, Vt = np.linalg.svd(W, full_matrices=False) # rowspace basis = Vt rows where singular values > tol tol = max(W.shape) * np.finfo(s.dtype).eps * (s.max() if s.size else 0.0) keep = s > tol basis = Vt[keep] # (k', d) return basis.T @ basis # (d, d) projector onto rowspace def _build_shortcut_subspace( X: np.ndarray, hospital_ids: np.ndarray, method: str = "lda", subspace_dim: int = 32 ) -> np.ndarray: """Return a (k, d) basis whose row-span is the 'shortcut subspace'. method='probe' — k = (n_classes - 1) probe weight rows (small subspace). method='lda' — k = subspace_dim top between-class directions: take per-hospital means in feature space, center them, and run SVD. This gives a rank-bounded but data-driven subspace that captures hospital-discriminating variance. method='pca-class' — top-PCs of features colored by hospital (mean-removed per class), giving us the variance directions that mostly reflect within-hospital structure × class. """ if method == "probe": clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1) clf.fit(X, hospital_ids) return clf.coef_ if method == "lda": classes = np.unique(hospital_ids) global_mean = X.mean(axis=0, keepdims=True) between = [] for c in classes: mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True) between.append(mu_c - global_mean) between = np.vstack(between) # (n_classes, d) # Augment with random hospital-correlated directions to grow rank up # to subspace_dim — use top PCs of *centered-by-hospital-mean* features. if subspace_dim > between.shape[0]: # within-hospital residuals residuals = [] for c in classes: mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True) residuals.append(X[hospital_ids == c] - mu_c) R = np.vstack(residuals) # PCA on residuals — these are within-hospital directions; remove # them from the shortcut subspace by KEEPING only the between-class # directions. So we just return between as-is, plus the top PCs of # the *original* features projected onto the orthogonal complement # of `between` IF the user wants more dims. U, s, Vt = np.linalg.svd(X - global_mean, full_matrices=False) top = Vt[:subspace_dim] # Score each PC by how much it correlates with hospital-id variance # (one-hot expansion); keep top by that correlation. one_hot = np.eye(len(classes))[ np.searchsorted(classes, hospital_ids) ] # (N, n_classes) proj = (X - global_mean) @ top.T # (N, subspace_dim) corrs = np.array([ np.max(np.abs([np.corrcoef(proj[:, k], one_hot[:, c])[0, 1] for c in range(len(classes))])) for k in range(subspace_dim) ]) # take the top-k most-hospital-correlated PCs order = np.argsort(-np.nan_to_num(corrs)) top_hosp = top[order[:subspace_dim]] # combine: between-class means + top-hospital-correlated PCs return np.vstack([between, top_hosp]) return between raise ValueError(f"Unknown method: {method}") def _classifier_logits_from_features( model: nn.Module, features: np.ndarray, layer: str, device: str ) -> np.ndarray: """Apply the *post-`layer`* part of the network to the (modified) features and return the model's binary-classification logits. For ResNet, `avgpool` features have shape (N, C). The classifier head `model.fc` (timm: `model.get_classifier()`) maps C → 2. For non-avgpool layers we do not currently support full propagation — caller should use layer='avgpool' for OOD-accuracy interventions.""" if layer != "avgpool": raise NotImplementedError( "Re-applying the classifier head from intermediate spatial layers " "is not yet supported. Use --layer avgpool for the head-level " "ablation." ) # Find the classifier head (timm convention: model.fc or model.get_classifier()) if hasattr(model, "get_classifier"): head = model.get_classifier() elif hasattr(model, "fc"): head = model.fc elif hasattr(model, "classifier"): head = model.classifier else: raise RuntimeError("Could not locate classifier head on the model.") head = head.to(device).eval() with torch.no_grad(): x = torch.tensor(features, dtype=torch.float32, device=device) logits = head(x).cpu().numpy() return logits def _accuracy(logits: np.ndarray, labels: np.ndarray) -> float: if logits.ndim == 1 or logits.shape[1] == 1: pred = (logits.flatten() > 0).astype(int) else: pred = logits.argmax(axis=1) return float((pred == labels).mean()) def run_ablation( run_dir: Path, data_root: str, layer: str = "avgpool", epoch: int | None = None, max_samples: int = 1000, device: str = "cuda", subspace_method: str = "lda", subspace_dim: int = 32, ) -> Dict: epoch, ckpt_path = _select_epoch(run_dir, epoch) print(f"\n M4 — Representation Ablation") print(f" run_dir : {run_dir.name}") print(f" epoch : {epoch} ({ckpt_path.name})") print(f" layer : {layer}") # Load model and dataloaders model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device) model.eval() register_hooks(model) cfg_path = run_dir / "config.json" seed = 42 if cfg_path.exists(): seed = json.loads(cfg_path.read_text()).get("seed", 42) train_loader, ood_loader = _build_loaders(data_root, max_samples, seed=seed) # Extract features print(f" Extracting features ({max_samples} samples per split)...") feats_train, hosp_train, tumor_train = extract_features( model, train_loader, device, max_samples=max_samples ) feats_ood, hosp_ood, tumor_ood = extract_features( model, ood_loader, device, max_samples=max_samples // 2 ) if layer not in feats_train: raise KeyError(f"Layer '{layer}' not in extracted features " f"({list(feats_train.keys())})") X_tr = np.asarray(feats_train[layer]) # (N_tr, D) X_ood = np.asarray(feats_ood[layer]) # (N_ood, D) if X_tr.ndim > 2: # spatial map; flatten X_tr = X_tr.reshape(X_tr.shape[0], -1) X_ood = X_ood.reshape(X_ood.shape[0], -1) # Normalize features (probe is sensitive to scale; classifier head was # trained on un-normalized features so we keep two parallel pipelines). scaler = StandardScaler().fit(X_tr) X_tr_n = scaler.transform(X_tr) X_ood_n = scaler.transform(X_ood) # ──────────── 1. Fit hospital probe + build shortcut subspace print(f" Fitting hospital probe on H0/H1/H2 train features...") hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1) hosp_clf.fit(X_tr_n, hosp_train) hosp_acc_train = hosp_clf.score(X_tr_n, hosp_train) # Build a richer shortcut subspace via LDA-style between-class + # hospital-correlated top PCs. This catches more shortcut variance than # the (n_classes - 1)-D probe-rowspace alone. W = _build_shortcut_subspace(X_tr_n, np.asarray(hosp_train), method=subspace_method, subspace_dim=subspace_dim) P = _build_projector(W) # (D, D) rank_subspace = int(np.linalg.matrix_rank(P, tol=1e-8)) print(f" Shortcut subspace: dim={rank_subspace} method={subspace_method} " f"(probe train acc {hosp_acc_train:.3f})") # ──────────── 2. Build ablated versions of features # Apply the projection in the *normalized* feature space, then un-scale # for re-feeding to the classifier head (which was trained on raw features). def ablate_norm(X_n): return X_n - X_n @ P.T X_ood_ablated_n = ablate_norm(X_ood_n) # un-scale X_ood_ablated = scaler.inverse_transform(X_ood_ablated_n) # Sanity probe metrics print(f" Re-fitting tumor probe on train features...") tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1) tumor_clf.fit(X_tr_n, tumor_train) tumor_acc_train = tumor_clf.score(X_tr_n, tumor_train) # Probe accuracies on raw vs ablated OOD features hosp_acc_ood_raw = hosp_clf.score(X_ood_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan") hosp_acc_ood_ablated = hosp_clf.score(X_ood_ablated_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan") tumor_acc_ood_raw = tumor_clf.score(X_ood_n, tumor_ood) tumor_acc_ood_ablated = tumor_clf.score(X_ood_ablated_n, tumor_ood) # ──────────── 3. Head-level OOD classification accuracy print(f" Re-classifying OOD with model head (raw vs ablated features)...") logits_raw = _classifier_logits_from_features(model, X_ood, layer, device) logits_ablated = _classifier_logits_from_features(model, X_ood_ablated, layer, device) head_acc_raw = _accuracy(logits_raw, tumor_ood) head_acc_ablated = _accuracy(logits_ablated, tumor_ood) # ──────────── 4. Pack + report result = { "run_id": run_dir.name, "epoch": epoch, "layer": layer, "max_samples": max_samples, "shortcut_subspace_dim": rank_subspace, "hospital_probe_train_acc": hosp_acc_train, "tumor_probe_train_acc": tumor_acc_train, "hospital_probe_ood_raw": hosp_acc_ood_raw, "hospital_probe_ood_ablated": hosp_acc_ood_ablated, "tumor_probe_ood_raw": tumor_acc_ood_raw, "tumor_probe_ood_ablated": tumor_acc_ood_ablated, "head_ood_acc_raw": head_acc_raw, "head_ood_acc_ablated": head_acc_ablated, "intervention_effect": { "shortcut_collapse": hosp_acc_ood_raw - hosp_acc_ood_ablated, "ood_improvement": head_acc_ablated - head_acc_raw, "tumor_preservation": tumor_acc_ood_ablated - tumor_acc_ood_raw, }, } print(f"\n RESULTS") print(f" hospital probe (OOD): {hosp_acc_ood_raw:.3f} → {hosp_acc_ood_ablated:.3f} " f"(Δ {result['intervention_effect']['shortcut_collapse']:+.3f})") print(f" tumor probe (OOD) : {tumor_acc_ood_raw:.3f} → {tumor_acc_ood_ablated:.3f} " f"(Δ {result['intervention_effect']['tumor_preservation']:+.3f})") print(f" head OOD acc : {head_acc_raw:.3f} → {head_acc_ablated:.3f} " f"(Δ {result['intervention_effect']['ood_improvement']:+.3f})") return result def plot_ablation(result: Dict, out_path: Path): metrics = ["hospital_probe_ood", "tumor_probe_ood", "head_ood_acc"] raw_keys = ["hospital_probe_ood_raw", "tumor_probe_ood_raw", "head_ood_acc_raw"] ablated_keys = ["hospital_probe_ood_ablated", "tumor_probe_ood_ablated", "head_ood_acc_ablated"] labels = ["Hospital probe\n(↓ = causal effect)", "Tumor probe\n(stable = good)", "Head OOD acc\n(↑ = causal effect)"] raws = [result[k] for k in raw_keys] ablateds = [result[k] for k in ablated_keys] fig, ax = plt.subplots(figsize=(9, 5)) x = np.arange(len(metrics)) w = 0.35 b1 = ax.bar(x - w / 2, raws, w, label="raw features", color="#444") b2 = ax.bar(x + w / 2, ablateds, w, label="shortcut-ablated", color="#c33") for bars in (b1, b2): for b in bars: ax.text(b.get_x() + b.get_width() / 2, b.get_height() + 0.005, f"{b.get_height():.3f}", ha="center", va="bottom", fontsize=9) ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=9) ax.set_ylim(0, 1.05); ax.set_ylabel("Accuracy") ax.set_title(f"M4 — Causal Ablation of Shortcut Subspace\n" f"{result['run_id']} • ep{result['epoch']} • layer={result['layer']} " f"• subspace dim={result['shortcut_subspace_dim']}", fontsize=10, fontweight="bold") ax.legend(loc="upper right") ax.grid(alpha=0.3, axis="y") plt.tight_layout() fig.savefig(out_path, dpi=180, bbox_inches="tight") plt.close(fig) def main(): p = argparse.ArgumentParser() p.add_argument("--run_dir", required=True) p.add_argument("--data_root", default="data/wilds") p.add_argument("--layer", default="avgpool", choices=["avgpool"]) # head-level intervention only at avgpool p.add_argument("--epoch", type=int, default=None, help="Specific checkpoint epoch; default = peak_ood_epoch from summary.json") p.add_argument("--max_samples", type=int, default=1000) p.add_argument("--device", default="cuda") p.add_argument("--subspace_method", default="lda", choices=["lda", "probe"], help="lda = LDA-style between-class + hospital-correlated PCs; " "probe = LR probe row-space (small, often only 2-D)") p.add_argument("--subspace_dim", type=int, default=32, help="Target subspace dim for lda method") p.add_argument("--all_epochs", action="store_true", help="Sweep across all periodic checkpoints") args = p.parse_args() run_dir = Path(args.run_dir) out_dir = run_dir / "mechinterp" out_dir.mkdir(parents=True, exist_ok=True) if args.all_epochs: # Sweep across every periodic checkpoint, build a trajectory. ckpts = find_checkpoints(str(run_dir)) # de-duplicate (final.pt may share epoch with last ep*.pt) seen = set(); uniq = [] for ep, p in ckpts: if ep in seen: continue seen.add(ep); uniq.append((ep, p)) traj = [] for ep, _ in uniq: try: r = run_ablation( run_dir=run_dir, data_root=args.data_root, layer=args.layer, epoch=ep, max_samples=args.max_samples, device=args.device, subspace_method=args.subspace_method, subspace_dim=args.subspace_dim, ) traj.append(r) except Exception as e: print(f" [skip ep{ep}] {e}") out = out_dir / f"m4_ablation_{args.layer}_trajectory.json" out.write_text(json.dumps(traj, indent=2)) plot_trajectory(traj, out.with_suffix(".png")) print(f"\n → {out}") print(f" → {out.with_suffix('.png')}") return result = run_ablation( run_dir=run_dir, data_root=args.data_root, layer=args.layer, epoch=args.epoch, max_samples=args.max_samples, device=args.device, subspace_method=args.subspace_method, subspace_dim=args.subspace_dim, ) base = out_dir / f"m4_ablation_{args.layer}_ep{result['epoch']:05d}" (base.with_suffix(".json")).write_text(json.dumps(result, indent=2)) plot_ablation(result, base.with_suffix(".png")) print(f"\n → {base.with_suffix('.json')}") print(f" → {base.with_suffix('.png')}") def plot_trajectory(traj, out_path: Path): """Plot the intervention effect across training epochs.""" eps = [r["epoch"] for r in traj] head_raw = [r["head_ood_acc_raw"] for r in traj] head_abl = [r["head_ood_acc_ablated"] for r in traj] tum_raw = [r["tumor_probe_ood_raw"] for r in traj] tum_abl = [r["tumor_probe_ood_ablated"] for r in traj] fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Panel A: head OOD acc raw vs ablated ax = axes[0] ax.plot(eps, head_raw, "k-o", lw=2, label="raw features") ax.plot(eps, head_abl, "r-s", lw=2, label="shortcut-ablated features") ax.fill_between(eps, head_raw, head_abl, where=[a > b for a, b in zip(head_abl, head_raw)], color="seagreen", alpha=0.3, label="ablation helps") ax.fill_between(eps, head_raw, head_abl, where=[a < b for a, b in zip(head_abl, head_raw)], color="salmon", alpha=0.3, label="ablation hurts") ax.set_xlabel("Training epoch"); ax.set_ylabel("OOD (H4) head accuracy") ax.set_title("Head OOD accuracy: raw vs shortcut-ablated", fontweight="bold") ax.legend(fontsize=9); ax.grid(alpha=0.3) # Panel B: tumor probe survival ax = axes[1] ax.plot(eps, tum_raw, "k-o", lw=2, label="raw features") ax.plot(eps, tum_abl, "g-s", lw=2, label="shortcut-ablated features") ax.set_xlabel("Training epoch"); ax.set_ylabel("Tumor probe OOD accuracy") ax.set_title("Tumor probe survival under ablation\n(stable line = causal feature preserved)", fontweight="bold") ax.legend(fontsize=9); ax.grid(alpha=0.3); ax.set_ylim(0.4, 1.0) rid = traj[0]["run_id"] if traj else "?" layer = traj[0]["layer"] if traj else "?" fig.suptitle(f"M4 — Causal Ablation Trajectory: {rid} • layer={layer}", fontsize=11, fontweight="bold") plt.tight_layout() fig.savefig(out_path, dpi=180, bbox_inches="tight") plt.close(fig) if __name__ == "__main__": main() ``` --- ## 14. Full source: `experiments/mechinterp_m5_steering.py` ```python """M5 — Activation Steering: causally manipulate the shortcut direction. For one checkpoint (default: peak_ood_epoch from summary.json), we: 1. Extract avgpool features for train (H0-H2) + OOD (H4) splits. 2. Identify the dominant shortcut direction `v_s` as the top eigenvector of the between-hospital covariance (LDA's first projection direction). 3. Sweep α ∈ {-3, -2, -1, 0, +1, +2, +3} and apply h' = h + α · σ_align · v_s where σ_align is the std of features projected onto v_s (so α counts in 'standard deviations of shortcut activation'). 4. Re-classify OOD with the original head. 5. Re-fit hospital + tumor probes on the steered features and report accuracy curves. Strong mechanistic claim if: - tumor-head OOD acc declines monotonically as |α| grows - hospital-probe acc on steered features rises with |α| - tumor-probe acc on steered features stays approximately flat (the *causal* feature isn't aligned with the shortcut direction) Usage ----- python -m experiments.mechinterp_m5_steering \\ --run_dir experiments/runs/ \\ --data_root data/wilds \\ [--epoch 50] # default: peak_ood_epoch from summary.json [--max_samples 1000] [--alphas " -3,-2,-1,0,1,2,3 "] """ from __future__ import annotations import argparse import json from pathlib import Path from typing import Dict, List, Tuple import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import StandardScaler from experiments.mechinterp_m1 import ( register_hooks, extract_features, load_model_from_checkpoint, find_checkpoints, ) from experiments.mechinterp_m4_ablation import ( _select_epoch, _build_loaders, _classifier_logits_from_features, _accuracy, ) def _top_lda_direction(X: np.ndarray, hospital_ids: np.ndarray) -> np.ndarray: """Return a unit vector aligned with the dominant between-hospital direction in feature space (LDA-1).""" classes = np.unique(hospital_ids) global_mean = X.mean(axis=0, keepdims=True) means = np.vstack([ X[hospital_ids == c].mean(axis=0, keepdims=True) - global_mean for c in classes ]) # SVD: rows of Vt are the orthonormal between-class directions ranked by # singular value (variance explained between hospitals). U, s, Vt = np.linalg.svd(means, full_matrices=False) return Vt[0] # (D,) unit vector def run_steering( run_dir: Path, data_root: str, epoch: int | None = None, max_samples: int = 1000, device: str = "cuda", alphas: List[float] = None, ) -> Dict: if alphas is None: alphas = [-3.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 3.0] epoch, ckpt_path = _select_epoch(run_dir, epoch) print(f"\n M5 — Activation Steering") print(f" run_dir : {run_dir.name}") print(f" epoch : {epoch} ({ckpt_path.name})") print(f" alphas : {alphas}") model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device) model.eval() register_hooks(model) cfg_path = run_dir / "config.json" seed = 42 if cfg_path.exists(): seed = json.loads(cfg_path.read_text()).get("seed", 42) train_loader, ood_loader = _build_loaders(data_root, max_samples, seed=seed) print(f" Extracting features...") feats_train, hosp_train, tumor_train = extract_features( model, train_loader, device, max_samples=max_samples ) feats_ood, hosp_ood, tumor_ood = extract_features( model, ood_loader, device, max_samples=max_samples // 2 ) layer = "avgpool" X_tr = np.asarray(feats_train[layer]); X_tr = X_tr.reshape(X_tr.shape[0], -1) X_ood = np.asarray(feats_ood[layer]); X_ood = X_ood.reshape(X_ood.shape[0], -1) hosp_train = np.asarray(hosp_train) hosp_ood = np.asarray(hosp_ood) tumor_train = np.asarray(tumor_train) tumor_ood = np.asarray(tumor_ood) # Standardize for probe-fitting; un-standardize when feeding head scaler = StandardScaler().fit(X_tr) X_tr_n = scaler.transform(X_tr) X_ood_n = scaler.transform(X_ood) # 1. Top LDA direction in normalized feature space v = _top_lda_direction(X_tr_n, hosp_train) # (D,) unit vec # Std of training features projected onto v (scale unit for α) sigma = float(np.std(X_tr_n @ v)) print(f" Top hospital direction v_s : ‖v‖={np.linalg.norm(v):.3f}, " f"σ(X_tr·v)={sigma:.3f}") # 2. Pre-fit reference probes on un-steered train features hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1).fit(X_tr_n, hosp_train) tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1).fit(X_tr_n, tumor_train) # 3. Sweep α sweep = [] for alpha in alphas: # Steer features along v in normalized space, then un-scale for the head. X_ood_steered_n = X_ood_n + alpha * sigma * v[None, :] X_ood_steered = scaler.inverse_transform(X_ood_steered_n) # Head OOD accuracy logits = _classifier_logits_from_features(model, X_ood_steered, layer, device) head_acc = _accuracy(logits, tumor_ood) # Probe accuracies on steered features if len(np.unique(hosp_ood)) > 1: hosp_acc = hosp_clf.score(X_ood_steered_n, hosp_ood) else: hosp_acc = float("nan") tumor_acc = tumor_clf.score(X_ood_steered_n, tumor_ood) sweep.append({ "alpha": float(alpha), "head_ood_acc": head_acc, "hospital_probe": hosp_acc, "tumor_probe": tumor_acc, }) print(f" α={alpha:+.2f} head_ood={head_acc:.3f} " f"hosp_probe={hosp_acc if not np.isnan(hosp_acc) else 'nan':<5} " f"tumor_probe={tumor_acc:.3f}") return { "run_id": run_dir.name, "epoch": epoch, "layer": layer, "max_samples": max_samples, "v_norm": float(np.linalg.norm(v)), "sigma": sigma, "sweep": sweep, } def plot_steering(result: Dict, out_path: Path): sweep = result["sweep"] a = [r["alpha"] for r in sweep] head = [r["head_ood_acc"] for r in sweep] hosp = [r["hospital_probe"] for r in sweep] tumor = [r["tumor_probe"] for r in sweep] fig, axes = plt.subplots(1, 2, figsize=(13, 5)) # Panel A — Head OOD acc vs α ax = axes[0] ax.plot(a, head, "k-o", lw=2, ms=7) ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5) ax.set_xlabel("Steering coefficient α (in σ-units of shortcut direction)") ax.set_ylabel("Head OOD (H4) accuracy") ax.set_title("Causal effect of steering activations along v_s\n" "(monotonic decline as |α| grows = causal evidence)", fontweight="bold", fontsize=10) ax.grid(alpha=0.3) ax.set_ylim(0.4, max(0.85, max(head) + 0.05)) # Panel B — Probe accuracies vs α ax = axes[1] ax.plot(a, hosp, "r-s", lw=2, ms=7, label="Hospital probe (↑ with |α| = good)") ax.plot(a, tumor, "g-^", lw=2, ms=7, label="Tumor probe (flat = causal disjoint)") ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5) ax.set_xlabel("Steering coefficient α") ax.set_ylabel("Probe accuracy") ax.set_title("Probe responses to steering", fontweight="bold", fontsize=10) ax.legend(loc="best", fontsize=9); ax.grid(alpha=0.3) ax.set_ylim(0, 1.05) fig.suptitle(f"M5 — Activation Steering: {result['run_id']} " f"• ep{result['epoch']} • layer={result['layer']}", fontsize=11, fontweight="bold") plt.tight_layout() fig.savefig(out_path, dpi=180, bbox_inches="tight") plt.close(fig) def main(): p = argparse.ArgumentParser() p.add_argument("--run_dir", required=True) p.add_argument("--data_root", default="data/wilds") p.add_argument("--epoch", type=int, default=None) p.add_argument("--max_samples", type=int, default=1000) p.add_argument("--device", default="cuda") p.add_argument("--alphas", default=None, help="Comma-separated α values, e.g. ' -3,-2,-1,0,1,2,3 '") p.add_argument("--all_epochs", action="store_true", help="Sweep across all periodic checkpoints; output a trajectory") args = p.parse_args() alphas = None if args.alphas is not None: alphas = [float(x) for x in args.alphas.split(",")] run_dir = Path(args.run_dir) out_dir = run_dir / "mechinterp" out_dir.mkdir(parents=True, exist_ok=True) if args.all_epochs: # Trajectory mode: run M5 at every periodic checkpoint ckpts = find_checkpoints(str(run_dir)) seen = set(); uniq = [] for ep, p in ckpts: if ep in seen: continue seen.add(ep); uniq.append((ep, p)) traj = [] for ep, _ in uniq: try: r = run_steering( run_dir=run_dir, data_root=args.data_root, epoch=ep, max_samples=args.max_samples, device=args.device, alphas=alphas, ) traj.append(r) except Exception as e: print(f" [skip ep{ep}] {e}") out = out_dir / "m5_steering_trajectory.json" out.write_text(json.dumps(traj, indent=2)) print(f"\n → {out}") return result = run_steering( run_dir=run_dir, data_root=args.data_root, epoch=args.epoch, max_samples=args.max_samples, device=args.device, alphas=alphas, ) base = out_dir / f"m5_steering_ep{result['epoch']:05d}" base.with_suffix(".json").write_text(json.dumps(result, indent=2)) plot_steering(result, base.with_suffix(".png")) print(f"\n → {base.with_suffix('.json')}") print(f" → {base.with_suffix('.png')}") if __name__ == "__main__": main() ``` --- ## 15. Full source: `experiments/mechinterp_m6_neuron_ablation.py` ```python """M6 — Neuron-level Ablation (the textbook reviewer-asked intervention). Pipeline: 1. At a chosen checkpoint (default: peak_ood_epoch), extract avgpool features for train (H0-H2) and OOD (H4) splits. 2. Score each of the 512 avgpool channels by *how predictive its activation is of hospital ID*: we use a one-vs-rest logistic-regression coefficient per channel × class as the per-neuron shortcut score: score_c = max_h |β_{h,c}| (β = coefficients of LR fit per channel) ↑ score_c → channel c is more strongly stain-shortcut-aligned. 3. Sweep top-K ∈ {0, 8, 16, 32, 64, 128} ablated neurons (zero out their activations) and measure: - head OOD acc (raw vs ablated) - hospital-probe acc on raw vs ablated features - tumor-probe acc on raw vs ablated features 4. Strong mechanistic claim: - hospital-probe acc collapses sharply with K (these neurons are carrying hospital info) - head OOD acc *improves* (or at least preserves) at small K (the model was using shortcut neurons to harm OOD) - tumor-probe acc stays flat (causal info is distributed elsewhere) Usage ----- python -m experiments.mechinterp_m6_neuron_ablation \\ --run_dir experiments/runs/ \\ --data_root data/wilds \\ [--epoch 50] [--max_samples 1000] \\ [--ks "0,4,8,16,32,64,128,256"] """ from __future__ import annotations import argparse import json from pathlib import Path from typing import Dict, List import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import torch from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import StandardScaler from torch.utils.data import DataLoader, Subset from torchvision import transforms from experiments.mechinterp_m1 import ( register_hooks, extract_features, load_model_from_checkpoint, ) from experiments.mechinterp_m4_ablation import ( _select_epoch, _TransformWrapper, _classifier_logits_from_features, _accuracy, ) from utils.camelyon_data import get_camelyon_subsets def _build_loaders_with_id(data_root: str, max_samples: int, seed: int = 42): """Like M4's _build_loaders but also returns an ID validation loader so we can track ID acc and compute the OOD/ID degradation ratio.""" transform = transforms.Compose([ transforms.Resize((96, 96)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets( root_dir=data_root, download=False ) train_t = _TransformWrapper(train_ds, transform) id_t = _TransformWrapper(id_val_ds, transform) ood_t = _TransformWrapper(ood_test_ds, transform) torch.manual_seed(seed) train_idx = torch.randperm(len(train_t))[:max_samples] id_idx = torch.randperm(len(id_t))[:max_samples // 2] ood_idx = torch.randperm(len(ood_t))[:max_samples // 2] train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128, shuffle=False, num_workers=0) id_loader = DataLoader(Subset(id_t, id_idx), batch_size=128, shuffle=False, num_workers=0) ood_loader = DataLoader(Subset(ood_t, ood_idx), batch_size=128, shuffle=False, num_workers=0) return train_loader, id_loader, ood_loader def _per_neuron_shortcut_scores(X_n: np.ndarray, hosp: np.ndarray) -> np.ndarray: """Return a (D,) array — score per channel c, larger = more hospital-predictive. Uses a 1-feature-at-a-time log-reg fit's |coef| would be dominated by feature scale; instead we fit a single multiclass LR over all features and use the L2 norm of (β_{:,c}) — the column norm of the LR coefficient matrix — as channel c's hospital-discrimination score. """ clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1).fit(X_n, hosp) W = clf.coef_ # (n_classes, D) # column norms — large means many class-discriminations rely on this neuron return np.linalg.norm(W, axis=0) # (D,) def _ablate_and_eval( X_n, mask, scaler, head_target, model, layer, device, hosp_clf, tumor_clf, hosp_target, tumor_target, ): """Apply mask to normalized features, unscale, evaluate everything.""" X_ablated_n = X_n * mask[None, :] X_ablated = scaler.inverse_transform(X_ablated_n) logits = _classifier_logits_from_features(model, X_ablated, layer, device) head_acc = _accuracy(logits, head_target) hosp_acc = hosp_clf.score(X_ablated_n, hosp_target) if hosp_clf is not None and len(np.unique(hosp_target)) > 1 else float("nan") tumor_acc = tumor_clf.score(X_ablated_n, tumor_target) return head_acc, hosp_acc, tumor_acc def run_neuron_ablation( run_dir: Path, data_root: str, epoch: int | None = None, max_samples: int = 1000, device: str = "cuda", ks: List[int] = None, n_random_samples: int = 5, include_morphology: bool = True, include_id: bool = True, ) -> Dict: if ks is None: # Dose-response curve emphasizing small K (per reviewer guidance) ks = [0, 4, 8, 16, 32, 64, 128, 256] epoch, ckpt_path = _select_epoch(run_dir, epoch) print(f"\n M6 — Neuron Ablation (with random + morphology controls)") print(f" run_dir : {run_dir.name}") print(f" epoch : {epoch} ({ckpt_path.name})") print(f" ks : {ks}") print(f" random ablation: {n_random_samples} samplings per K") model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device) model.eval() register_hooks(model) cfg_path = run_dir / "config.json" seed = 42 if cfg_path.exists(): seed = json.loads(cfg_path.read_text()).get("seed", 42) if include_id: train_loader, id_loader, ood_loader = _build_loaders_with_id(data_root, max_samples, seed=seed) else: from experiments.mechinterp_m4_ablation import _build_loaders as _bl train_loader, ood_loader = _bl(data_root, max_samples, seed=seed) id_loader = None print(f" Extracting features (train + id + ood)...") feats_train, hosp_train, tumor_train = extract_features( model, train_loader, device, max_samples=max_samples ) feats_ood, hosp_ood, tumor_ood = extract_features( model, ood_loader, device, max_samples=max_samples // 2 ) feats_id, hosp_id, tumor_id = (None, None, None) if id_loader is not None: feats_id, hosp_id, tumor_id = extract_features( model, id_loader, device, max_samples=max_samples // 2 ) layer = "avgpool" def _to_2d(arr): a = np.asarray(arr); return a.reshape(a.shape[0], -1) X_tr = _to_2d(feats_train[layer]) X_ood = _to_2d(feats_ood[layer]) X_id = _to_2d(feats_id[layer]) if feats_id is not None else None hosp_train = np.asarray(hosp_train); hosp_ood = np.asarray(hosp_ood) tumor_train = np.asarray(tumor_train); tumor_ood = np.asarray(tumor_ood) if X_id is not None: hosp_id = np.asarray(hosp_id); tumor_id = np.asarray(tumor_id) scaler = StandardScaler().fit(X_tr) X_tr_n = scaler.transform(X_tr) X_ood_n = scaler.transform(X_ood) X_id_n = scaler.transform(X_id) if X_id is not None else None # 1. Per-neuron scores: shortcut (hospital) and morphology (tumor) print(f" Scoring {X_tr.shape[1]} avgpool channels...") shortcut_scores = _per_neuron_shortcut_scores(X_tr_n, hosp_train) morphology_scores = _per_neuron_shortcut_scores(X_tr_n, tumor_train) if include_morphology else None rank_shortcut = np.argsort(-shortcut_scores) rank_morphology = np.argsort(-morphology_scores) if morphology_scores is not None else None hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1).fit(X_tr_n, hosp_train) tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1).fit(X_tr_n, tumor_train) rng = np.random.default_rng(seed) D = X_tr.shape[1] sweep = [] for k in ks: row = {"k": int(k)} # Mask helpers def make_mask(indices): m = np.ones(D) if k > 0: m[indices[:k]] = 0.0 return m # ── A: top-K SHORTCUT neurons (the targeted ablation) ── mask_s = make_mask(rank_shortcut) h_ood, hp_ood, tp_ood = _ablate_and_eval( X_ood_n, mask_s, scaler, tumor_ood, model, layer, device, hosp_clf, tumor_clf, hosp_ood, tumor_ood, ) row["shortcut_head_ood"] = float(h_ood) row["shortcut_hosp_probe"] = float(hp_ood) row["shortcut_tumor_probe"] = float(tp_ood) if X_id_n is not None: h_id, _, _ = _ablate_and_eval( X_id_n, mask_s, scaler, tumor_id, model, layer, device, None, tumor_clf, hosp_id, tumor_id, ) row["shortcut_head_id"] = float(h_id) # ── B: top-K MORPHOLOGY neurons (control: ablate the causal neurons) ── if include_morphology and rank_morphology is not None: mask_m = make_mask(rank_morphology) h_ood_m, _, _ = _ablate_and_eval( X_ood_n, mask_m, scaler, tumor_ood, model, layer, device, None, tumor_clf, hosp_ood, tumor_ood, ) row["morphology_head_ood"] = float(h_ood_m) if X_id_n is not None: h_id_m, _, _ = _ablate_and_eval( X_id_n, mask_m, scaler, tumor_id, model, layer, device, None, tumor_clf, hosp_id, tumor_id, ) row["morphology_head_id"] = float(h_id_m) # ── C: K RANDOM neurons (control: damage uniformly) ── if k > 0: r_oods, r_ids = [], [] for s_ in range(n_random_samples): idx = rng.permutation(D)[:k] m = np.ones(D); m[idx] = 0.0 h_ood_r, _, _ = _ablate_and_eval( X_ood_n, m, scaler, tumor_ood, model, layer, device, None, tumor_clf, hosp_ood, tumor_ood, ) r_oods.append(h_ood_r) if X_id_n is not None: h_id_r, _, _ = _ablate_and_eval( X_id_n, m, scaler, tumor_id, model, layer, device, None, tumor_clf, hosp_id, tumor_id, ) r_ids.append(h_id_r) row["random_head_ood_mean"] = float(np.mean(r_oods)) row["random_head_ood_std"] = float(np.std(r_oods)) if r_ids: row["random_head_id_mean"] = float(np.mean(r_ids)) row["random_head_id_std"] = float(np.std(r_ids)) else: row["random_head_ood_mean"] = row["shortcut_head_ood"] # K=0 same as baseline row["random_head_ood_std"] = 0.0 if X_id_n is not None: row["random_head_id_mean"] = row.get("shortcut_head_id", float("nan")) row["random_head_id_std"] = 0.0 sweep.append(row) # Concise log line print(f" K={k:>4} shortcut={row['shortcut_head_ood']:.3f} " f"random={row.get('random_head_ood_mean', float('nan')):.3f}±" f"{row.get('random_head_ood_std', 0):.3f} " + (f"morphology={row.get('morphology_head_ood', float('nan')):.3f}" if include_morphology else "")) return { "run_id": run_dir.name, "epoch": epoch, "layer": layer, "max_samples": max_samples, "feature_dim": int(X_tr.shape[1]), "shortcut_scores_top10": [int(i) for i in rank_shortcut[:10]], "morphology_scores_top10": ([int(i) for i in rank_morphology[:10]] if rank_morphology is not None else []), "n_random_samples": n_random_samples, "include_id": include_id, "include_morphology": include_morphology, "sweep": sweep, } def plot_neuron_ablation(result: Dict, out_path: Path): sweep = result["sweep"] ks = [r["k"] for r in sweep] has_id = result.get("include_id", False) has_morph = result.get("include_morphology", False) fig, axes = plt.subplots(1, 2 if has_id else 1, figsize=(13, 5)) if has_id else \ plt.subplots(1, 1, figsize=(8, 5)) if not has_id: axes = [axes] # Panel A — Head OOD: shortcut vs random (vs morphology) ax = axes[0] shortcut_ood = [r.get("shortcut_head_ood") for r in sweep] random_ood_mu = [r.get("random_head_ood_mean") for r in sweep] random_ood_sd = [r.get("random_head_ood_std", 0) for r in sweep] morphology_ood = [r.get("morphology_head_ood") for r in sweep] if has_morph else None ax.plot(ks, shortcut_ood, "r-o", lw=2.2, ms=7, label="top-K shortcut neurons (targeted)") ax.plot(ks, random_ood_mu, "k-s", lw=1.8, ms=6, label="K random neurons (control)") ax.fill_between(ks, [m - s for m, s in zip(random_ood_mu, random_ood_sd)], [m + s for m, s in zip(random_ood_mu, random_ood_sd)], color="black", alpha=0.15) if has_morph and morphology_ood is not None: ax.plot(ks, morphology_ood, "g-^", lw=1.8, ms=6, label="top-K morphology neurons (control)") base = shortcut_ood[0] ax.axhline(base, color="gray", ls=":", lw=1, alpha=0.5, label=f"K=0 baseline ({base:.3f})") ax.set_xlabel("K (neurons zeroed at avgpool)") ax.set_ylabel("Head OOD (H4) accuracy") ax.set_xscale("symlog", linthresh=4) ax.set_title("Targeted vs random ablation — OOD effect\n" "(separation = shortcut neurons selectively hurt OOD)", fontweight="bold", fontsize=10) ax.legend(loc="best", fontsize=8); ax.grid(alpha=0.3) # Panel B — ID/OOD tradeoff if has_id: ax = axes[1] shortcut_id = [r.get("shortcut_head_id") for r in sweep] random_id_mu = [r.get("random_head_id_mean") for r in sweep] random_id_sd = [r.get("random_head_id_std", 0) for r in sweep] ax.plot(ks, shortcut_id, "r--o", lw=2, ms=7, alpha=0.85, label="ID (shortcut ablation)") ax.plot(ks, shortcut_ood, "r-o", lw=2, ms=7, label="OOD (shortcut ablation)") ax.plot(ks, random_id_mu, "k--s", lw=1.6, ms=5, alpha=0.7, label="ID (random ablation)") ax.plot(ks, random_ood_mu, "k-s", lw=1.6, ms=5, alpha=0.7, label="OOD (random ablation)") ax.set_xlabel("K (neurons zeroed at avgpool)") ax.set_ylabel("Head accuracy") ax.set_xscale("symlog", linthresh=4) ax.set_title("ID vs OOD degradation tradeoff\n" "(targeted: OOD steady or ↑ while ID slowly ↓ = good)", fontweight="bold", fontsize=10) ax.legend(fontsize=8, loc="best"); ax.grid(alpha=0.3) fig.suptitle(f"M6 — Targeted Neuron Ablation vs Random Control: {result['run_id']} " f"• ep{result['epoch']}", fontsize=11, fontweight="bold") plt.tight_layout() fig.savefig(out_path, dpi=180, bbox_inches="tight") plt.close(fig) def main(): p = argparse.ArgumentParser() p.add_argument("--run_dir", required=True) p.add_argument("--data_root", default="data/wilds") p.add_argument("--epoch", type=int, default=None) p.add_argument("--max_samples", type=int, default=1000) p.add_argument("--device", default="cuda") p.add_argument("--ks", default=None, help="Comma-separated K values, e.g. '0,4,8,16,32,64,128,256'") p.add_argument("--n_random_samples", type=int, default=5, help="Random ablation: averages over this many random K-subsets") p.add_argument("--no_morphology", action="store_true", help="Skip the morphology-targeted ablation control") p.add_argument("--no_id", action="store_true", help="Skip ID accuracy evaluation (faster but loses ID/OOD ratio)") args = p.parse_args() ks = None if args.ks is not None: ks = [int(x) for x in args.ks.split(",")] run_dir = Path(args.run_dir) out_dir = run_dir / "mechinterp" out_dir.mkdir(parents=True, exist_ok=True) result = run_neuron_ablation( run_dir=run_dir, data_root=args.data_root, epoch=args.epoch, max_samples=args.max_samples, device=args.device, ks=ks, n_random_samples=args.n_random_samples, include_morphology=not args.no_morphology, include_id=not args.no_id, ) base = out_dir / f"m6_neuron_ablation_ep{result['epoch']:05d}" base.with_suffix(".json").write_text(json.dumps(result, indent=2)) plot_neuron_ablation(result, base.with_suffix(".png")) print(f"\n → {base.with_suffix('.json')}") print(f" → {base.with_suffix('.png')}") if __name__ == "__main__": main() ``` --- ## 16. Run inventory and summary results 14 runs at the canonical 3000-epoch config. The 11 runs with full M4/M5/M6 trajectories carry the paper's mechanistic claims; the 3 n=300 grokking runs have M1 only. ### `n = 1000` (5 grokking-favorable + 3 standard) | Cond. | Seed | Run ID | Peak OOD | Peak ep | Final OOD | Δ (ungrok) | Best ID | Final ‖W‖ | Final rank | Final IRM | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | | Grok | 7 | 20260508-183413_grokking_n1000_s7 | 0.6876 | 50 | 0.5882 | −0.0995 | 0.8797 | 1516.7 | 69.83 | 4.47e-12 | | Grok | 42 | 20260505-080445_grokking_n1000_s42 | 0.7336 | 350 | 0.6639 | −0.0696 | 0.8976 | 1470.5 | 35.87 | 4.34e-12 | | Grok | 123 | 20260505-100720_grokking_n1000_s123 | 0.7270 | 350 | 0.6447 | −0.0823 | 0.8994 | 1457.2 | 56.65 | 6.73e-15 | | Grok | 456 | 20260505-100720_grokking_n1000_s456 | 0.6722 | 1100 | 0.5224 | −0.1498 | 0.8824 | 1493.6 | 64.54 | 2.87e-09 | | Grok | 2024 | 20260508-183413_grokking_n1000_s2024 | 0.7056 | 400 | 0.5506 | −0.1550 | 0.8959 | 1632.4 | 65.77 | 4.56e-07 | | Std | 42 | 20260505-100720_standard_n1000_s42 | 0.7615 | 1 | 0.6482 | −0.1133 | 0.9011 | 812.6 | 33.35 | 2.24e-13 | | Std | 123 | 20260508-183413_standard_n1000_s123 | 0.8880* | 1 | 0.6645 | −0.2235 | 0.8957 | 798.3 | 37.08 | 9.53e-14 | | Std | 456 | 20260508-183413_standard_n1000_s456 | 0.7450 | 1050 | 0.5783 | −0.1667 | 0.8950 | 792.4 | 35.30 | 7.18e-09 | \* Std s123 peaks at epoch 1 on the random initialization (artifact). **Aggregates (n=1000)**: grokking 5-seed mean peak = **0.7052 ± 0.0237**, mean Δ = **−0.1112 ± 0.0345**. Standard corrected 2-seed mean peak (s42, s456) = **0.7533**; raw 3-seed mean = 0.7982. ### `n = 500` and `n = 300` | Cond. | Seed | Run ID | Peak OOD | Peak ep | Final OOD | Δ | Best ID | | --- | --- | --- | --- | --- | --- | --- | --- | | Grok | 42 | 20260505-080442_grokking_n500_s42 | 0.7924 | 50 | 0.5514 | −0.2410 | 0.8874 | | Std | 42 | 20260505-100720_standard_n500_s42 | 0.7576 | 1050 | 0.6526 | −0.1050 | 0.8867 | | Grok | 42 | 20260502-214859_grokking_n300_s42 | 0.7162 | 250 | 0.5189 | −0.1974 | 0.8664 | | Grok | 123 | 20260502-214859_grokking_n300_s123 | 0.6961 | 50 | 0.5154 | −0.1807 | 0.8388 | | Grok | 456 | 20260502-214859_grokking_n300_s456 | 0.6654 | 750 | 0.5469 | −0.1184 | 0.8522 | | Std | 42 | 20260505-080836_standard_n300_s42 | 0.7647 | 250 | 0.7052 | −0.0596 | 0.8584 | **Universal finding**: every run *ungrokks* (Δ < 0 for all 14). No run shows a plateau-then-jump. `grokking_epoch = -1` everywhere; `irm_drop_pct ≈ 100%` everywhere. --- ## 17. Per-run `config.json` and `summary.json` (all 14 runs) Every run's exact saved JSON, verbatim from disk. ### `20260502-214859_grokking_n300_s42` `config.json`: ```json { "seed": 42, "n_train": 300, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "grokking", "lr": 0.001, "weight_decay": 0.005, "n_epochs": 3000, "init_scale": 4.0, "use_grokfast": true, "grokfast_alpha": 0.98, "grokfast_lamb": 2.0, "grad_clip": 1.0, "run_id": "20260502-214859_grokking_n300_s42", "run_dir": "experiments/runs/20260502-214859_grokking_n300_s42" } ``` `results/summary.json`: ```json { "run_id": "20260502-214859_grokking_n300_s42", "condition": "grokking", "n_train": 300, "seed": 42, "best_id_val": 0.866388557806913, "best_ood": 0.7162273379264937, "peak_ood_epoch": 250, "final_ood": 0.5188703647094787, "ood_delta": -0.19735697321701506, "ood_improvement": 0.01951701272132994, "grokking_epoch": -1, "final_weight_norm": 1131.2872887615824, "final_feature_rank": 39.32365036010742, "final_irm": 6.867336560523185e-14, "final_shortcut_ratio": 1.0053635176200695, "final_ood_gap": 0.3284478712857537, "ungrokking_detected": true } ``` ### `20260502-214859_grokking_n300_s123` `config.json`: ```json { "seed": 123, "n_train": 300, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "grokking", "lr": 0.001, "weight_decay": 0.005, "n_epochs": 3000, "init_scale": 4.0, "use_grokfast": true, "grokfast_alpha": 0.98, "grokfast_lamb": 2.0, "grad_clip": 1.0, "run_id": "20260502-214859_grokking_n300_s123", "run_dir": "experiments/runs/20260502-214859_grokking_n300_s123" } ``` `results/summary.json`: ```json { "run_id": "20260502-214859_grokking_n300_s123", "condition": "grokking", "n_train": 300, "seed": 123, "best_id_val": 0.8387663885578069, "best_ood": 0.6961459778493663, "peak_ood_epoch": 50, "final_ood": 0.515413737155219, "ood_delta": -0.18073224069414728, "ood_improvement": 0.015413737155218987, "grokking_epoch": -1, "final_weight_norm": 981.2013665038359, "final_feature_rank": 44.795772552490234, "final_irm": 5.876544368309256e-13, "final_shortcut_ratio": 1.0042143792951101, "final_ood_gap": 0.23887708525240914, "ungrokking_detected": true } ``` ### `20260502-214859_grokking_n300_s456` `config.json`: ```json { "seed": 456, "n_train": 300, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "grokking", "lr": 0.001, "weight_decay": 0.005, "n_epochs": 3000, "init_scale": 4.0, "use_grokfast": true, "grokfast_alpha": 0.98, "grokfast_lamb": 2.0, "grad_clip": 1.0, "run_id": "20260502-214859_grokking_n300_s456", "run_dir": "experiments/runs/20260502-214859_grokking_n300_s456" } ``` `results/summary.json`: ```json { "run_id": "20260502-214859_grokking_n300_s456", "condition": "grokking", "n_train": 300, "seed": 456, "best_id_val": 0.8521752085816449, "best_ood": 0.6653655324852447, "peak_ood_epoch": 750, "final_ood": 0.5469348884238249, "ood_delta": -0.11843064406141979, "ood_improvement": 0.04693488842382487, "grokking_epoch": -1, "final_weight_norm": 1109.5530019538146, "final_feature_rank": 49.14884567260742, "final_irm": 8.174740884214771e-10, "final_shortcut_ratio": 0.9680124053677005, "final_ood_gap": 0.25908418189798677, "ungrokking_detected": true } ``` ### `20260505-080836_standard_n300_s42` `config.json`: ```json { "seed": 42, "n_train": 300, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "standard", "lr": 0.001, "weight_decay": 0.0001, "n_epochs": 3000, "init_scale": 1.0, "use_grokfast": false, "grad_clip": 1.0, "run_id": "20260505-080836_standard_n300_s42", "run_dir": "experiments/runs/20260505-080836_standard_n300_s42" } ``` `results/summary.json`: ```json { "run_id": "20260505-080836_standard_n300_s42", "condition": "standard", "n_train": 300, "seed": 42, "best_id_val": 0.858373063170441, "best_ood": 0.7647259388153408, "peak_ood_epoch": 250, "final_ood": 0.7051637783055471, "ood_delta": -0.05956216050979368, "ood_improvement": -0.026599572036588692, "grokking_epoch": -1, "irm_drop_pct": 99.99996794236725, "irm_drop_epoch": 100, "epoch_gap": -1, "final_weight_norm": 452.7899771783757, "final_feature_rank": 17.35470962524414, "final_irm": 1.032695706726372e-09, "final_shortcut_ratio": 1.0080115087353168, "final_ood_gap": 0.12806029797574014 } ``` ### `20260505-080442_grokking_n500_s42` `config.json`: ```json { "seed": 42, "n_train": 500, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "grokking", "lr": 0.001, "weight_decay": 0.005, "n_epochs": 3000, "init_scale": 4.0, "use_grokfast": true, "grokfast_alpha": 0.98, "grokfast_lamb": 2.0, "grad_clip": 1.0, "run_id": "20260505-080442_grokking_n500_s42", "run_dir": "experiments/runs/20260505-080442_grokking_n500_s42" } ``` `results/summary.json`: ```json { "run_id": "20260505-080442_grokking_n500_s42", "condition": "grokking", "n_train": 500, "seed": 42, "best_id_val": 0.8873957091775924, "best_ood": 0.7924495026688927, "peak_ood_epoch": 50, "final_ood": 0.5514496672702048, "ood_delta": -0.24099983539868786, "ood_improvement": -0.06643779246126003, "grokking_epoch": -1, "irm_drop_pct": 99.99999995667906, "irm_drop_epoch": 50, "epoch_gap": -1, "final_weight_norm": 1046.672482929869, "final_feature_rank": 34.301780700683594, "final_irm": 6.552892841682478e-07, "final_shortcut_ratio": 0.9969108279290858, "final_ood_gap": 0.32080897396936603 } ``` ### `20260505-100720_standard_n500_s42` `config.json`: ```json { "seed": 42, "n_train": 500, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "standard", "lr": 0.001, "weight_decay": 0.0001, "n_epochs": 3000, "init_scale": 1.0, "use_grokfast": false, "grad_clip": 1.0, "run_id": "20260505-100720_standard_n500_s42", "run_dir": "experiments/runs/20260505-100720_standard_n500_s42" } ``` `results/summary.json`: ```json { "run_id": "20260505-100720_standard_n500_s42", "condition": "standard", "n_train": 500, "seed": 42, "best_id_val": 0.8866507747318236, "best_ood": 0.7575775389752393, "peak_ood_epoch": 1050, "final_ood": 0.6525854163237472, "ood_delta": -0.10499212265149205, "ood_improvement": -0.08494603428410186, "grokking_epoch": -1, "irm_drop_pct": 99.99998953242944, "irm_drop_epoch": 50, "epoch_gap": -1, "final_weight_norm": 555.1170332945212, "final_feature_rank": 20.722705841064453, "final_irm": 5.05333608291636e-10, "final_shortcut_ratio": 0.9843676046096169, "final_ood_gap": 0.19765296269889876 } ``` ### `20260505-080445_grokking_n1000_s42` `config.json`: ```json { "seed": 42, "n_train": 1000, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "grokking", "lr": 0.001, "weight_decay": 0.005, "n_epochs": 3000, "init_scale": 4.0, "use_grokfast": true, "grokfast_alpha": 0.98, "grokfast_lamb": 2.0, "grad_clip": 1.0, "run_id": "20260505-080445_grokking_n1000_s42", "run_dir": "experiments/runs/20260505-080445_grokking_n1000_s42" } ``` `results/summary.json`: ```json { "run_id": "20260505-080445_grokking_n1000_s42", "condition": "grokking", "n_train": 1000, "seed": 42, "best_id_val": 0.8976460071513707, "best_ood": 0.7335575046441084, "peak_ood_epoch": 350, "final_ood": 0.6639076351494345, "ood_delta": -0.06964986949467389, "ood_improvement": 0.011399816587109313, "grokking_epoch": -1, "irm_drop_pct": 99.99999500910222, "irm_drop_epoch": 50, "epoch_gap": -1, "final_weight_norm": 1470.4773196265805, "final_feature_rank": 35.86945724487305, "final_irm": 4.340286376830482e-12, "final_shortcut_ratio": 0.9871634909839839, "final_ood_gap": 0.20680154244293736 } ``` ### `20260505-100720_grokking_n1000_s123` `config.json`: ```json { "seed": 123, "n_train": 1000, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "grokking", "lr": 0.001, "weight_decay": 0.005, "n_epochs": 3000, "init_scale": 4.0, "use_grokfast": true, "grokfast_alpha": 0.98, "grokfast_lamb": 2.0, "grad_clip": 1.0, "run_id": "20260505-100720_grokking_n1000_s123", "run_dir": "experiments/runs/20260505-100720_grokking_n1000_s123" } ``` `results/summary.json`: ```json { "run_id": "20260505-100720_grokking_n1000_s123", "condition": "grokking", "n_train": 1000, "seed": 123, "best_id_val": 0.8994338498212158, "best_ood": 0.7269734521598044, "peak_ood_epoch": 350, "final_ood": 0.6446610388694242, "ood_delta": -0.08231241329038019, "ood_improvement": 0.01956639311496222, "grokking_epoch": -1, "irm_drop_pct": 99.9999991755092, "irm_drop_epoch": 50, "epoch_gap": -1, "final_weight_norm": 1457.2260357210637, "final_feature_rank": 56.64516067504883, "final_irm": 6.7278793620213426e-15, "final_shortcut_ratio": 0.9760735232865748, "final_ood_gap": 0.23534492060614198 } ``` ### `20260505-100720_grokking_n1000_s456` `config.json`: ```json { "seed": 456, "n_train": 1000, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "grokking", "lr": 0.001, "weight_decay": 0.005, "n_epochs": 3000, "init_scale": 4.0, "use_grokfast": true, "grokfast_alpha": 0.98, "grokfast_lamb": 2.0, "grad_clip": 1.0, "run_id": "20260505-100720_grokking_n1000_s456", "run_dir": "experiments/runs/20260505-100720_grokking_n1000_s456" } ``` `results/summary.json`: ```json { "run_id": "20260505-100720_grokking_n1000_s456", "condition": "grokking", "n_train": 1000, "seed": 456, "best_id_val": 0.8824493444576877, "best_ood": 0.6721847297011311, "peak_ood_epoch": 1100, "final_ood": 0.522397535683213, "ood_delta": -0.14978719401791807, "ood_improvement": -0.08302960472170617, "grokking_epoch": -1, "irm_drop_pct": 99.99999977269624, "irm_drop_epoch": 50, "epoch_gap": -1, "final_weight_norm": 1493.580040593733, "final_feature_rank": 64.54296875, "final_irm": 2.8693030174054e-09, "final_shortcut_ratio": 1.0356636404914459, "final_ood_gap": 0.3221495441737596 } ``` ### `20260508-183413_grokking_n1000_s7` `config.json`: ```json { "seed": 7, "n_train": 1000, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "grokking", "lr": 0.001, "weight_decay": 0.005, "n_epochs": 3000, "init_scale": 4.0, "use_grokfast": true, "grokfast_alpha": 0.98, "grokfast_lamb": 2.0, "grad_clip": 1.0, "run_id": "20260508-183413_grokking_n1000_s7", "run_dir": "experiments/runs/20260508-183413_grokking_n1000_s7" } ``` `results/summary.json`: ```json { "run_id": "20260508-183413_grokking_n1000_s7", "condition": "grokking", "n_train": 1000, "seed": 7, "best_id_val": 0.8797079856972586, "best_ood": 0.6876454958026665, "peak_ood_epoch": 50, "final_ood": 0.5881910315799375, "ood_delta": -0.09945446422272908, "ood_improvement": -0.03798527993980294, "grokking_epoch": -1, "irm_drop_pct": 99.99999355254296, "irm_drop_epoch": 50, "epoch_gap": -1, "final_weight_norm": 1516.661424559645, "final_feature_rank": 69.8335952758789, "final_irm": 4.471252361415434e-12, "final_shortcut_ratio": 0.996297612800058, "final_ood_gap": 0.26586141180504463 } ``` ### `20260508-183413_grokking_n1000_s2024` `config.json`: ```json { "seed": 2024, "n_train": 1000, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "grokking", "lr": 0.001, "weight_decay": 0.005, "n_epochs": 3000, "init_scale": 4.0, "use_grokfast": true, "grokfast_alpha": 0.98, "grokfast_lamb": 2.0, "grad_clip": 1.0, "run_id": "20260508-183413_grokking_n1000_s2024", "run_dir": "experiments/runs/20260508-183413_grokking_n1000_s2024" } ``` `results/summary.json`: ```json { "run_id": "20260508-183413_grokking_n1000_s2024", "condition": "grokking", "n_train": 1000, "seed": 2024, "best_id_val": 0.8959177592371871, "best_ood": 0.7056105532955534, "peak_ood_epoch": 400, "final_ood": 0.5506031462365085, "ood_delta": -0.15500740705904492, "ood_improvement": -0.04230488865896964, "grokking_epoch": -1, "irm_drop_pct": 99.9999998741651, "irm_drop_epoch": 50, "epoch_gap": -1, "final_weight_norm": 1632.3948325021925, "final_feature_rank": 65.77389526367188, "final_irm": 4.5635979972757923e-07, "final_shortcut_ratio": 0.9633737610070339, "final_ood_gap": 0.308663838268855 } ``` ### `20260505-100720_standard_n1000_s42` `config.json`: ```json { "seed": 42, "n_train": 1000, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "standard", "lr": 0.001, "weight_decay": 0.0001, "n_epochs": 3000, "init_scale": 1.0, "use_grokfast": false, "grad_clip": 1.0, "run_id": "20260505-100720_standard_n1000_s42", "run_dir": "experiments/runs/20260505-100720_standard_n1000_s42" } ``` `results/summary.json`: ```json { "run_id": "20260505-100720_standard_n1000_s42", "condition": "standard", "n_train": 1000, "seed": 42, "best_id_val": 0.9011025029797378, "best_ood": 0.7615162132292426, "peak_ood_epoch": 1, "final_ood": 0.6482234815528958, "ood_delta": -0.11329273167634679, "ood_improvement": -0.022357561078844124, "grokking_epoch": -1, "irm_drop_pct": 99.99999329006182, "irm_drop_epoch": 50, "epoch_gap": -1, "final_weight_norm": 812.5540534619066, "final_feature_rank": 33.34842300415039, "final_irm": 2.2439123855463178e-13, "final_shortcut_ratio": 0.99423543483118, "final_ood_gap": 0.24736650652815306 } ``` ### `20260508-183413_standard_n1000_s123` `config.json`: ```json { "seed": 123, "n_train": 1000, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "standard", "lr": 0.001, "weight_decay": 0.0001, "n_epochs": 3000, "init_scale": 1.0, "use_grokfast": false, "grad_clip": 1.0, "run_id": "20260508-183413_standard_n1000_s123", "run_dir": "experiments/runs/20260508-183413_standard_n1000_s123" } ``` `results/summary.json`: ```json { "run_id": "20260508-183413_standard_n1000_s123", "condition": "standard", "n_train": 1000, "seed": 123, "best_id_val": 0.8957091775923719, "best_ood": 0.8879652926376185, "peak_ood_epoch": 1, "final_ood": 0.6644837397418111, "ood_delta": -0.2234815528958074, "ood_improvement": -0.10371998965363183, "grokking_epoch": -1, "irm_drop_pct": 99.9999667168104, "irm_drop_epoch": 150, "epoch_gap": -1, "final_weight_norm": 798.3091903337586, "final_feature_rank": 37.0809326171875, "final_irm": 9.526699497634447e-14, "final_shortcut_ratio": 0.9913218948516812, "final_ood_gap": 0.22869266073494698 } ``` ### `20260508-183413_standard_n1000_s456` `config.json`: ```json { "seed": 456, "n_train": 1000, "batch_size": 32, "img_size": 96, "n_classes": 2, "log_every": 50, "device": "cuda", "condition": "standard", "lr": 0.001, "weight_decay": 0.0001, "n_epochs": 3000, "init_scale": 1.0, "use_grokfast": false, "grad_clip": 1.0, "run_id": "20260508-183413_standard_n1000_s456", "run_dir": "experiments/runs/20260508-183413_standard_n1000_s456" } ``` `results/summary.json`: ```json { "run_id": "20260508-183413_standard_n1000_s456", "condition": "standard", "n_train": 1000, "seed": 456, "best_id_val": 0.8949940405244339, "best_ood": 0.7449737813624285, "peak_ood_epoch": 1050, "final_ood": 0.5783149528534813, "ood_delta": -0.1666588285089472, "ood_improvement": 0.02752839372633853, "grokking_epoch": -1, "irm_drop_pct": 99.99999467428148, "irm_drop_epoch": 50, "epoch_gap": -1, "final_weight_norm": 792.3747983792097, "final_feature_rank": 35.297889709472656, "final_irm": 7.180730676736857e-09, "final_shortcut_ratio": 0.987733651871859, "final_ood_gap": 0.2792237837376986 } ``` --- ## 18. Full training log: grokking n=1000 seed=42 ``` # launched: 2026-05-05T08:04:45Z # host: ubuntu-Standard-PC-Q35-ICH9-2009 # pwd: /home/garima/CausalGrok # cmd: /home/garima/anaconda3/envs/causalgrok/bin/python -u -m experiments.causalgrok_camelyon_v2 --condition grokking --n_train 1000 --seed 42 --run_dir experiments/runs/20260505-080445_grokking_n1000_s42 --wandb_project causalgrok --wandb_mode offline ---- Device: cuda Run ID: 20260505-080445_grokking_n1000_s42 Started: 2026-05-05T08:04:50.408035+00:00 Env hospital=0: 181 samples, positive rate=0.53 Env hospital=3: 371 samples, positive rate=0.48 Env hospital=4: 448 samples, positive rate=0.50 Train: 1000 | ID val (H3): 33560 | OOD test (H4): 85054 Params: 11,177,538 ============================================================ GROKKING | Camelyon17 v2 | 3000 epochs WD=0.005 | α=4.0 | n=1000 Tracking: ID val (H3) + OOD test (H4) at every checkpoint Grokking detection: watching OOD acc, not ID val acc IRM envs: 3 hospitals ============================================================ ep 1 | tr 0.734 | id 0.728 | ood 0.499 | gap +0.228 | ‖W‖ 356.1 | rank 109.6 | IRM 0.2004 | sc 1.12x ep 50 | tr 0.996 | id 0.859 | ood 0.697 | gap +0.162 | ‖W‖ 495.4 | rank 91.4 | IRM 0.0001 | sc 0.97x ep 100 | tr 1.000 | id 0.883 | ood 0.592 | gap +0.291 | ‖W‖ 551.6 | rank 81.7 | IRM 0.0000 | sc 0.96x ep 150 | tr 1.000 | id 0.872 | ood 0.641 | gap +0.231 | ‖W‖ 650.2 | rank 72.2 | IRM 0.0000 | sc 0.98x ✓ Checkpoint → ep00200.pt ep 200 | tr 1.000 | id 0.885 | ood 0.613 | gap +0.272 | ‖W‖ 741.5 | rank 62.6 | IRM 0.0000 | sc 0.97x ep 250 | tr 1.000 | id 0.890 | ood 0.659 | gap +0.231 | ‖W‖ 842.7 | rank 61.1 | IRM 0.0000 | sc 0.99x ep 300 | tr 1.000 | id 0.876 | ood 0.650 | gap +0.227 | ‖W‖ 878.1 | rank 60.7 | IRM 0.0000 | sc 0.99x ep 350 | tr 1.000 | id 0.881 | ood 0.734 | gap +0.147 | ‖W‖ 907.9 | rank 53.9 | IRM 0.0000 | sc 0.97x ✓ Checkpoint → ep00400.pt ep 400 | tr 1.000 | id 0.875 | ood 0.647 | gap +0.229 | ‖W‖ 1005.7 | rank 60.9 | IRM 0.0000 | sc 1.00x ep 450 | tr 1.000 | id 0.876 | ood 0.615 | gap +0.261 | ‖W‖ 1031.2 | rank 57.8 | IRM 0.0000 | sc 0.99x ep 500 | tr 1.000 | id 0.880 | ood 0.607 | gap +0.273 | ‖W‖ 1023.1 | rank 55.7 | IRM 0.0000 | sc 0.99x ep 550 | tr 1.000 | id 0.888 | ood 0.611 | gap +0.276 | ‖W‖ 1101.7 | rank 52.6 | IRM 0.0000 | sc 0.99x ✓ Checkpoint → ep00600.pt ep 600 | tr 1.000 | id 0.884 | ood 0.546 | gap +0.337 | ‖W‖ 1102.1 | rank 54.5 | IRM 0.0000 | sc 1.00x ep 650 | tr 1.000 | id 0.878 | ood 0.605 | gap +0.273 | ‖W‖ 1142.4 | rank 55.2 | IRM 0.0000 | sc 0.98x ep 700 | tr 1.000 | id 0.889 | ood 0.587 | gap +0.303 | ‖W‖ 1187.3 | rank 47.9 | IRM 0.0000 | sc 0.98x ep 750 | tr 1.000 | id 0.892 | ood 0.670 | gap +0.222 | ‖W‖ 1180.7 | rank 48.4 | IRM 0.0000 | sc 0.98x ✓ Checkpoint → ep00800.pt ep 800 | tr 1.000 | id 0.879 | ood 0.675 | gap +0.204 | ‖W‖ 1265.4 | rank 48.0 | IRM 0.0000 | sc 0.99x ep 850 | tr 1.000 | id 0.886 | ood 0.627 | gap +0.259 | ‖W‖ 1300.7 | rank 48.3 | IRM 0.0000 | sc 0.98x ep 900 | tr 1.000 | id 0.886 | ood 0.653 | gap +0.233 | ‖W‖ 1290.6 | rank 48.0 | IRM 0.0000 | sc 0.98x ep 950 | tr 1.000 | id 0.883 | ood 0.640 | gap +0.243 | ‖W‖ 1280.4 | rank 47.1 | IRM 0.0000 | sc 0.98x ✓ Checkpoint → ep01000.pt ep 1000 | tr 1.000 | id 0.886 | ood 0.653 | gap +0.233 | ‖W‖ 1270.3 | rank 47.3 | IRM 0.0000 | sc 0.99x ep 1050 | tr 1.000 | id 0.898 | ood 0.697 | gap +0.201 | ‖W‖ 1289.1 | rank 42.4 | IRM 0.0000 | sc 0.99x ep 1100 | tr 0.999 | id 0.876 | ood 0.641 | gap +0.235 | ‖W‖ 1308.0 | rank 46.4 | IRM 0.0000 | sc 0.99x ep 1150 | tr 1.000 | id 0.894 | ood 0.663 | gap +0.231 | ‖W‖ 1313.5 | rank 45.8 | IRM 0.0000 | sc 0.99x ✓ Checkpoint → ep01200.pt ep 1200 | tr 0.999 | id 0.878 | ood 0.592 | gap +0.286 | ‖W‖ 1320.4 | rank 42.9 | IRM 0.0000 | sc 0.99x ep 1250 | tr 1.000 | id 0.877 | ood 0.546 | gap +0.330 | ‖W‖ 1373.5 | rank 47.2 | IRM 0.0000 | sc 0.99x ep 1300 | tr 1.000 | id 0.853 | ood 0.600 | gap +0.254 | ‖W‖ 1375.3 | rank 48.0 | IRM 0.0000 | sc 0.98x ep 1350 | tr 1.000 | id 0.881 | ood 0.575 | gap +0.306 | ‖W‖ 1415.1 | rank 49.4 | IRM 0.0000 | sc 0.99x ✓ Checkpoint → ep01400.pt ep 1400 | tr 1.000 | id 0.892 | ood 0.605 | gap +0.287 | ‖W‖ 1428.3 | rank 42.5 | IRM 0.0000 | sc 1.00x ep 1450 | tr 1.000 | id 0.867 | ood 0.619 | gap +0.248 | ‖W‖ 1438.2 | rank 41.4 | IRM 0.0000 | sc 0.99x ep 1500 | tr 1.000 | id 0.875 | ood 0.652 | gap +0.223 | ‖W‖ 1440.8 | rank 47.2 | IRM 0.0000 | sc 0.97x ep 1550 | tr 1.000 | id 0.879 | ood 0.643 | gap +0.236 | ‖W‖ 1429.5 | rank 45.8 | IRM 0.0000 | sc 0.97x ✓ Checkpoint → ep01600.pt ep 1600 | tr 1.000 | id 0.878 | ood 0.630 | gap +0.248 | ‖W‖ 1418.2 | rank 46.4 | IRM 0.0000 | sc 0.98x ep 1650 | tr 1.000 | id 0.881 | ood 0.636 | gap +0.245 | ‖W‖ 1406.9 | rank 45.9 | IRM 0.0000 | sc 0.98x ep 1700 | tr 1.000 | id 0.883 | ood 0.634 | gap +0.249 | ‖W‖ 1395.7 | rank 47.0 | IRM 0.0000 | sc 0.98x ep 1750 | tr 1.000 | id 0.875 | ood 0.702 | gap +0.173 | ‖W‖ 1441.2 | rank 45.9 | IRM 0.0000 | sc 0.99x ✓ Checkpoint → ep01800.pt ep 1800 | tr 1.000 | id 0.880 | ood 0.648 | gap +0.232 | ‖W‖ 1433.0 | rank 45.6 | IRM 0.0000 | sc 0.98x ep 1850 | tr 1.000 | id 0.880 | ood 0.644 | gap +0.236 | ‖W‖ 1421.6 | rank 45.8 | IRM 0.0000 | sc 0.98x ep 1900 | tr 1.000 | id 0.885 | ood 0.644 | gap +0.242 | ‖W‖ 1410.4 | rank 43.9 | IRM 0.0000 | sc 0.98x ep 1950 | tr 1.000 | id 0.883 | ood 0.549 | gap +0.334 | ‖W‖ 1415.3 | rank 49.1 | IRM 0.0000 | sc 0.98x ✓ Checkpoint → ep02000.pt ep 2000 | tr 1.000 | id 0.889 | ood 0.581 | gap +0.308 | ‖W‖ 1404.6 | rank 47.1 | IRM 0.0000 | sc 0.98x ep 2050 | tr 1.000 | id 0.888 | ood 0.577 | gap +0.311 | ‖W‖ 1393.4 | rank 46.6 | IRM 0.0000 | sc 0.98x ep 2100 | tr 1.000 | id 0.884 | ood 0.617 | gap +0.266 | ‖W‖ 1460.6 | rank 33.9 | IRM 0.0000 | sc 1.00x ep 2150 | tr 1.000 | id 0.870 | ood 0.597 | gap +0.273 | ‖W‖ 1470.9 | rank 37.5 | IRM 0.0000 | sc 1.00x ✓ Checkpoint → ep02200.pt ep 2200 | tr 1.000 | id 0.869 | ood 0.568 | gap +0.301 | ‖W‖ 1460.2 | rank 38.5 | IRM 0.0000 | sc 0.99x ep 2250 | tr 1.000 | id 0.870 | ood 0.588 | gap +0.282 | ‖W‖ 1448.6 | rank 36.9 | IRM 0.0000 | sc 0.98x ep 2300 | tr 0.998 | id 0.872 | ood 0.706 | gap +0.166 | ‖W‖ 1485.2 | rank 41.9 | IRM 0.0000 | sc 0.97x ep 2350 | tr 1.000 | id 0.877 | ood 0.648 | gap +0.229 | ‖W‖ 1506.9 | rank 41.2 | IRM 0.0000 | sc 1.00x ✓ Checkpoint → ep02400.pt ep 2400 | tr 1.000 | id 0.876 | ood 0.650 | gap +0.226 | ‖W‖ 1495.0 | rank 40.1 | IRM 0.0000 | sc 1.00x ep 2450 | tr 1.000 | id 0.869 | ood 0.682 | gap +0.187 | ‖W‖ 1486.6 | rank 36.8 | IRM 0.0000 | sc 0.98x ep 2500 | tr 1.000 | id 0.884 | ood 0.621 | gap +0.263 | ‖W‖ 1510.2 | rank 40.7 | IRM 0.0000 | sc 1.00x ep 2550 | tr 1.000 | id 0.876 | ood 0.667 | gap +0.209 | ‖W‖ 1499.1 | rank 38.9 | IRM 0.0000 | sc 0.99x ✓ Checkpoint → ep02600.pt ep 2600 | tr 1.000 | id 0.871 | ood 0.635 | gap +0.236 | ‖W‖ 1506.6 | rank 39.4 | IRM 0.0000 | sc 1.01x ep 2650 | tr 1.000 | id 0.874 | ood 0.692 | gap +0.183 | ‖W‖ 1500.6 | rank 36.8 | IRM 0.0000 | sc 0.99x ep 2700 | tr 1.000 | id 0.877 | ood 0.607 | gap +0.269 | ‖W‖ 1510.5 | rank 39.9 | IRM 0.0000 | sc 0.97x ep 2750 | tr 1.000 | id 0.873 | ood 0.599 | gap +0.273 | ‖W‖ 1521.2 | rank 38.5 | IRM 0.0000 | sc 1.00x ✓ Checkpoint → ep02800.pt ep 2800 | tr 1.000 | id 0.878 | ood 0.584 | gap +0.295 | ‖W‖ 1515.0 | rank 36.9 | IRM 0.0000 | sc 1.00x ep 2850 | tr 1.000 | id 0.878 | ood 0.619 | gap +0.259 | ‖W‖ 1504.0 | rank 36.2 | IRM 0.0000 | sc 0.99x ep 2900 | tr 1.000 | id 0.880 | ood 0.614 | gap +0.266 | ‖W‖ 1492.6 | rank 35.7 | IRM 0.0000 | sc 0.99x ep 2950 | tr 1.000 | id 0.871 | ood 0.619 | gap +0.252 | ‖W‖ 1482.1 | rank 36.4 | IRM 0.0000 | sc 0.99x ✓ Checkpoint → ep03000.pt ep 3000 | tr 1.000 | id 0.871 | ood 0.664 | gap +0.207 | ‖W‖ 1470.5 | rank 35.9 | IRM 0.0000 | sc 0.99x Best ID val (H3): 0.8976 Best OOD (H4): 0.7336 OOD improvement: +0.0114 ← did OOD grok? Grokking at: None IRM drop: 100.0% Wall time: 358.5 min ``` --- ## 19. Full training log: standard n=1000 seed=42 ``` # launched: 2026-05-05T10:07:20Z # host: ubuntu-Standard-PC-Q35-ICH9-2009 # pwd: /home/garima/CausalGrok # cmd: /home/garima/anaconda3/envs/causalgrok/bin/python -u -m experiments.causalgrok_camelyon_v2 --condition standard --n_train 1000 --seed 42 --run_dir experiments/runs/20260505-100720_standard_n1000_s42 --wandb_project causalgrok --wandb_mode offline --n_epochs 3000 ---- Device: cuda Run ID: 20260505-100720_standard_n1000_s42 Started: 2026-05-05T10:07:36.596218+00:00 Env hospital=0: 181 samples, positive rate=0.53 Env hospital=3: 371 samples, positive rate=0.48 Env hospital=4: 448 samples, positive rate=0.50 Train: 1000 | ID val (H3): 33560 | OOD test (H4): 85054 Params: 11,177,538 ============================================================ STANDARD | Camelyon17 v2 | 3000 epochs WD=0.0001 | α=1.0 | n=1000 Tracking: ID val (H3) + OOD test (H4) at every checkpoint Grokking detection: watching OOD acc, not ID val acc IRM envs: 3 hospitals ============================================================ ep 1 | tr 0.681 | id 0.664 | ood 0.762 | gap -0.098 | ‖W‖ 105.0 | rank 78.5 | IRM 0.1490 | sc 1.25x ep 50 | tr 0.999 | id 0.897 | ood 0.620 | gap +0.276 | ‖W‖ 162.5 | rank 45.1 | IRM 0.0000 | sc 0.99x ep 100 | tr 1.000 | id 0.901 | ood 0.722 | gap +0.179 | ‖W‖ 196.1 | rank 35.1 | IRM 0.0000 | sc 0.98x ep 150 | tr 0.997 | id 0.880 | ood 0.648 | gap +0.232 | ‖W‖ 228.8 | rank 29.4 | IRM 0.0000 | sc 0.98x ✓ Checkpoint → ep00200.pt ep 200 | tr 1.000 | id 0.886 | ood 0.611 | gap +0.274 | ‖W‖ 268.7 | rank 30.1 | IRM 0.0000 | sc 0.99x ep 250 | tr 1.000 | id 0.890 | ood 0.672 | gap +0.218 | ‖W‖ 294.3 | rank 30.9 | IRM 0.0000 | sc 0.97x ep 300 | tr 1.000 | id 0.900 | ood 0.684 | gap +0.216 | ‖W‖ 323.3 | rank 26.7 | IRM 0.0000 | sc 0.98x ep 350 | tr 1.000 | id 0.891 | ood 0.573 | gap +0.318 | ‖W‖ 343.5 | rank 27.9 | IRM 0.0000 | sc 0.98x ✓ Checkpoint → ep00400.pt ep 400 | tr 1.000 | id 0.885 | ood 0.642 | gap +0.243 | ‖W‖ 361.4 | rank 28.6 | IRM 0.0000 | sc 0.98x ep 450 | tr 1.000 | id 0.894 | ood 0.700 | gap +0.194 | ‖W‖ 377.9 | rank 31.3 | IRM 0.0000 | sc 0.98x ep 500 | tr 1.000 | id 0.890 | ood 0.705 | gap +0.185 | ‖W‖ 378.2 | rank 29.3 | IRM 0.0000 | sc 0.97x ep 550 | tr 1.000 | id 0.895 | ood 0.656 | gap +0.239 | ‖W‖ 412.9 | rank 26.5 | IRM 0.0000 | sc 0.97x ✓ Checkpoint → ep00600.pt ep 600 | tr 1.000 | id 0.862 | ood 0.717 | gap +0.145 | ‖W‖ 426.0 | rank 29.7 | IRM 0.0000 | sc 0.97x ep 650 | tr 1.000 | id 0.885 | ood 0.713 | gap +0.172 | ‖W‖ 445.0 | rank 25.9 | IRM 0.0000 | sc 0.99x ep 700 | tr 1.000 | id 0.892 | ood 0.639 | gap +0.253 | ‖W‖ 454.4 | rank 28.0 | IRM 0.0000 | sc 0.98x ep 750 | tr 1.000 | id 0.880 | ood 0.648 | gap +0.232 | ‖W‖ 472.1 | rank 25.5 | IRM 0.0000 | sc 0.98x ✓ Checkpoint → ep00800.pt ep 800 | tr 1.000 | id 0.888 | ood 0.681 | gap +0.207 | ‖W‖ 489.6 | rank 28.8 | IRM 0.0000 | sc 0.97x ep 850 | tr 1.000 | id 0.887 | ood 0.626 | gap +0.262 | ‖W‖ 506.4 | rank 28.2 | IRM 0.0000 | sc 0.99x ep 900 | tr 1.000 | id 0.888 | ood 0.703 | gap +0.185 | ‖W‖ 515.4 | rank 31.3 | IRM 0.0000 | sc 0.99x ep 950 | tr 1.000 | id 0.883 | ood 0.667 | gap +0.215 | ‖W‖ 526.2 | rank 27.6 | IRM 0.0000 | sc 1.00x ✓ Checkpoint → ep01000.pt ep 1000 | tr 1.000 | id 0.897 | ood 0.674 | gap +0.222 | ‖W‖ 530.5 | rank 26.2 | IRM 0.0000 | sc 1.00x ep 1050 | tr 1.000 | id 0.896 | ood 0.581 | gap +0.315 | ‖W‖ 544.2 | rank 26.3 | IRM 0.0000 | sc 0.98x ep 1100 | tr 1.000 | id 0.877 | ood 0.655 | gap +0.222 | ‖W‖ 559.3 | rank 26.2 | IRM 0.0000 | sc 1.00x ep 1150 | tr 1.000 | id 0.899 | ood 0.694 | gap +0.205 | ‖W‖ 562.6 | rank 23.7 | IRM 0.0000 | sc 0.99x ✓ Checkpoint → ep01200.pt ep 1200 | tr 1.000 | id 0.892 | ood 0.578 | gap +0.315 | ‖W‖ 581.9 | rank 27.0 | IRM 0.0000 | sc 0.98x ep 1250 | tr 1.000 | id 0.889 | ood 0.647 | gap +0.242 | ‖W‖ 595.8 | rank 30.1 | IRM 0.0000 | sc 0.99x ep 1300 | tr 1.000 | id 0.880 | ood 0.616 | gap +0.264 | ‖W‖ 604.9 | rank 31.7 | IRM 0.0000 | sc 0.98x ep 1350 | tr 1.000 | id 0.878 | ood 0.649 | gap +0.228 | ‖W‖ 606.3 | rank 30.0 | IRM 0.0000 | sc 0.98x ✓ Checkpoint → ep01400.pt ep 1400 | tr 1.000 | id 0.879 | ood 0.735 | gap +0.144 | ‖W‖ 624.0 | rank 32.7 | IRM 0.0000 | sc 0.96x ep 1450 | tr 1.000 | id 0.885 | ood 0.703 | gap +0.182 | ‖W‖ 625.3 | rank 33.0 | IRM 0.0000 | sc 0.98x ep 1500 | tr 1.000 | id 0.894 | ood 0.686 | gap +0.207 | ‖W‖ 626.2 | rank 30.6 | IRM 0.0000 | sc 0.98x ep 1550 | tr 1.000 | id 0.879 | ood 0.670 | gap +0.209 | ‖W‖ 640.3 | rank 32.3 | IRM 0.0000 | sc 0.97x ✓ Checkpoint → ep01600.pt ep 1600 | tr 1.000 | id 0.871 | ood 0.669 | gap +0.202 | ‖W‖ 653.6 | rank 31.0 | IRM 0.0000 | sc 0.98x ep 1650 | tr 1.000 | id 0.886 | ood 0.540 | gap +0.346 | ‖W‖ 662.7 | rank 34.4 | IRM 0.0000 | sc 0.98x ep 1700 | tr 1.000 | id 0.897 | ood 0.659 | gap +0.239 | ‖W‖ 667.4 | rank 38.2 | IRM 0.0000 | sc 0.97x ep 1750 | tr 1.000 | id 0.871 | ood 0.716 | gap +0.155 | ‖W‖ 678.7 | rank 31.7 | IRM 0.0000 | sc 0.97x ✓ Checkpoint → ep01800.pt ep 1800 | tr 1.000 | id 0.887 | ood 0.665 | gap +0.222 | ‖W‖ 682.0 | rank 33.0 | IRM 0.0000 | sc 0.98x ep 1850 | tr 1.000 | id 0.892 | ood 0.598 | gap +0.294 | ‖W‖ 685.4 | rank 31.1 | IRM 0.0000 | sc 0.98x ep 1900 | tr 1.000 | id 0.890 | ood 0.574 | gap +0.316 | ‖W‖ 690.9 | rank 32.8 | IRM 0.0000 | sc 0.99x ep 1950 | tr 1.000 | id 0.889 | ood 0.623 | gap +0.265 | ‖W‖ 706.3 | rank 30.0 | IRM 0.0000 | sc 0.98x ✓ Checkpoint → ep02000.pt ep 2000 | tr 1.000 | id 0.890 | ood 0.595 | gap +0.296 | ‖W‖ 714.3 | rank 31.5 | IRM 0.0000 | sc 0.98x ep 2050 | tr 1.000 | id 0.888 | ood 0.596 | gap +0.292 | ‖W‖ 720.1 | rank 31.4 | IRM 0.0000 | sc 0.99x ep 2100 | tr 1.000 | id 0.860 | ood 0.686 | gap +0.173 | ‖W‖ 730.1 | rank 30.6 | IRM 0.0000 | sc 0.99x ep 2150 | tr 1.000 | id 0.892 | ood 0.712 | gap +0.180 | ‖W‖ 732.0 | rank 28.5 | IRM 0.0000 | sc 0.99x ✓ Checkpoint → ep02200.pt ep 2200 | tr 1.000 | id 0.881 | ood 0.621 | gap +0.260 | ‖W‖ 736.6 | rank 31.0 | IRM 0.0000 | sc 0.99x ep 2250 | tr 1.000 | id 0.882 | ood 0.621 | gap +0.261 | ‖W‖ 736.5 | rank 29.6 | IRM 0.0000 | sc 0.99x ep 2300 | tr 1.000 | id 0.877 | ood 0.641 | gap +0.236 | ‖W‖ 748.9 | rank 32.6 | IRM 0.0000 | sc 0.99x ep 2350 | tr 1.000 | id 0.883 | ood 0.652 | gap +0.230 | ‖W‖ 753.4 | rank 35.2 | IRM 0.0000 | sc 0.99x ✓ Checkpoint → ep02400.pt ep 2400 | tr 1.000 | id 0.884 | ood 0.673 | gap +0.212 | ‖W‖ 753.4 | rank 34.2 | IRM 0.0000 | sc 1.00x ep 2450 | tr 1.000 | id 0.887 | ood 0.677 | gap +0.209 | ‖W‖ 755.5 | rank 32.2 | IRM 0.0000 | sc 1.00x ep 2500 | tr 1.000 | id 0.887 | ood 0.601 | gap +0.286 | ‖W‖ 768.2 | rank 35.0 | IRM 0.0000 | sc 0.98x ep 2550 | tr 1.000 | id 0.880 | ood 0.645 | gap +0.235 | ‖W‖ 772.8 | rank 35.5 | IRM 0.0000 | sc 0.98x ✓ Checkpoint → ep02600.pt ep 2600 | tr 1.000 | id 0.884 | ood 0.631 | gap +0.253 | ‖W‖ 774.0 | rank 34.8 | IRM 0.0000 | sc 0.99x ep 2650 | tr 1.000 | id 0.885 | ood 0.584 | gap +0.301 | ‖W‖ 776.4 | rank 37.4 | IRM 0.0000 | sc 0.99x ep 2700 | tr 1.000 | id 0.888 | ood 0.597 | gap +0.291 | ‖W‖ 780.4 | rank 36.2 | IRM 0.0000 | sc 0.99x ep 2750 | tr 1.000 | id 0.895 | ood 0.683 | gap +0.212 | ‖W‖ 786.8 | rank 35.2 | IRM 0.0000 | sc 0.99x ✓ Checkpoint → ep02800.pt ep 2800 | tr 1.000 | id 0.894 | ood 0.597 | gap +0.297 | ‖W‖ 794.5 | rank 34.0 | IRM 0.0000 | sc 0.99x ep 2850 | tr 1.000 | id 0.874 | ood 0.623 | gap +0.252 | ‖W‖ 803.4 | rank 32.7 | IRM 0.0000 | sc 0.99x ep 2900 | tr 1.000 | id 0.891 | ood 0.679 | gap +0.212 | ‖W‖ 808.3 | rank 35.3 | IRM 0.0000 | sc 0.98x ep 2950 | tr 1.000 | id 0.886 | ood 0.704 | gap +0.182 | ‖W‖ 811.6 | rank 33.7 | IRM 0.0000 | sc 1.00x ✓ Checkpoint → ep03000.pt ep 3000 | tr 1.000 | id 0.896 | ood 0.648 | gap +0.247 | ‖W‖ 812.6 | rank 33.3 | IRM 0.0000 | sc 0.99x Best ID val (H3): 0.9011 Best OOD (H4): 0.7615 OOD improvement: -0.0224 ← did OOD grok? Grokking at: None IRM drop: 100.0% Wall time: 327.3 min ``` --- ## 20. M5 — Full activation-steering JSONs (8 runs at n=1000) Every run's full `m5_steering_ep*.json`, verbatim from disk. `head_ood_acc` is head OOD accuracy on the steered features `h' = h + alpha * sigma * v_s`. `tumor_probe` is the linear-probe accuracy on the same steered features. `hospital_probe` is NaN by construction (H4 hospital labels do not overlap with the training-hospital probe's class set). ### `20260505-080445_grokking_n1000_s42` ```json { "run_id": "20260505-080445_grokking_n1000_s42", "epoch": 400, "layer": "avgpool", "max_samples": 800, "v_norm": 1.0, "sigma": 8.685188293457031, "sweep": [ { "alpha": -3.0, "head_ood_acc": 0.63, "hospital_probe": NaN, "tumor_probe": 0.6625 }, { "alpha": -2.0, "head_ood_acc": 0.625, "hospital_probe": NaN, "tumor_probe": 0.6525 }, { "alpha": -1.0, "head_ood_acc": 0.64, "hospital_probe": NaN, "tumor_probe": 0.665 }, { "alpha": -0.5, "head_ood_acc": 0.6325, "hospital_probe": NaN, "tumor_probe": 0.655 }, { "alpha": 0.0, "head_ood_acc": 0.6525, "hospital_probe": NaN, "tumor_probe": 0.66 }, { "alpha": 0.5, "head_ood_acc": 0.6575, "hospital_probe": NaN, "tumor_probe": 0.665 }, { "alpha": 1.0, "head_ood_acc": 0.6575, "hospital_probe": NaN, "tumor_probe": 0.6675 }, { "alpha": 2.0, "head_ood_acc": 0.62, "hospital_probe": NaN, "tumor_probe": 0.65 }, { "alpha": 3.0, "head_ood_acc": 0.59, "hospital_probe": NaN, "tumor_probe": 0.635 } ] } ``` ### `20260505-100720_grokking_n1000_s123` ```json { "run_id": "20260505-100720_grokking_n1000_s123", "epoch": 400, "layer": "avgpool", "max_samples": 800, "v_norm": 1.0, "sigma": 6.000692844390869, "sweep": [ { "alpha": -3.0, "head_ood_acc": 0.7325, "hospital_probe": NaN, "tumor_probe": 0.6925 }, { "alpha": -2.0, "head_ood_acc": 0.7325, "hospital_probe": NaN, "tumor_probe": 0.6925 }, { "alpha": -1.0, "head_ood_acc": 0.73, "hospital_probe": NaN, "tumor_probe": 0.695 }, { "alpha": -0.5, "head_ood_acc": 0.71, "hospital_probe": NaN, "tumor_probe": 0.695 }, { "alpha": 0.0, "head_ood_acc": 0.6975, "hospital_probe": NaN, "tumor_probe": 0.6925 }, { "alpha": 0.5, "head_ood_acc": 0.6925, "hospital_probe": NaN, "tumor_probe": 0.69 }, { "alpha": 1.0, "head_ood_acc": 0.69, "hospital_probe": NaN, "tumor_probe": 0.685 }, { "alpha": 2.0, "head_ood_acc": 0.68, "hospital_probe": NaN, "tumor_probe": 0.685 }, { "alpha": 3.0, "head_ood_acc": 0.6575, "hospital_probe": NaN, "tumor_probe": 0.685 } ] } ``` ### `20260505-100720_grokking_n1000_s456` ```json { "run_id": "20260505-100720_grokking_n1000_s456", "epoch": 1000, "layer": "avgpool", "max_samples": 800, "v_norm": 1.0, "sigma": 6.7488112449646, "sweep": [ { "alpha": -3.0, "head_ood_acc": 0.6825, "hospital_probe": NaN, "tumor_probe": 0.63 }, { "alpha": -2.0, "head_ood_acc": 0.6575, "hospital_probe": NaN, "tumor_probe": 0.615 }, { "alpha": -1.0, "head_ood_acc": 0.64, "hospital_probe": NaN, "tumor_probe": 0.605 }, { "alpha": -0.5, "head_ood_acc": 0.63, "hospital_probe": NaN, "tumor_probe": 0.595 }, { "alpha": 0.0, "head_ood_acc": 0.6225, "hospital_probe": NaN, "tumor_probe": 0.595 }, { "alpha": 0.5, "head_ood_acc": 0.62, "hospital_probe": NaN, "tumor_probe": 0.5925 }, { "alpha": 1.0, "head_ood_acc": 0.61, "hospital_probe": NaN, "tumor_probe": 0.5925 }, { "alpha": 2.0, "head_ood_acc": 0.6025, "hospital_probe": NaN, "tumor_probe": 0.5825 }, { "alpha": 3.0, "head_ood_acc": 0.5975, "hospital_probe": NaN, "tumor_probe": 0.59 } ] } ``` ### `20260508-183413_grokking_n1000_s7` ```json { "run_id": "20260508-183413_grokking_n1000_s7", "epoch": 200, "layer": "avgpool", "max_samples": 800, "v_norm": 0.9999999403953552, "sigma": 5.532620429992676, "sweep": [ { "alpha": -3.0, "head_ood_acc": 0.495, "hospital_probe": NaN, "tumor_probe": 0.58 }, { "alpha": -2.0, "head_ood_acc": 0.49, "hospital_probe": NaN, "tumor_probe": 0.5625 }, { "alpha": -1.0, "head_ood_acc": 0.485, "hospital_probe": NaN, "tumor_probe": 0.5425 }, { "alpha": -0.5, "head_ood_acc": 0.485, "hospital_probe": NaN, "tumor_probe": 0.5325 }, { "alpha": 0.0, "head_ood_acc": 0.485, "hospital_probe": NaN, "tumor_probe": 0.52 }, { "alpha": 0.5, "head_ood_acc": 0.4875, "hospital_probe": NaN, "tumor_probe": 0.5125 }, { "alpha": 1.0, "head_ood_acc": 0.4875, "hospital_probe": NaN, "tumor_probe": 0.52 }, { "alpha": 2.0, "head_ood_acc": 0.4775, "hospital_probe": NaN, "tumor_probe": 0.51 }, { "alpha": 3.0, "head_ood_acc": 0.4825, "hospital_probe": NaN, "tumor_probe": 0.51 } ] } ``` ### `20260508-183413_grokking_n1000_s2024` ```json { "run_id": "20260508-183413_grokking_n1000_s2024", "epoch": 400, "layer": "avgpool", "max_samples": 800, "v_norm": 1.0, "sigma": 9.224357604980469, "sweep": [ { "alpha": -3.0, "head_ood_acc": 0.7425, "hospital_probe": NaN, "tumor_probe": 0.69 }, { "alpha": -2.0, "head_ood_acc": 0.7775, "hospital_probe": NaN, "tumor_probe": 0.69 }, { "alpha": -1.0, "head_ood_acc": 0.7325, "hospital_probe": NaN, "tumor_probe": 0.69 }, { "alpha": -0.5, "head_ood_acc": 0.715, "hospital_probe": NaN, "tumor_probe": 0.69 }, { "alpha": 0.0, "head_ood_acc": 0.7125, "hospital_probe": NaN, "tumor_probe": 0.69 }, { "alpha": 0.5, "head_ood_acc": 0.705, "hospital_probe": NaN, "tumor_probe": 0.69 }, { "alpha": 1.0, "head_ood_acc": 0.68, "hospital_probe": NaN, "tumor_probe": 0.69 }, { "alpha": 2.0, "head_ood_acc": 0.6225, "hospital_probe": NaN, "tumor_probe": 0.69 }, { "alpha": 3.0, "head_ood_acc": 0.595, "hospital_probe": NaN, "tumor_probe": 0.69 } ] } ``` ### `20260505-100720_standard_n1000_s42` ```json { "run_id": "20260505-100720_standard_n1000_s42", "epoch": 200, "layer": "avgpool", "max_samples": 800, "v_norm": 1.0, "sigma": 7.23820686340332, "sweep": [ { "alpha": -3.0, "head_ood_acc": 0.585, "hospital_probe": NaN, "tumor_probe": 0.6125 }, { "alpha": -2.0, "head_ood_acc": 0.6, "hospital_probe": NaN, "tumor_probe": 0.62 }, { "alpha": -1.0, "head_ood_acc": 0.61, "hospital_probe": NaN, "tumor_probe": 0.625 }, { "alpha": -0.5, "head_ood_acc": 0.6175, "hospital_probe": NaN, "tumor_probe": 0.63 }, { "alpha": 0.0, "head_ood_acc": 0.6175, "hospital_probe": NaN, "tumor_probe": 0.625 }, { "alpha": 0.5, "head_ood_acc": 0.6175, "hospital_probe": NaN, "tumor_probe": 0.6225 }, { "alpha": 1.0, "head_ood_acc": 0.61, "hospital_probe": NaN, "tumor_probe": 0.6275 }, { "alpha": 2.0, "head_ood_acc": 0.6025, "hospital_probe": NaN, "tumor_probe": 0.6325 }, { "alpha": 3.0, "head_ood_acc": 0.5925, "hospital_probe": NaN, "tumor_probe": 0.625 } ] } ``` ### `20260508-183413_standard_n1000_s123` ```json { "run_id": "20260508-183413_standard_n1000_s123", "epoch": 200, "layer": "avgpool", "max_samples": 800, "v_norm": 1.0, "sigma": 13.366204261779785, "sweep": [ { "alpha": -3.0, "head_ood_acc": 0.5425, "hospital_probe": NaN, "tumor_probe": 0.565 }, { "alpha": -2.0, "head_ood_acc": 0.5975, "hospital_probe": NaN, "tumor_probe": 0.615 }, { "alpha": -1.0, "head_ood_acc": 0.6725, "hospital_probe": NaN, "tumor_probe": 0.68 }, { "alpha": -0.5, "head_ood_acc": 0.7125, "hospital_probe": NaN, "tumor_probe": 0.71 }, { "alpha": 0.0, "head_ood_acc": 0.72, "hospital_probe": NaN, "tumor_probe": 0.7375 }, { "alpha": 0.5, "head_ood_acc": 0.7175, "hospital_probe": NaN, "tumor_probe": 0.7425 }, { "alpha": 1.0, "head_ood_acc": 0.6825, "hospital_probe": NaN, "tumor_probe": 0.7425 }, { "alpha": 2.0, "head_ood_acc": 0.6275, "hospital_probe": NaN, "tumor_probe": 0.695 }, { "alpha": 3.0, "head_ood_acc": 0.555, "hospital_probe": NaN, "tumor_probe": 0.6425 } ] } ``` ### `20260508-183413_standard_n1000_s456` ```json { "run_id": "20260508-183413_standard_n1000_s456", "epoch": 1000, "layer": "avgpool", "max_samples": 800, "v_norm": 1.0, "sigma": 10.473133087158203, "sweep": [ { "alpha": -3.0, "head_ood_acc": 0.6375, "hospital_probe": NaN, "tumor_probe": 0.6475 }, { "alpha": -2.0, "head_ood_acc": 0.6575, "hospital_probe": NaN, "tumor_probe": 0.66 }, { "alpha": -1.0, "head_ood_acc": 0.675, "hospital_probe": NaN, "tumor_probe": 0.69 }, { "alpha": -0.5, "head_ood_acc": 0.6725, "hospital_probe": NaN, "tumor_probe": 0.69 }, { "alpha": 0.0, "head_ood_acc": 0.6175, "hospital_probe": NaN, "tumor_probe": 0.6575 }, { "alpha": 0.5, "head_ood_acc": 0.58, "hospital_probe": NaN, "tumor_probe": 0.5975 }, { "alpha": 1.0, "head_ood_acc": 0.5375, "hospital_probe": NaN, "tumor_probe": 0.585 }, { "alpha": 2.0, "head_ood_acc": 0.505, "hospital_probe": NaN, "tumor_probe": 0.5225 }, { "alpha": 3.0, "head_ood_acc": 0.505, "hospital_probe": NaN, "tumor_probe": 0.51 } ] } ``` --- ## 21. M5 — Aggregated sweep tables ### Grokking-favorable | α | s7 (ep200, σ=5.53) | s42 (ep400, σ=8.69) | s123 (ep400, σ=6.00) | s456 (ep1000, σ=6.75) | s2024 (ep400, σ=9.22) | | --- | --- | --- | --- | --- | --- | | −3.0 | 0.4950 | 0.6300 | 0.7325 | 0.6825 | 0.7425 | | −2.0 | 0.4900 | 0.6250 | 0.7325 | 0.6575 | 0.7775 | | −1.0 | 0.4850 | 0.6400 | 0.7300 | 0.6400 | 0.7325 | | −0.5 | 0.4850 | 0.6325 | 0.7100 | 0.6300 | 0.7150 | | 0.0 | 0.4850 | 0.6525 | 0.6975 | 0.6225 | 0.7125 | | +0.5 | 0.4875 | 0.6575 | 0.6925 | 0.6200 | 0.7050 | | +1.0 | 0.4875 | 0.6575 | 0.6900 | 0.6100 | 0.6800 | | +2.0 | 0.4775 | 0.6200 | 0.6800 | 0.6025 | 0.6225 | | +3.0 | 0.4825 | 0.5900 | 0.6575 | 0.5975 | 0.5950 | | Strict mono? | yes | no (α=0 ≥ α=−3) | yes | yes | yes | ### Standard | α | s42 (ep200, σ=7.24) | s123 (ep200, σ=13.37) | s456 (ep1000, σ=10.47) | | --- | --- | --- | --- | | −3.0 | 0.5850 | 0.5425 | 0.6375 | | −2.0 | 0.6000 | 0.5975 | 0.6575 | | −1.0 | 0.6100 | 0.6725 | 0.6750 | | −0.5 | 0.6175 | 0.7125 | 0.6725 | | 0.0 | 0.6175 | 0.7200 | 0.6175 | | +0.5 | 0.6175 | 0.7175 | 0.5800 | | +1.0 | 0.6100 | 0.6825 | 0.5375 | | +2.0 | 0.6025 | 0.6275 | 0.5050 | | +3.0 | 0.5925 | 0.5550 | 0.5050 | | Strict mono? | no (peak at α=0) | no (peak at α=0) | yes | **Aggregates and statistics**: - Strict monotonicity (`acc(−3) ≥ acc(0) ≥ acc(+3)`): **4/5 grokking** vs **1/3 standard**. - Mean σ: grokking 7.24, standard 10.36 (1.43× ratio — the σ-scaling confound). - Mean Δ(α=0→−3): grokking **+0.0225 ± 0.0276**, standard **−0.0633 ± 0.0835**. - Mean Δ(α=0→+3): grokking **−0.0500 ± 0.039**, standard **−0.101 ± 0.058**. - **Fisher exact one-sided p = 0.286**, **Mann-Whitney U one-sided p = 0.071** (continuous statistic, primary), binomial sign test p = 0.188. **None reaches p < 0.05.** --- ## 22. M6 — Full K-sweep results (per-seed, all K) Aggregated from `paper_figures/m6_summary.csv`. Δ(targ−rand) is `head_OOD(top-K shortcut ablated) − mean(head_OOD over 5 K-random ablations)`. Positive = targeted shortcut ablation beats random. ### Grokking-favorable (n=1000) | K | s7 | s42 | s123 | s456 | s2024 | N_+ | | --- | --- | --- | --- | --- | --- | --- | | 0 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | — | | 4 | +0.0015 | +0.0045 | +0.0015 | +0.0025 | −0.0005 | 4/5 | | 8 | −0.0010 | +0.0025 | +0.0010 | 0.0000 | −0.0020 | 2/5 | | 16 | +0.0005 | −0.0010 | +0.0055 | +0.0035 | −0.0015 | 3/5 | | 32 | −0.0045 | −0.0025 | +0.0010 | +0.0045 | −0.0020 | 2/5 | | 64 | −0.0015 | +0.0005 | +0.0115 | +0.0120 | −0.0115 | 3/5 | | 128 | −0.0060 | +0.0005 | +0.0090 | +0.0015 | −0.0225 | 3/5 | | 256 | −0.0345 | −0.0010 | +0.0060 | +0.0075 | −0.0205 | 2/5 | ### Standard (n=1000) | K | s42 | s123 | s456 | N_+ | | --- | --- | --- | --- | --- | | 0 | 0.0000 | 0.0000 | 0.0000 | — | | 4 | +0.0015 | −0.0025 | −0.0025 | 1/3 | | 8 | −0.0035 | −0.0010 | −0.0005 | 0/3 | | 16 | −0.0030 | −0.0025 | +0.0025 | 1/3 | | 32 | −0.0040 | −0.0055 | −0.0060 | 0/3 | | 64 | −0.0035 | −0.0100 | −0.0040 | 0/3 | | 128 | −0.0085 | −0.0090 | −0.0025 | 0/3 | | 256 | −0.0115 | +0.0040 | +0.0075 | 2/3 | ### Per-seed full ablation rows (K=64 and K=256) ``` grokking s42 K= 64: base=0.6575 short=0.6600 rand=0.6595±0.0043 morph=0.6475 Δshort-base=+0.0025 Δtarg-rand=+0.0005 grokking s42 K=256: base=0.6575 short=0.6625 rand=0.6635±0.0054 morph=0.5500 Δshort-base=+0.0050 Δtarg-rand=−0.0010 grokking s123 K= 64: base=0.6825 short=0.6950 rand=0.6835±0.0066 morph=0.6625 Δshort-base=+0.0125 Δtarg-rand=+0.0115 grokking s123 K=256: base=0.6825 short=0.7275 rand=0.7215±0.0179 morph=0.5575 Δshort-base=+0.0450 Δtarg-rand=+0.0060 grokking s456 K= 64: base=0.6450 short=0.6425 rand=0.6305±0.0120 morph=0.6550 Δshort-base=−0.0025 Δtarg-rand=+0.0120 grokking s456 K=256: base=0.6450 short=0.6325 rand=0.6250±0.0105 morph=0.5275 Δshort-base=−0.0125 Δtarg-rand=+0.0075 grokking s7 K= 64: base=0.4925 short=0.4875 rand=0.4890±0.0075 morph=0.4475 Δshort-base=−0.0050 Δtarg-rand=−0.0015 grokking s7 K=256: base=0.4925 short=0.4650 rand=0.4995±0.0056 morph=0.3800 Δshort-base=−0.0275 Δtarg-rand=−0.0345 grokking s2024 K= 64: base=0.7100 short=0.7125 rand=0.7240±0.0086 morph=0.7225 Δshort-base=+0.0025 Δtarg-rand=−0.0115 grokking s2024 K=256: base=0.7100 short=0.7225 rand=0.7430±0.0129 morph=0.5050 Δshort-base=+0.0125 Δtarg-rand=−0.0205 standard s42 K= 64: base=0.6150 short=0.6125 rand=0.6160±0.0034 morph=0.6100 Δshort-base=−0.0025 Δtarg-rand=−0.0035 standard s42 K=256: base=0.6150 short=0.6100 rand=0.6215±0.0020 morph=0.5950 Δshort-base=−0.0050 Δtarg-rand=−0.0115 standard s123 K= 64: base=0.7225 short=0.7150 rand=0.7250±0.0016 morph=0.7125 Δshort-base=−0.0075 Δtarg-rand=−0.0100 standard s123 K=256: base=0.7225 short=0.6900 rand=0.6860±0.0025 morph=0.6975 Δshort-base=−0.0325 Δtarg-rand=+0.0040 standard s456 K= 64: base=0.5975 short=0.5975 rand=0.6015±0.0030 morph=0.5850 Δshort-base= 0.0000 Δtarg-rand=−0.0040 standard s456 K=256: base=0.5975 short=0.5500 rand=0.5425±0.0055 morph=0.5525 Δshort-base=−0.0475 Δtarg-rand=+0.0075 ``` **Aggregates**: - K=64: grokking **3/5** positive Δ(targ−rand) (mean +0.0022 ± 0.0088); standard **0/3** (mean −0.0058 ± 0.0030). Fisher one-sided p = 0.179. - K=256: grokking 2/5 positive; standard 2/3 positive — essentially symmetric. - ID accuracy stays within 0.01 of baseline across all targeted ablations. - Random control averaged over 5 samplings per K (the main M6 weakness). The complete 88-row reviewer CSV is at `paper_figures/m6_summary.csv` — every run, every K, every condition (baseline / shortcut / random ± sd / morphology). --- ## 23. Exact commands Training (launched detached under nohup via `scripts/launch.sh`): ```bash python -u -m experiments.causalgrok_camelyon_v2 \ --condition grokking --n_train 1000 --seed 42 \ --run_dir experiments/runs/ \ --wandb_project causalgrok --wandb_mode offline # standard: --condition standard (wd/init/grokfast set automatically by get_config) ``` Mechanistic interpretability: ```bash python -m experiments.mechinterp_m1 --run_dir experiments/runs/ --data_root data/wilds python -m experiments.mechinterp_m4_ablation --run_dir experiments/runs/ --data_root data/wilds --layer avgpool --all_epochs python -m experiments.mechinterp_m5_steering --run_dir experiments/runs/ --data_root data/wilds python -m experiments.mechinterp_m6_neuron_ablation --run_dir experiments/runs/ --data_root data/wilds \ --ks "0,4,8,16,32,64,128,256" ``` Figures: ```bash python -m experiments.regenerate_all_figures # rebuilds all 7 paper figures from saved JSON in <2 min ``` --- ## 24. Output layout (per run) ``` experiments/runs// ├── config.json # full hyperparameter config ├── run.pid ├── checkpoints/ │ ├── ep00200.pt … ep03000.pt # 15 periodic checkpoints, ~44 MB each │ └── final.pt ├── logs/ │ ├── train.log # launch cmd + per-checkpoint log lines │ └── train.err ├── results/ │ ├── history.json # 61 per-checkpoint metric rows │ └── summary.json # final-summary fields ├── wandb/ # offline wandb run metadata └── mechinterp/ ├── m1_probe_data.json + heatmap/curves PNG ├── m4_ablation_avgpool_trajectory.json + PNG ├── m5_steering_ep.json + PNG └── m6_neuron_ablation_ep.json + PNG ``` Aggregated: `paper_figures/m6_summary.csv` (88 data rows), `paper_figures/*.{png,pdf}` (7 figures). Checkpoints (~10 GB, 240 `.pt` files) are mirrored to Hugging Face at `nileshsarkar-ai/CausalGrok`. --- *All numerical values in this document were read directly from on-disk `config.json` / `results/summary.json` / `results/history.json` / `mechinterp/*.json` / `paper_figures/m6_summary.csv`. Training logs in §18 and §19 are verbatim from each run's `logs/train.log`. Source code in §10–§15 is the exact code that ran for every reported result.*