| # CausalGrok — Complete Training, Evaluation, and Mechanistic-Interpretability Reference |
|
|
| Paper: *Interventional Analysis of Shortcut Geometry Under Grokking-Favorable Training*. |
|
|
| This document is the **complete preservation archive** of the project. It contains the full source of every script that ran, every per-run config and summary, the full training logs of two reference runs, every M5 activation-steering result, the full M6 K-sweep data, and all formulas. All numerical values are read directly from on-disk `config.json` / `results/summary.json` / `results/history.json` / `mechinterp/*.json` / `paper_figures/m6_summary.csv`. All source code is the exact code that produced every reported result. |
|
|
| **Contents** |
|
|
| 1. Environment |
| 2. Dataset and data pipeline |
| 3. Model architecture and initialization |
| 4. Hyperparameter tables |
| 5. Loss function (cross-entropy) — formula |
| 6. IRM penalty (diagnostic) — formula |
| 7. Grokfast EMA — formula |
| 8. Training loop and checkpointing |
| 9. Evaluation metrics — formulas |
| 10. Full source: `utils/grokfast.py` |
| 11. Full source: `experiments/causalgrok_camelyon_v2.py` |
| 12. Full source: `experiments/mechinterp_m1.py` |
| 13. Full source: `experiments/mechinterp_m4_ablation.py` |
| 14. Full source: `experiments/mechinterp_m5_steering.py` |
| 15. Full source: `experiments/mechinterp_m6_neuron_ablation.py` |
| 16. Run inventory and summary results (14 runs) |
| 17. Per-run `config.json` and `summary.json` (all 14 runs) |
| 18. Full training log: grokking n=1000 seed=42 |
| 19. Full training log: standard n=1000 seed=42 |
| 20. M5 — Full activation-steering JSONs (8 runs at n=1000) |
| 21. M5 — Aggregated sweep tables |
| 22. M6 — Full K-sweep results (per-seed, all K) |
| 23. Exact commands |
| 24. Output layout |
|
|
| --- |
|
|
| ## 1. Environment |
|
|
| | Item | Value | |
| | --- | --- | |
| | Python env | `conda env: causalgrok` (Python 3.10) | |
| | Framework | PyTorch, `timm`, `torchvision`, `scikit-learn`, `numpy`, `wandb` (offline) | |
| | Model | `timm.create_model("resnet18", pretrained=False, num_classes=2)` | |
| | Device | CUDA (NVIDIA A100 80GB PCIe) | |
| | Precision | TF32 (`set_float32_matmul_precision("high")`, `cudnn.benchmark=True`, `allow_tf32=True`) | |
| | Params | 11,177,538 | |
| | Dataset | Camelyon17 via WILDS (`utils.camelyon_data.get_camelyon_subsets`, auto-download) | |
| | Wall time per run | ~5.5 h (grokking n=1000 s42, 3000 epochs, A100) | |
|
|
| ## 2. Dataset and data pipeline |
|
|
| Camelyon17 (WILDS), H&E-stained histopathology patches, binary tumor label. Five hospitals; the WILDS split provides train hospitals, an in-distribution validation set, and a held-out OOD test hospital. |
|
|
| - **ID validation**: 33,560 images. |
| - **OOD test (held-out hospital)**: 85,054 images. |
| - **Train**: subsampled to `n_train`. The n=1000 seed-42 run drew hospitals `{0, 3, 4}` with 181 / 371 / 448 samples (positive rates 0.53 / 0.48 / 0.50). |
|
|
| Transforms: `Resize((96, 96))`, `ToTensor`, `Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])`. Train loader: `batch_size=32, shuffle=True`; eval loaders: `batch_size=256, shuffle=False`; `num_workers=0`, `pin_memory=True`. |
|
|
| IRM environments — one `{x, y}` dict per unique training hospital, built once from `metadata[:, 0]`. |
|
|
| ## 3. Model architecture and initialization |
|
|
| ResNet-18 (`timm`, no ImageNet pretraining), 96×96 input, 2-class head, 11,177,538 parameters. The grokking-favorable regime multiplies every multi-dim weight tensor by `init_scale = 4.0` at initialization; standard uses `init_scale = 1.0` (no rescaling). `avgpool` feature dimension `D = 512`. Six probed stages: `stem, layer1, layer2, layer3, layer4, avgpool`. |
|
|
| ## 4. Hyperparameter tables |
|
|
| | Hyperparameter | Standard | Grokking-favorable | |
| | --- | --- | --- | |
| | Optimizer | AdamW | AdamW | |
| | Learning rate | 1e-3 | 1e-3 | |
| | Weight decay | 1e-4 | 5e-3 (50×) | |
| | Epochs | 3000 | 3000 | |
| | Init scale | 1.0 | 4.0 | |
| | Grokfast EMA | off | on | |
| | Grokfast alpha (EMA decay) | — | 0.98 | |
| | Grokfast lamb (slow-grad amplification) | — | 2.0 | |
| | Gradient clip (max-norm) | 1.0 | 1.0 | |
| | Batch size | 32 | 32 | |
| | Image size | 96 × 96 | 96 × 96 | |
| | `log_every` (metric cadence) | 50 epochs | 50 epochs | |
| | Checkpoint cadence | 200 epochs | 200 epochs | |
| | IRM weight in loss | 0.0 | 0.0 | |
|
|
| **Three-axis confound**: weight decay (50× ratio), init scale (4× ratio), and Grokfast EMA (on vs off) differ simultaneously between regimes. No single-axis ablation was run. |
|
|
| ## 5. Loss function — cross-entropy |
|
|
| ``` |
| L = CE(f_theta(x), y) = - (1/B) * sum_i log [ exp(z_{i, y_i}) / sum_k exp(z_{i, k}) ] |
| ``` |
|
|
| where `z = f_theta(x)` are the logits. The training objective is pure CE for every reported run (`irm_weight = 0.0`). |
|
|
| ## 6. IRM penalty (diagnostic, not in loss) |
|
|
| IRMv1 penalty (Arjovsky et al. 2019): |
|
|
| ``` |
| IRM_e = || grad_{w=1.0} CE( w · f_theta(x^e), y^e ) ||^2 |
| IRM = (1/|E|) sum_{e in E} IRM_e |
| ``` |
|
|
| Logged at every checkpoint as `irm_mean`, `irm_var`. Collapses from ~0.20 to ~1e-13 within ~50–150 epochs in every run (a CE-memorization consequence, not invariance learning). |
|
|
| ## 7. Grokfast EMA (grokking-favorable only) |
|
|
| Grokfast (Lee et al. 2024, arXiv:2405.20233): |
|
|
| ``` |
| g_ema <- alpha * g_ema + (1 - alpha) * g # alpha = 0.98 |
| g <- g + lamb * g_ema # lamb = 2.0 |
| ``` |
|
|
| Applied after `loss.backward()`, before `optimizer.step()`. |
|
|
| ## 8. Training loop and checkpointing |
|
|
| - Loop over `epoch ∈ [1, 3000]`. |
| - Per minibatch: forward, CE loss, `loss.backward()`, Grokfast filter (grokking only), `clip_grad_norm_(max_norm=1.0)`, `optimizer.step()`. |
| - Metrics computed every 50 epochs (+ epoch 1) → `results/history.json` (61 rows/run). |
| - Checkpoints every 200 epochs → 15 `.pt` per run + `final.pt`. |
| - OOD-aware early stopping exists but defaults off; all reported runs train the full 3000 epochs. |
| - Grokking detection watches OOD-accuracy plateau-then-jump; `grokking_epoch = -1` (never fires) for every run. |
|
|
| ## 9. Evaluation metrics — formulas |
|
|
| | Metric | Definition | |
| | --- | --- | |
| | Accuracy | `argmax` accuracy on a loader; `train_acc`, `id_val_acc`, `ood_acc`. `ood_gap = id_val_acc - ood_acc`. | |
| | Weight norm | `||W|| = sqrt(sum_p ||p||_2^2)` | |
| | Effective feature rank | `exp(- sum_i ŝ_i log ŝ_i)` where `ŝ_i = s_i / sum_j s_j` are the normalized SVD singular values of `layer4` avgpool features on ≤300 samples. | |
| | Shortcut ratio | `min(border_conf / center_conf, 10)`. `>1` = stain-reliant, `<1` = tissue-reliant. | |
| | IRM penalty | see §6 — computed across the 3 training-hospital environments per epoch. | |
|
|
| Summary fields (`summary.json`): `best_id_val`, `best_ood`, `peak_ood_epoch`, `final_ood`, `ood_delta` (= `final_ood − best_ood`, ungrokking signal), `ood_improvement`, `grokking_epoch`, `irm_drop_pct`, `irm_drop_epoch`, `epoch_gap`, `final_weight_norm`, `final_feature_rank`, `final_irm`, `final_shortcut_ratio`, `final_ood_gap`. |
|
|
| --- |
|
|
| ## 10. Full source: `utils/grokfast.py` |
|
|
| ```python |
| """ |
| utils.grokfast — accelerated grokking by amplifying slow-varying gradient |
| components (Lee et al. 2024, arXiv:2405.20233). |
| |
| Maintain an EMA of gradients across steps; the slow-EMA component |
| corresponds to the generalising circuit. Adding it back into the live |
| gradient (scaled by `lamb`) accelerates the grokking transition 20-100×. |
| """ |
| |
| from __future__ import annotations |
| |
| |
| def gradfilter_ema(model, grads_ema, alpha: float = 0.98, lamb: float = 2.0): |
| """ |
| Call this AFTER `loss.backward()` and BEFORE `optimizer.step()`. |
| |
| Args: |
| model: the network whose gradients we are filtering. |
| grads_ema: dict {param_name: ema_grad}, or None on the first call. |
| alpha: EMA decay (0.98 → very slow, emphasises persistent grads). |
| lamb: amplification factor for the slow component. |
| |
| Returns: |
| Updated `grads_ema` dict — pass it back in on the next step. |
| """ |
| if grads_ema is None: |
| grads_ema = {} |
| |
| for name, p in model.named_parameters(): |
| if p.requires_grad and p.grad is not None: |
| if name not in grads_ema: |
| grads_ema[name] = p.grad.data.detach().clone() |
| else: |
| grads_ema[name] = ( |
| grads_ema[name] * alpha |
| + p.grad.data.detach() * (1 - alpha) |
| ) |
| p.grad.data = p.grad.data + grads_ema[name] * lamb |
| |
| return grads_ema |
| |
| ``` |
|
|
| --- |
|
|
| ## 11. Full source: `experiments/causalgrok_camelyon_v2.py` |
|
|
| ```python |
| """ |
| CausalGrok — Camelyon17 Training Loop v2 |
| Nilesh |
| |
| KEY CHANGE FROM v1: |
| OOD test accuracy (H4 — unseen hospital) is now tracked at EVERY |
| checkpoint, not just at the end. Grokking detection watches OOD acc, |
| not ID val acc. This is the correct signal. |
| |
| The paper claim: after ID accuracy converges (fast, expected), the model |
| undergoes a delayed phase transition in OOD generalization — grokking |
| the cross-hospital invariant causal features. This co-occurs with a drop |
| in IRM penalty. That is the grokking we care about for clinical deployment. |
| |
| Two curves to watch: |
| val_acc (H3 ID val) — converges fast, expected ~0.86 by ep 50 |
| ood_acc (H4 OOD test) — should plateau then JUMP (the grokking) |
| |
| Run via: |
| python -m experiments.causalgrok_camelyon_v2 --condition grokking --n_train 300 |
| """ |
| |
| from __future__ import annotations |
| |
| import argparse |
| import json |
| import os |
| import time |
| from datetime import datetime, timezone |
| |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
| from torch.utils.data import DataLoader, Subset |
| import timm |
| try: |
| import wandb |
| except ImportError: |
| wandb = None |
| |
| from utils.grokfast import gradfilter_ema |
| from utils.camelyon_data import get_camelyon_subsets |
| from utils.run_dir import make_run_dir, ensure_run_dir, save_config |
| |
| |
| # ────────────────────────────────────────────── |
| # CONFIG |
| # ────────────────────────────────────────────── |
| |
| def get_config(condition): |
| base = dict( |
| seed=42, n_train=300, batch_size=32, img_size=96, |
| n_classes=2, log_every=50, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| ) |
| if condition == "standard": |
| base.update(dict( |
| condition="standard", |
| lr=1e-3, weight_decay=1e-4, |
| # Default 3000 epochs to match grokking config and the |
| # paper's reported runs; previously defaulted to 300 which |
| # made the standard baseline trivially under-trained |
| # relative to grokking. See paper Limitations §M3. |
| n_epochs=3000, init_scale=1.0, use_grokfast=False, |
| )) |
| elif condition == "grokking": |
| base.update(dict( |
| condition="grokking", |
| lr=1e-3, weight_decay=5e-3, |
| n_epochs=3000, init_scale=4.0, use_grokfast=True, |
| grokfast_alpha=0.98, grokfast_lamb=2.0, |
| )) |
| return base |
| |
| |
| # ────────────────────────────────────────────── |
| # WILDS-SAFE METRICS |
| # All handle the (imgs, labels, metadata) 3-tuple WILDS batch format. |
| # ────────────────────────────────────────────── |
| |
| @torch.no_grad() |
| def accuracy_wilds(model, loader, device, max_samples=None): |
| model.eval() |
| correct = total = 0 |
| for batch in loader: |
| imgs = batch[0].to(device) |
| labels = batch[1].squeeze().long().to(device) |
| preds = model(imgs).argmax(1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
| if max_samples and total >= max_samples: |
| break |
| return correct / max(total, 1) |
| |
| |
| @torch.no_grad() |
| def weight_norm_fn(model): |
| return sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 |
| |
| |
| @torch.no_grad() |
| def feature_rank_wilds(model, loader, device, n=300): |
| model.eval() |
| feats = [] |
| |
| def hook_fn(module, input, output): |
| avg_pool = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1)) |
| feats.append(avg_pool.view(avg_pool.size(0), -1).cpu()) |
| |
| hook = model.layer4[-1].register_forward_hook(hook_fn) |
| count = 0 |
| for batch in loader: |
| model(batch[0].to(device)) |
| count += batch[0].size(0) |
| if count >= n: |
| break |
| hook.remove() |
| if not feats: |
| return float("nan") |
| F_mat = torch.cat(feats)[:n] |
| try: |
| _, s, _ = torch.svd(F_mat) |
| s = s / (s.sum() + 1e-10) |
| return torch.exp(-(s * torch.log(s + 1e-10)).sum()).item() |
| except Exception: |
| return float("nan") |
| |
| |
| @torch.no_grad() |
| def shortcut_ratio_wilds(model, loader, device, n_samples=200): |
| """ |
| Stain shortcut proxy: compare model confidence on center crop |
| (tissue — causal features) vs. border region (stain — spurious). |
| |
| sc > 1.0 = relying on border stain more than tissue (shortcut) |
| sc < 1.0 = relying on tissue center more than stain (causal) |
| |
| The transition from > 1.0 to < 1.0 during training is the |
| attribution-level signature of the grokking transition. |
| """ |
| model.eval() |
| cc, bc = [], [] |
| count = 0 |
| for batch in loader: |
| if count >= n_samples: |
| break |
| imgs = batch[0].to(device) |
| B, C, H, W = imgs.shape |
| hs, he = H // 4, 3 * H // 4 |
| ws, we = W // 4, 3 * W // 4 |
| center = F.interpolate( |
| imgs[:, :, hs:he, ws:we], size=(H, W), |
| mode="bilinear", align_corners=False |
| ) |
| border = imgs.clone() |
| border[:, :, hs:he, ws:we] = 0.0 |
| cc.append(F.softmax(model(center), 1).max(1).values.mean().item()) |
| bc.append(F.softmax(model(border), 1).max(1).values.mean().item()) |
| count += imgs.size(0) |
| cconf = float(np.mean(cc)) if cc else 0.5 |
| bconf = float(np.mean(bc)) if bc else 0.5 |
| return cconf, bconf |
| |
| |
| def irm_penalty_wilds(model, envs, device): |
| """ |
| IRMv1 penalty across TRAINING hospital environments (H0-H2). |
| Diagnostic version: uses create_graph=False, returns floats. Used as a |
| monitoring metric only (logged per epoch). |
| """ |
| model.eval() |
| penalties = [] |
| for env in envs: |
| w = torch.tensor(1.0, requires_grad=True, device=device) |
| logits = model(env["x"]) * w |
| loss = F.cross_entropy(logits, env["y"]) |
| grad = torch.autograd.grad(loss, w, create_graph=False)[0] |
| penalties.append(grad.item() ** 2) |
| t = torch.tensor(penalties) |
| return t.mean().item(), t.var().item() |
| |
| |
| def irm_penalty_train_time(logits_list, y_list): |
| """ |
| IRMv1 penalty for use INSIDE the training loss (differentiable). |
| Splits a batch by environment, computes per-env loss with a virtual |
| scale variable, takes the squared gradient of each per-env loss w.r.t. |
| that scale, returns the mean across envs. |
| |
| Args: |
| logits_list: list of (per-env) logits tensors |
| y_list: list of (per-env) label tensors |
| |
| Returns: |
| scalar tensor (differentiable), the IRM penalty contribution. |
| """ |
| penalty = 0.0 |
| n = 0 |
| for logits, y in zip(logits_list, y_list): |
| if logits.shape[0] == 0: |
| continue |
| scale = torch.tensor(1.0, requires_grad=True, device=logits.device) |
| loss = F.cross_entropy(logits * scale, y) |
| grad = torch.autograd.grad(loss, scale, create_graph=True)[0] |
| penalty = penalty + grad ** 2 |
| n += 1 |
| if n == 0: |
| return torch.tensor(0.0, device=logits_list[0].device) |
| return penalty / n |
| |
| |
| def eval_irm_penalty_wilds(model, id_val_loader, ood_test_loader, device): |
| """ |
| IRM penalty evaluated on HELD-OUT environments (H3 and H4). |
| This avoids the measurement artifact of training on H0-H2 where loss→0. |
| HIGH penalty = model relies on hospital-discriminating features = shortcuts. |
| LOW penalty = model ignores hospital labels = causal features. |
| """ |
| model.eval() |
| penalties = [] |
| |
| # Create environment views from eval data |
| for loader, hospital_label in [ |
| (id_val_loader, "H3"), |
| (ood_test_loader, "H4"), |
| ]: |
| xs, ys = [], [] |
| count = 0 |
| with torch.no_grad(): |
| for batch in loader: |
| imgs = batch[0].to(device) |
| labels = batch[1].squeeze().long().to(device) |
| xs.append(model(imgs)) |
| ys.append(labels) |
| count += imgs.size(0) |
| if count >= 500: |
| break |
| if xs: |
| x = torch.cat(xs) |
| y = torch.cat(ys) |
| w = torch.tensor(1.0, requires_grad=True, device=device) |
| logits = x * w |
| loss = F.cross_entropy(logits, y) |
| try: |
| grad = torch.autograd.grad(loss, w, create_graph=False)[0] |
| penalties.append(grad.item() ** 2) |
| except: |
| penalties.append(float("nan")) |
| |
| if penalties and not any(np.isnan(p) for p in penalties): |
| return float(np.mean(penalties)), float(np.var(penalties)) |
| else: |
| return float("nan"), float("nan") |
| |
| |
| # ────────────────────────────────────────────── |
| # DATA |
| # ────────────────────────────────────────────── |
| |
| class TransformWrapper: |
| def __init__(self, dataset, transform): |
| self.dataset = dataset |
| self.transform = transform |
| def __len__(self): |
| return len(self.dataset) |
| def __getitem__(self, idx): |
| img, label, metadata = self.dataset[idx] |
| return self.transform(img), label, metadata |
| |
| |
| def get_dataloaders(cfg, data_root): |
| transform = transforms.Compose([ |
| transforms.Resize((cfg["img_size"], cfg["img_size"])), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
| |
| train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets( |
| root_dir=data_root, download=True) |
| |
| # Subsample training set |
| torch.manual_seed(cfg["seed"]) |
| indices = torch.randperm(len(train_ds))[:cfg["n_train"]] |
| train_subset = Subset(train_ds, indices) |
| |
| # Wrap with TransformWrapper to apply transforms |
| train_subset = TransformWrapper(train_subset, transform) |
| id_val_ds = TransformWrapper(id_val_ds, transform) |
| ood_test_ds = TransformWrapper(ood_test_ds, transform) |
| |
| train_loader = DataLoader(train_subset, batch_size=cfg["batch_size"], |
| shuffle=True, num_workers=0, pin_memory=True) |
| id_val_loader = DataLoader(id_val_ds, batch_size=256, |
| shuffle=False, num_workers=0, pin_memory=True) |
| ood_test_loader = DataLoader(ood_test_ds, batch_size=256, |
| shuffle=False, num_workers=0, pin_memory=True) |
| |
| return train_loader, id_val_loader, ood_test_loader, train_subset |
| |
| |
| def get_hospital_environments(train_subset, device): |
| """ |
| Build IRM environments from ground-truth hospital labels. |
| Returns list of {x, y} dicts — one per unique hospital in the subset. |
| Hospitals in Camelyon17 train split: 0, 1, 2. |
| """ |
| loader = DataLoader(train_subset, batch_size=512, |
| shuffle=False, num_workers=4) |
| all_imgs, all_labels, all_meta = [], [], [] |
| for imgs, labels, meta in loader: |
| all_imgs.append(imgs) |
| all_labels.append(labels.squeeze().long()) |
| all_meta.append(meta) |
| |
| all_imgs = torch.cat(all_imgs) |
| all_labels = torch.cat(all_labels) |
| hospitals = torch.cat(all_meta)[:, 0].long() # field 0 = hospital ID |
| |
| envs = [] |
| for h in torch.unique(hospitals): |
| mask = hospitals == h |
| n = mask.sum().item() |
| envs.append({ |
| "x": all_imgs[mask].to(device), |
| "y": all_labels[mask].to(device), |
| "hospital": int(h), |
| }) |
| pos_rate = all_labels[mask].float().mean().item() |
| print(f" Env hospital={int(h)}: {n} samples, " |
| f"positive rate={pos_rate:.2f}") |
| return envs |
| |
| |
| # ────────────────────────────────────────────── |
| # MODEL |
| # ────────────────────────────────────────────── |
| |
| def build_model(cfg): |
| model = timm.create_model("resnet18", pretrained=False, |
| num_classes=cfg["n_classes"]) |
| if cfg["init_scale"] != 1.0: |
| with torch.no_grad(): |
| for name, p in model.named_parameters(): |
| if "weight" in name and p.dim() > 1: |
| p.data *= cfg["init_scale"] |
| return model.to(cfg["device"]) |
| |
| |
| # ────────────────────────────────────────────── |
| # TRAIN |
| # ────────────────────────────────────────────── |
| |
| def train(cfg, model, train_loader, id_val_loader, ood_test_loader, |
| envs, optimizer, run_dir): |
| |
| criterion = nn.CrossEntropyLoss() |
| grads_ema = None |
| history = [] |
| best_id_val = 0.0 |
| best_ood = 0.0 |
| peak_ood_epoch = None # Epoch where best_ood was achieved |
| grok_epoch = None |
| irm_base = None |
| history_path = os.path.join(run_dir, "results", "history.json") |
| grad_clip = cfg.get("grad_clip", 1.0) |
| |
| # Grokking detection parameters. |
| # We watch OOD accuracy (H4), not ID val accuracy (H3). |
| # ID val converges fast (expected). OOD is what should grok. |
| plateau_window = 10 |
| plateau_eps = 0.01 |
| |
| # Ungrokking early stopping parameters. |
| # If OOD peaks then declines, stop at the peak rather than training to convergence. |
| ood_patience = cfg.get("ood_patience", 20) # checkpoints to wait before stopping |
| ood_min_delta = cfg.get("ood_min_delta", 0.01) # minimum improvement threshold |
| use_ood_early_stop = cfg.get("use_ood_early_stop", False) |
| |
| print(f"\n{'='*60}") |
| print(f" {cfg['condition'].upper()} | Camelyon17 v2 | {cfg['n_epochs']} epochs") |
| print(f" WD={cfg['weight_decay']} | α={cfg['init_scale']} | n={cfg['n_train']}") |
| print(f" Tracking: ID val (H3) + OOD test (H4) at every checkpoint") |
| print(f" Grokking detection: watching OOD acc, not ID val acc") |
| print(f" IRM envs: {len(envs)} hospitals") |
| print(f"{'='*60}", flush=True) |
| |
| irm_weight = float(cfg.get("irm_weight", 0.0)) |
| use_irm_in_loss = irm_weight > 0.0 |
| if use_irm_in_loss: |
| print(f" IRM-in-loss: ENABLED, alpha={irm_weight}", flush=True) |
| else: |
| print(f" IRM-in-loss: disabled (CE-only training; IRM penalty is diagnostic)", flush=True) |
| |
| for epoch in range(1, cfg["n_epochs"] + 1): |
| # ── Train step ──────────────────────────────────────────────── |
| model.train() |
| loss_sum = n_b = 0 |
| for imgs, labels, metadata in train_loader: |
| imgs = imgs.to(cfg["device"]) |
| labels = labels.squeeze().long().to(cfg["device"]) |
| optimizer.zero_grad() |
| logits = model(imgs) |
| ce_loss = criterion(logits, labels) |
| |
| if use_irm_in_loss: |
| # Split this batch by training hospital (H0/H1/H2) and |
| # compute IRMv1 penalty as a differentiable scalar. |
| hosp_ids = metadata[:, 0].long().to(cfg["device"]) |
| logits_per_env, y_per_env = [], [] |
| for h in [0, 1, 2]: |
| mask = (hosp_ids == h) |
| if mask.sum() < 2: |
| continue |
| logits_per_env.append(logits[mask]) |
| y_per_env.append(labels[mask]) |
| if len(logits_per_env) >= 2: |
| irm_term = irm_penalty_train_time(logits_per_env, y_per_env) |
| loss = ce_loss + irm_weight * irm_term |
| else: |
| loss = ce_loss |
| else: |
| loss = ce_loss |
| |
| loss.backward() |
| if cfg.get("use_grokfast"): |
| grads_ema = gradfilter_ema( |
| model, grads_ema, |
| alpha=cfg.get("grokfast_alpha", 0.98), |
| lamb=cfg.get("grokfast_lamb", 2.0)) |
| if grad_clip > 0: |
| torch.nn.utils.clip_grad_norm_( |
| model.parameters(), max_norm=grad_clip) |
| optimizer.step() |
| loss_sum += loss.item() |
| n_b += 1 |
| |
| # ── Checkpoint metrics ──────────────────────────────────────── |
| if epoch % cfg["log_every"] == 0 or epoch == 1: |
| tr_acc = accuracy_wilds(model, train_loader, cfg["device"]) |
| id_acc = accuracy_wilds(model, id_val_loader, cfg["device"]) |
| ood_acc = accuracy_wilds(model, ood_test_loader, cfg["device"]) # KEY |
| wn = weight_norm_fn(model) |
| fr = feature_rank_wilds(model, id_val_loader, cfg["device"]) |
| irm_m, irm_v = irm_penalty_wilds(model, envs, cfg["device"]) |
| cconf, bconf = shortcut_ratio_wilds( |
| model, id_val_loader, cfg["device"]) |
| |
| if irm_base is None: |
| irm_base = irm_m |
| |
| # ── OOD grokking detection ──────────────────────────────── |
| # Require sustained plateau in OOD acc before the jump. |
| # The ID val acc plateau is expected and not grokking. |
| if grok_epoch is None and len(history) >= plateau_window: |
| last = history[-plateau_window:] |
| ref = last[-1]["ood_acc"] |
| flat = sum(1 for r in last |
| if abs(r["ood_acc"] - ref) < plateau_eps) |
| if flat >= plateau_window - 2 and ood_acc > best_ood + 0.05: |
| grok_epoch = epoch |
| irm_drop = (irm_base - irm_m) / (irm_base + 1e-8) * 100 |
| print(f"\n *** OOD GROKKING at epoch {epoch} ***") |
| print(f" OOD: {best_ood:.3f} → {ood_acc:.3f} | " |
| f"IRM drop: {irm_drop:.1f}%", flush=True) |
| |
| if id_acc > best_id_val: best_id_val = id_acc |
| if ood_acc > best_ood: |
| best_ood = ood_acc |
| peak_ood_epoch = epoch # Track when peak was achieved |
| |
| sc_ratio = min(bconf / (cconf + 1e-8), 10.0) |
| |
| # OOD gap: how much worse is OOD vs ID? |
| # This should shrink at the grokking transition. |
| ood_gap = id_acc - ood_acc |
| |
| row = dict( |
| epoch = epoch, |
| train_loss = loss_sum / n_b, |
| train_acc = tr_acc, |
| id_val_acc = id_acc, |
| ood_acc = ood_acc, # ← primary grokking signal |
| ood_gap = ood_gap, # ← should narrow at transition |
| weight_norm = wn, |
| feature_rank = fr, |
| irm_mean = irm_m, |
| irm_var = irm_v, |
| center_conf = cconf, |
| border_conf = bconf, |
| shortcut_ratio = sc_ratio, |
| grokking_detected = grok_epoch is not None, |
| ) |
| history.append(row) |
| if wandb: |
| wandb.log(row) |
| |
| with open(history_path, "w") as f: |
| json.dump(history, f, indent=2) |
| |
| # Save periodic checkpoint for M1 analysis (every 200 epochs) |
| if epoch % 200 == 0: |
| ckpt_dir = os.path.join(run_dir, "checkpoints") |
| os.makedirs(ckpt_dir, exist_ok=True) |
| ckpt_path = os.path.join(ckpt_dir, f"ep{epoch:05d}.pt") |
| torch.save(model.state_dict(), ckpt_path) |
| print(f" ✓ Checkpoint → ep{epoch:05d}.pt", flush=True) |
| |
| # ── OOD-aware early stopping (if ungrokking detected) ─────── |
| # If OOD peaks then declines, stop at the peak rather than full epochs. |
| if use_ood_early_stop and peak_ood_epoch is not None and len(history) >= ood_patience: |
| recent_ood = [r["ood_acc"] for r in history[-ood_patience:]] |
| ood_trend = max(recent_ood) - min(recent_ood) |
| |
| if ood_acc < best_ood - ood_min_delta: |
| print(f"\n *** EARLY STOP (OOD declining) at epoch {epoch} ***", flush=True) |
| print(f" Peak OOD: {best_ood:.4f} at epoch {peak_ood_epoch}", flush=True) |
| print(f" Current: {ood_acc:.4f} ({ood_acc-best_ood:+.4f})", flush=True) |
| |
| # Save peak checkpoint separately for clinical deployment |
| if peak_ood_epoch and peak_ood_epoch % 200 == 0: |
| peak_src = os.path.join(run_dir, "checkpoints", f"ep{peak_ood_epoch:05d}.pt") |
| peak_dst = os.path.join(run_dir, "checkpoints", "peak_ood.pt") |
| if os.path.exists(peak_src): |
| import shutil |
| shutil.copy(peak_src, peak_dst) |
| print(f" Saved peak → checkpoints/peak_ood.pt", flush=True) |
| |
| break # Exit training loop |
| |
| print(f" ep {epoch:5d} | " |
| f"tr {tr_acc:.3f} | " |
| f"id {id_acc:.3f} | " |
| f"ood {ood_acc:.3f} | " |
| f"gap {ood_gap:+.3f} | " # + means OOD worse than ID |
| f"‖W‖ {wn:.1f} | " |
| f"rank {fr:.1f} | " |
| f"IRM {irm_m:.4f} | " |
| f"sc {sc_ratio:.2f}x", |
| flush=True) |
| |
| # ── Final summary ───────────────────────────────────────────────── |
| # One final OOD eval at the very end |
| final_ood = accuracy_wilds(model, ood_test_loader, cfg["device"]) |
| if wandb: |
| wandb.log({"final_ood_acc": final_ood, |
| "grokking_epoch": grok_epoch or -1}) |
| |
| # Decision numbers |
| irm_drop_pct = float("nan") |
| irm_drop_ep = epoch_gap = -1 |
| if history: |
| irm0 = history[0]["irm_mean"] |
| irm_min = min(r["irm_mean"] for r in history) |
| if irm0: |
| irm_drop_pct = (irm0 - irm_min) / (irm0 + 1e-8) * 100 |
| if len(history) > 1: |
| biggest = 0.0 |
| for prev, cur in zip(history[:-1], history[1:]): |
| d = abs(cur["irm_mean"] - prev["irm_mean"]) |
| if d > biggest: |
| biggest = d |
| irm_drop_ep = cur["epoch"] |
| if grok_epoch and irm_drop_ep > 0: |
| epoch_gap = abs(grok_epoch - irm_drop_ep) |
| |
| # OOD grokking: did OOD acc improve significantly after ID convergence? |
| # Measure: max OOD acc in last 20% of training vs. OOD acc when ID |
| # first plateaued (epoch ~200-300 for standard training). |
| ood_early = np.mean([r["ood_acc"] for r in history[:5]]) if history else 0 |
| ood_late = np.mean([r["ood_acc"] for r in history[-5:]]) if history else 0 |
| ood_improvement = ood_late - ood_early |
| |
| # Ungrokking detection: did OOD collapse after peaking? |
| ood_delta = final_ood - best_ood # Negative = ungrokking |
| |
| summary = dict( |
| run_id = cfg["run_id"], |
| condition = cfg["condition"], |
| n_train = cfg["n_train"], |
| seed = cfg["seed"], |
| best_id_val = best_id_val, |
| best_ood = best_ood, |
| peak_ood_epoch = peak_ood_epoch or -1, # When peak was achieved |
| final_ood = final_ood, |
| ood_delta = ood_delta, # final - best (ungrokking signal) |
| ood_improvement = ood_improvement, # ← key: did OOD grok? |
| grokking_epoch = grok_epoch or -1, |
| irm_drop_pct = irm_drop_pct, |
| irm_drop_epoch = irm_drop_ep, |
| epoch_gap = epoch_gap, |
| final_weight_norm = history[-1]["weight_norm"] if history else None, |
| final_feature_rank= history[-1]["feature_rank"] if history else None, |
| final_irm = history[-1]["irm_mean"] if history else None, |
| final_shortcut_ratio = history[-1]["shortcut_ratio"] if history else None, |
| final_ood_gap = history[-1]["ood_gap"] if history else None, |
| ) |
| with open(os.path.join(run_dir, "results", "summary.json"), "w") as f: |
| json.dump(summary, f, indent=2) |
| |
| torch.save(model.state_dict(), |
| os.path.join(run_dir, "checkpoints", "final.pt")) |
| |
| print(f"\n Best ID val (H3): {best_id_val:.4f}") |
| print(f" Best OOD (H4): {best_ood:.4f}") |
| print(f" OOD improvement: {ood_improvement:+.4f} ← did OOD grok?") |
| print(f" Grokking at: {grok_epoch}") |
| print(f" IRM drop: {irm_drop_pct:.1f}%", |
| flush=True) |
| return history |
| |
| |
| # ────────────────────────────────────────────── |
| # MAIN |
| # ────────────────────────────────────────────── |
| |
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--condition", default="grokking", |
| choices=["standard", "grokking"]) |
| p.add_argument("--n_train", type=int, default=300) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--log_every", type=int, default=50) |
| p.add_argument("--wandb_project", default="causalgrok") |
| p.add_argument("--wandb_mode", default="offline", |
| choices=["online", "offline", "disabled"]) |
| p.add_argument("--run_dir", default=None) |
| p.add_argument("--data_root", default="data/wilds") |
| p.add_argument("--weight_decay", type=float, default=None) |
| p.add_argument("--init_scale", type=float, default=None) |
| p.add_argument("--n_epochs", type=int, default=None) |
| p.add_argument("--lr", type=float, default=None) |
| p.add_argument("--grokfast", choices=["on", "off"], default=None) |
| p.add_argument("--grad_clip", type=float, default=1.0) |
| p.add_argument("--irm_weight", type=float, default=0.0, |
| help="IRMv1 penalty weight added to training loss " |
| "(0 = pure cross-entropy / diagnostic-only IRM).") |
| args = p.parse_args() |
| |
| cfg = get_config(args.condition) |
| cfg.update(n_train=args.n_train, seed=args.seed, |
| log_every=args.log_every, grad_clip=args.grad_clip) |
| |
| if args.weight_decay is not None: cfg["weight_decay"] = args.weight_decay |
| if args.init_scale is not None: cfg["init_scale"] = args.init_scale |
| if args.n_epochs is not None: cfg["n_epochs"] = args.n_epochs |
| if args.lr is not None: cfg["lr"] = args.lr |
| if args.grokfast is not None: cfg["use_grokfast"] = (args.grokfast == "on") |
| cfg["irm_weight"] = args.irm_weight |
| |
| if cfg["device"] == "cuda": |
| torch.set_float32_matmul_precision("high") |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| |
| torch.manual_seed(cfg["seed"]) |
| np.random.seed(cfg["seed"]) |
| |
| if args.run_dir is None: |
| run_dir, run_id = make_run_dir( |
| ["camelyon_v2", cfg["condition"], |
| f"n{cfg['n_train']}", f"s{cfg['seed']}"]) |
| else: |
| run_dir = args.run_dir |
| ensure_run_dir(run_dir) |
| run_id = os.path.basename(os.path.normpath(run_dir)) |
| |
| cfg["run_id"] = run_id |
| cfg["run_dir"] = run_dir |
| save_config(cfg, run_dir) |
| |
| if wandb: |
| wandb.init(project=args.wandb_project, config=cfg, name=run_id, |
| mode=args.wandb_mode, dir=run_dir) |
| |
| print(f"\nDevice: {cfg['device']}") |
| print(f"Run ID: {run_id}") |
| print(f"Started: {datetime.now(timezone.utc).isoformat()}", flush=True) |
| |
| train_loader, id_val_loader, ood_test_loader, train_subset = \ |
| get_dataloaders(cfg, args.data_root) |
| |
| envs = get_hospital_environments(train_subset, cfg["device"]) |
| model = build_model(cfg) |
| |
| print(f"Train: {len(train_subset)} | " |
| f"ID val (H3): {len(id_val_loader.dataset)} | " |
| f"OOD test (H4): {len(ood_test_loader.dataset)}") |
| print(f"Params: {sum(p.numel() for p in model.parameters()):,}", |
| flush=True) |
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=cfg["lr"], weight_decay=cfg["weight_decay"]) |
| |
| t0 = time.time() |
| train(cfg, model, train_loader, id_val_loader, ood_test_loader, |
| envs, optimizer, run_dir) |
| print(f"\nWall time: {(time.time()-t0)/60:.1f} min", flush=True) |
| if wandb: |
| wandb.finish() |
| |
| |
| if __name__ == "__main__": |
| main() |
| |
| ``` |
|
|
| --- |
|
|
| ## 12. Full source: `experiments/mechinterp_m1.py` |
| |
| ```python |
| """ |
| CausalGrok — M1: Layer-wise Linear Probing |
| Nilesh |
| |
| The mechanistic claim: |
| Before grokking: hospital probe HIGH (model uses stain shortcut), |
| tumor probe LOW in early layers |
| At transition: hospital probe DROPS, tumor probe RISES |
| After grokking: inverted — tumor high, hospital low |
| |
| If OOD acc jump + hospital probe drop + tumor probe rise |
| all happen at the same epoch → mechanistic claim proven. |
| That's Figure 2 of the paper. |
| |
| Usage: |
| # Run on all saved checkpoints from a run |
| python -m experiments.mechinterp_m1 \ |
| --run_dir experiments/runs/<run_id> \ |
| --data_root data/wilds |
| |
| # Run on latest checkpoint only (quick check while training) |
| python -m experiments.mechinterp_m1 \ |
| --run_dir experiments/runs/<run_id> \ |
| --data_root data/wilds \ |
| --latest_only |
| |
| # Run on ALL camelyon_v2 grokking runs |
| python -m experiments.mechinterp_m1 \ |
| --all_runs \ |
| --data_root data/wilds |
| |
| Output per run: |
| experiments/runs/<run_id>/mechinterp/ |
| m1_probe_heatmap.png ← epoch × layer, hospital probe acc |
| m1_tumor_heatmap.png ← epoch × layer, tumor probe acc |
| m1_probe_curves.png ← hospital vs tumor probe over epochs (layer 4) |
| m1_probe_data.json ← raw numbers for paper tables |
| """ |
| |
| from __future__ import annotations |
|
|
| import argparse |
| import glob |
| import json |
| import os |
| from typing import Dict, List, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as transforms |
| from torch.utils.data import DataLoader, Subset |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.preprocessing import StandardScaler |
| import matplotlib |
| import matplotlib.pyplot as plt |
| import timm |
| import warnings |
| warnings.filterwarnings("ignore") |
| |
| matplotlib.rcParams.update({"font.size": 11, "figure.dpi": 150}) |
| |
| |
| # ────────────────────────────────────────────── |
| # RESNET-18 LAYER HOOKS |
| # Extract features after each of the 6 measurable stages: |
| # stem → layer1 → layer2 → layer3 → layer4 → avgpool |
| # ────────────────────────────────────────────── |
| |
| LAYER_NAMES = [ |
| "stem", # After initial conv + bn + relu + maxpool |
| "layer1", # ResNet block 1 (64 channels) |
| "layer2", # ResNet block 2 (128 channels) |
| "layer3", # ResNet block 3 (256 channels) |
| "layer4", # ResNet block 4 (512 channels) |
| "avgpool", # Global average pool — penultimate representation |
| ] |
| |
|
|
| def register_hooks(model): |
| """ |
| Register forward hooks on all 6 extraction points. |
| Returns (hooks, features_dict). |
| """ |
| features = {name: [] for name in LAYER_NAMES} |
| hooks = [] |
| |
| def make_hook(name): |
| def hook_fn(module, input, output): |
| if output.dim() == 4: |
| feat = output.mean(dim=[2, 3]) |
| else: |
| feat = output.view(output.size(0), -1) |
| features[name].append(feat.detach().cpu()) |
| return hook_fn |
| |
| hooks.append(model.maxpool.register_forward_hook(make_hook("stem"))) |
| hooks.append(model.layer1.register_forward_hook(make_hook("layer1"))) |
| hooks.append(model.layer2.register_forward_hook(make_hook("layer2"))) |
| hooks.append(model.layer3.register_forward_hook(make_hook("layer3"))) |
| hooks.append(model.layer4.register_forward_hook(make_hook("layer4"))) |
| hooks.append(model.global_pool.register_forward_hook(make_hook("avgpool"))) |
| |
| return hooks, features |
| |
|
|
| def extract_features(model, loader, device, max_samples=1000): |
| """ |
| Run forward pass and collect features at all 6 layers. |
| """ |
| model.eval() |
| hooks, feat_dict = register_hooks(model) |
| |
| all_hospital = [] |
| all_tumor = [] |
| count = 0 |
| |
| with torch.no_grad(): |
| for batch in loader: |
| imgs = batch[0].to(device) |
| labels = batch[1].squeeze().long() |
| metadata = batch[2] |
| model(imgs) |
| all_hospital.append(metadata[:, 0].long()) |
| all_tumor.append(labels) |
| count += imgs.size(0) |
| if count >= max_samples: |
| break |
| |
| for h in hooks: |
| h.remove() |
| |
| features = {k: torch.cat(v).numpy() for k, v in feat_dict.items()} |
| hospital_ids = torch.cat(all_hospital).numpy() |
| tumor_labels = torch.cat(all_tumor).numpy() |
| |
| n = min(max_samples, len(hospital_ids)) |
| features = {k: v[:n] for k, v in features.items()} |
| hospital_ids = hospital_ids[:n] |
| tumor_labels = tumor_labels[:n] |
| |
| return features, hospital_ids, tumor_labels |
| |
|
|
| def train_probe(X_train, y_train, X_val, y_val): |
| """ |
| Train logistic regression probe on frozen features. |
| """ |
| if len(np.unique(y_train)) < 2: |
| return 0.5 |
| |
| scaler = StandardScaler() |
| X_train = scaler.fit_transform(X_train) |
| X_val = scaler.transform(X_val) |
| |
| clf = LogisticRegression( |
| max_iter=500, |
| C=1.0, |
| solver="lbfgs", |
| multi_class="auto", |
| n_jobs=-1, |
| ) |
| try: |
| clf.fit(X_train, y_train) |
| return clf.score(X_val, y_val) |
| except Exception: |
| return float("nan") |
| |
|
|
| # ────────────────────────────────────────────── |
| # CHECKPOINT DISCOVERY |
| # ────────────────────────────────────────────── |
|
|
| def find_checkpoints(run_dir: str) -> List[tuple]: |
| """ |
| Find all checkpoints in a run directory. |
| Returns list of (epoch, checkpoint_path) sorted by epoch. |
| """ |
| ckpt_dir = os.path.join(run_dir, "checkpoints") |
| if not os.path.isdir(ckpt_dir): |
| return [] |
| |
| checkpoints = [] |
| |
| # Periodic checkpoints: ep050.pt, ep100.pt, etc. |
| for f in sorted(glob.glob(os.path.join(ckpt_dir, "ep*.pt"))): |
| epoch_str = os.path.basename(f).replace("ep", "").replace(".pt", "") |
| try: |
| epoch = int(epoch_str) |
| checkpoints.append((epoch, f)) |
| except ValueError: |
| continue |
| |
| # Final checkpoint |
| final = os.path.join(ckpt_dir, "final.pt") |
| if os.path.isfile(final): |
| hist_path = os.path.join(run_dir, "results", "history.json") |
| if os.path.isfile(hist_path): |
| try: |
| hist = json.load(open(hist_path)) |
| epoch = hist[-1]["epoch"] if hist else 9999 |
| except Exception: |
| epoch = 9999 |
| else: |
| epoch = 9999 |
| checkpoints.append((epoch, final)) |
| |
| return sorted(checkpoints, key=lambda x: x[0]) |
| |
|
|
| def load_model_from_checkpoint(ckpt_path: str, n_classes: int = 2, |
| device: str = "cuda") -> nn.Module: |
| model = timm.create_model("resnet18", pretrained=False, |
| num_classes=n_classes) |
| state = torch.load(ckpt_path, map_location=device) |
| model.load_state_dict(state, strict=True) |
| model.eval() |
| return model.to(device) |
| |
|
|
| # ────────────────────────────────────────────── |
| # MAIN PROBE ANALYSIS |
| # ────────────────────────────────────────────── |
|
|
| def run_probe_analysis(run_dir: str, data_root: str, |
| device: str = "cuda", |
| max_samples: int = 800, |
| latest_only: bool = False) -> Optional[Dict]: |
| """ |
| For each checkpoint in a run, extract features at all 6 layers |
| and train hospital + tumor probes. |
| """ |
| from utils.camelyon_data import get_camelyon_subsets |
| |
| cfg_path = os.path.join(run_dir, "config.json") |
| if not os.path.isfile(cfg_path): |
| print(f" No config.json in {run_dir}, skipping") |
| return None |
| |
| cfg = json.load(open(cfg_path)) |
| condition = cfg.get("condition", "unknown") |
| n_train = cfg.get("n_train", 300) |
| seed = cfg.get("seed", 42) |
| |
| print(f"\n{'='*55}") |
| print(f" M1 Probe Analysis: {os.path.basename(run_dir)}") |
| print(f" condition={condition}, n_train={n_train}, seed={seed}") |
| print(f"{'='*55}") |
| |
| checkpoints = find_checkpoints(run_dir) |
| if not checkpoints: |
| print(f" No checkpoints found — skipping") |
| return None |
| |
| if latest_only: |
| checkpoints = checkpoints[-1:] |
| |
| print(f" Found {len(checkpoints)} checkpoints: " |
| f"epochs {[e for e,_ in checkpoints]}") |
| print(f" Hospital probe: fits on training data (H0-H2), " |
| f"evaluates on H3 and H4 separately") |
| |
| # ── Data ───────────────────────────────────────────────────────── |
| transform = transforms.Compose([ |
| transforms.Resize((96, 96)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
| train_ds, id_val_ds, ood_test_ds, full_ds = get_camelyon_subsets( |
| root_dir=data_root, download=False) |
| |
| # Wrap datasets with transform (WILDS returns PIL images) |
| class _TransformWrapper: |
| def __init__(self, dataset, transform): |
| self.dataset = dataset |
| self.transform = transform |
| def __len__(self): |
| return len(self.dataset) |
| def __getitem__(self, idx): |
| img, label, metadata = self.dataset[idx] |
| return self.transform(img), label, metadata |
| |
| id_val_t = _TransformWrapper(id_val_ds, transform) |
| ood_test_t = _TransformWrapper(ood_test_ds, transform) |
| train_t = _TransformWrapper(train_ds, transform) |
| |
| torch.manual_seed(seed) |
| probe_idx = torch.randperm(len(id_val_t))[:max_samples // 2] |
| ood_idx = torch.randperm(len(ood_test_t))[:max_samples // 2] |
| train_idx = torch.randperm(len(train_t))[:max_samples] |
| |
| probe_loader = DataLoader( |
| Subset(id_val_t, probe_idx), |
| batch_size=128, shuffle=False, num_workers=0) |
| ood_loader = DataLoader( |
| Subset(ood_test_t, ood_idx), |
| batch_size=128, shuffle=False, num_workers=0) |
| train_loader = DataLoader( |
| Subset(train_t, train_idx), |
| batch_size=128, shuffle=False, num_workers=0) |
| |
| # ── Results storage ────────────────────────────────────────────── |
| results = { |
| "run_id": os.path.basename(run_dir), |
| "condition": condition, |
| "n_train": n_train, |
| "seed": seed, |
| "epochs": [], |
| "layers": LAYER_NAMES, |
| "hospital_probe_id": [], # Hospital accuracy on H3 |
| "hospital_probe_ood": [], # Hospital accuracy on H4 |
| "tumor_probe_id": [], # Tumor accuracy on H3 |
| "tumor_probe_ood": [], # Tumor accuracy on H4 |
| } |
| |
| # ── Per-checkpoint analysis ─────────────────────────────────────── |
| for epoch, ckpt_path in checkpoints: |
| print(f"\n Epoch {epoch} | {os.path.basename(ckpt_path)}") |
| |
| try: |
| model = load_model_from_checkpoint( |
| ckpt_path, n_classes=2, device=device) |
| except Exception as e: |
| print(f" Failed to load checkpoint: {e}") |
| continue |
| |
| # Extract features from all three datasets |
| feats_train, hosp_train, tumor_train = extract_features( |
| model, train_loader, device, max_samples=max_samples) |
| |
| feats_id, hosp_id, tumor_id = extract_features( |
| model, probe_loader, device, max_samples=max_samples // 2) |
| |
| feats_ood, hosp_ood, tumor_ood = extract_features( |
| model, ood_loader, device, max_samples=max_samples // 2) |
| |
| epoch_hosp_id = [] |
| epoch_hosp_ood = [] |
| epoch_tumor_id = [] |
| epoch_tumor_ood = [] |
| |
| for layer_name in LAYER_NAMES: |
| # Fit probes on training features, evaluate on H3 and H4 |
| X_train_layer = feats_train[layer_name] |
| X_id_layer = feats_id[layer_name] |
| X_ood_layer = feats_ood[layer_name] |
| |
| # Hospital probe: can model distinguish hospitals H0-H2? |
| # If yes on H3/H4 → stain is encoded |
| h_acc_id = train_probe(X_train_layer, hosp_train, |
| X_id_layer, hosp_id) |
| h_acc_ood = train_probe(X_train_layer, hosp_train, |
| X_ood_layer, hosp_ood) |
| |
| # Tumor probe: can model distinguish tumor vs normal? |
| t_acc_id = train_probe(X_train_layer, tumor_train, |
| X_id_layer, tumor_id) |
| t_acc_ood = train_probe(X_train_layer, tumor_train, |
| X_ood_layer, tumor_ood) |
| |
| epoch_hosp_id.append(h_acc_id) |
| epoch_hosp_ood.append(h_acc_ood) |
| epoch_tumor_id.append(t_acc_id) |
| epoch_tumor_ood.append(t_acc_ood) |
| |
| print(f" {layer_name:8s}: " |
| f"hosp_H3={h_acc_id:.3f} hosp_H4={h_acc_ood:.3f} " |
| f"tumor_H3={t_acc_id:.3f} tumor_H4={t_acc_ood:.3f}") |
| |
| results["epochs"].append(epoch) |
| results["hospital_probe_id"].append(epoch_hosp_id) |
| results["hospital_probe_ood"].append(epoch_hosp_ood) |
| results["tumor_probe_id"].append(epoch_tumor_id) |
| results["tumor_probe_ood"].append(epoch_tumor_ood) |
| |
| del model |
| |
| # ── Save raw data ───────────────────────────────────────────────── |
| out_dir = os.path.join(run_dir, "mechinterp") |
| os.makedirs(out_dir, exist_ok=True) |
| |
| data_path = os.path.join(out_dir, "m1_probe_data.json") |
| with open(data_path, "w") as f: |
| json.dump(results, f, indent=2) |
| print(f"\n Probe data → {data_path}") |
| |
| # ── Plots ───────────────────────────────────────────────────────── |
| _plot_probe_heatmaps(results, out_dir) |
| _plot_probe_curves(results, out_dir) |
| |
| print(f" Figures → {out_dir}/") |
| return results |
| |
|
|
| def _plot_probe_heatmaps(results: Dict, out_dir: str): |
| """ |
| Epoch (x) × layer (y), color = probe accuracy. |
| |
| Hospital probe shown on H3 (held-in held-out hospital, classes overlap |
| with training). The H4 version is degenerate by construction since the |
| probe is fit on the training-hospital class set and H4 is not in it |
| (hospital_probe_ood ≡ 0 across all epochs / layers). |
| |
| Tumor probe shown on H4 (truly OOD hospital) since tumor labels are |
| binary and shared across hospitals — H4 captures the causal-feature |
| transferability we care about. |
| """ |
| epochs = results["epochs"] |
| layers = results["layers"] |
| |
| if not epochs: |
| return |
| |
| hosp_matrix = np.array(results["hospital_probe_id"]) # H3 — has signal |
| tumor_matrix = np.array(results["tumor_probe_ood"]) # H4 — true OOD |
| |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
| |
| for ax, matrix, title, cmap in [ |
| (axes[0], hosp_matrix, "Hospital probe on H3 (shortcut recoverability)\nHigh = stain still encoded = BAD", "Reds"), |
| (axes[1], tumor_matrix, "Tumor probe on H4 (causal, OOD)\nHigh = causal feature transfers = GOOD", "Greens"), |
| ]: |
| im = ax.imshow( |
| matrix.T, |
| aspect="auto", |
| cmap=cmap, |
| vmin=0.0, vmax=1.0, |
| interpolation="nearest", |
| origin="lower", |
| ) |
| ax.set_xticks(range(len(epochs))) |
| ax.set_xticklabels(epochs, rotation=45, ha="right", fontsize=8) |
| ax.set_yticks(range(len(layers))) |
| ax.set_yticklabels(layers, fontsize=9) |
| ax.set_xlabel("Training epoch") |
| ax.set_ylabel("ResNet layer") |
| ax.set_title(title, fontsize=10, fontweight="bold") |
| plt.colorbar(im, ax=ax, label="Probe accuracy") |
| |
| fig.suptitle( |
| f"M1 — Layer-wise Linear Probing: {results['run_id']}\n" |
| "Circuit signature: deep-layer hospital-probe drop (Reds) + sustained tumor recoverability (Greens)", |
| fontsize=10, y=1.02 |
| ) |
| plt.tight_layout() |
| out = os.path.join(out_dir, "m1_probe_heatmap.png") |
| plt.savefig(out, bbox_inches="tight") |
| plt.close() |
| |
|
|
| def _plot_probe_curves(results: Dict, out_dir: str): |
| """ |
| Line plot per-layer: hospital probe (H3 — recoverability) + tumor probe |
| (H4 — causal-feature transfer to truly unseen hospital), with OOD accuracy |
| from history.json overlaid. |
| """ |
| epochs = results["epochs"] |
| run_dir = os.path.join(out_dir, "..") |
| layers = results["layers"] |
| avgpool_idx = layers.index("avgpool") |
| layer2_idx = layers.index("layer2") |
| |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
| |
| for ax, layer_idx, layer_label in [ |
| (axes[0], avgpool_idx, "avgpool (penultimate)"), |
| (axes[1], layer2_idx, "layer2 (early)"), |
| ]: |
| hosp = [results["hospital_probe_id"][i][layer_idx] |
| for i in range(len(epochs))] |
| tumor = [results["tumor_probe_ood"][i][layer_idx] |
| for i in range(len(epochs))] |
| |
| ax.plot(epochs, hosp, "r-o", markersize=4, lw=2, |
| label="Hospital probe on H3 (shortcut recoverability ↓ want)") |
| ax.plot(epochs, tumor, "g-s", markersize=4, lw=2, |
| label="Tumor probe on H4 (causal transfer ↑ want)") |
| |
| hist_path = os.path.join(run_dir, "results", "history.json") |
| if os.path.isfile(hist_path): |
| try: |
| hist = json.load(open(hist_path)) |
| hist_eps = [r["epoch"] for r in hist] |
| ood_accs = [r.get("ood_acc", float("nan")) for r in hist] |
| ax.plot(hist_eps, ood_accs, "b--", lw=1.5, alpha=0.7, |
| label="OOD accuracy (H4)") |
| except Exception: |
| pass |
| |
| ax.axhline(0.5, color="gray", ls=":", lw=1, alpha=0.5, |
| label="Chance (0.5)") |
| ax.set_xlabel("Training epoch") |
| ax.set_ylabel("Probe / OOD accuracy") |
| ax.set_title(f"Layer: {layer_label}", fontweight="bold") |
| ax.legend(fontsize=9) |
| ax.set_ylim([0, 1.05]) |
| ax.grid(alpha=0.3) |
| |
| fig.suptitle( |
| f"M1 — Probe Curves: {results['run_id']}\n" |
| "Hospital recoverability (H3) drops in deep layers + tumor transfers (H4) — circuit signature", |
| fontsize=10, y=1.02 |
| ) |
| plt.tight_layout() |
| out = os.path.join(out_dir, "m1_probe_curves.png") |
| plt.savefig(out, bbox_inches="tight") |
| plt.close() |
| |
|
|
| def main(): |
| p = argparse.ArgumentParser( |
| description="M1: Layer-wise linear probing for CausalGrok") |
| p.add_argument("--run_dir", default=None, |
| help="Single run directory to analyze") |
| p.add_argument("--all_runs", action="store_true", |
| help="Analyze all camelyon_v2 grokking runs") |
| p.add_argument("--data_root", default="data/wilds") |
| p.add_argument("--device", default="cuda") |
| p.add_argument("--max_samples", type=int, default=800) |
| p.add_argument("--latest_only", action="store_true", |
| help="Analyze only latest checkpoint (quick check)") |
| args = p.parse_args() |
| |
| if args.all_runs: |
| run_dirs = sorted(glob.glob( |
| "experiments/runs/*camelyon_v2*grokking*")) |
| print(f"Found {len(run_dirs)} grokking runs") |
| all_results = [] |
| for rd in run_dirs: |
| r = run_probe_analysis(rd, args.data_root, |
| device=args.device, |
| max_samples=args.max_samples, |
| latest_only=args.latest_only) |
| if r: |
| all_results.append(r) |
| |
| if all_results: |
| os.makedirs("paper_figures", exist_ok=True) |
| with open("paper_figures/m1_all_probes.json", "w") as f: |
| json.dump(all_results, f, indent=2) |
| print(f"\nCombined → paper_figures/m1_all_probes.json") |
| |
| elif args.run_dir: |
| run_probe_analysis(args.run_dir, args.data_root, |
| device=args.device, |
| max_samples=args.max_samples, |
| latest_only=args.latest_only) |
| else: |
| print("Specify --run_dir <path> or --all_runs") |
| |
|
|
| if __name__ == "__main__": |
| main() |
| |
| ``` |
| |
| --- |
| |
| ## 13. Full source: `experiments/mechinterp_m4_ablation.py` |
| |
| ```python |
| """M4 — Representation Ablation: causal intervention on the shortcut subspace. |
|
|
| Pipeline: |
| 1. Pick a checkpoint (peak-OOD epoch by default). |
| 2. Extract features at avgpool (or `--layer`) for train (H0-H2) + OOD (H4) splits. |
| 3. Fit a hospital-classification logistic-regression probe on train features. |
| The probe's weight rows define the *shortcut subspace* in feature space. |
| 4. Build the projector P = W^T (W W^T)^-1 W onto that subspace and define |
| `ablate(h) = h - P h`. |
| 5. Re-classify OOD images with the *same* trained classifier head, fed: |
| (a) raw features h — baseline OOD accuracy |
| (b) ablated features h - Ph — post-intervention OOD accuracy |
| 6. Also report: |
| (c) shortcut accuracy (probe.score on h vs h-Ph) |
| (d) tumor probe accuracy on h vs h-Ph (sanity: the causal feature |
| should survive the intervention) |
| (e) head's tumor classification accuracy on H4 with raw vs ablated features |
| |
| If the intervention is causal: |
| - shortcut probe accuracy: collapses |
| - OOD accuracy: improves (or at least doesn't decay as much) |
| - tumor probe accuracy: largely preserved |
|
|
| Usage |
| ----- |
| python -m experiments.mechinterp_m4_ablation \\ |
| --run_dir experiments/runs/<id> \\ |
| --data_root data/wilds \\ |
| --layer avgpool \\ |
| [--epoch 50] # default: peak_ood_epoch from summary.json |
| [--max_samples 1000] |
| |
| Output: |
| <run_dir>/mechinterp/m4_ablation_<layer>_ep<E>.json |
| <run_dir>/mechinterp/m4_ablation_<layer>_ep<E>.png |
| """ |
| from __future__ import annotations |
| |
| import argparse |
| import json |
| import os |
| import sys |
| from pathlib import Path |
| from typing import Dict, Tuple |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.preprocessing import StandardScaler |
| from torch.utils.data import DataLoader, Subset |
| from torchvision import transforms |
| |
| ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(ROOT)) |
| |
| # Re-use M1 helpers — hooks, model loader, feature extraction, ckpt discovery. |
| from experiments.mechinterp_m1 import ( |
| register_hooks, |
| extract_features, |
| load_model_from_checkpoint, |
| find_checkpoints, |
| ) |
| from utils.camelyon_data import get_camelyon_subsets |
| |
|
|
| class _TransformWrapper: |
| def __init__(self, dataset, transform): |
| self.dataset = dataset |
| self.transform = transform |
| def __len__(self): |
| return len(self.dataset) |
| def __getitem__(self, idx): |
| img, label, metadata = self.dataset[idx] |
| return self.transform(img), label, metadata |
| |
| |
| def _build_loaders(data_root: str, max_samples: int, seed: int = 42): |
| transform = transforms.Compose([ |
| transforms.Resize((96, 96)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
| train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets( |
| root_dir=data_root, download=False |
| ) |
| train_t = _TransformWrapper(train_ds, transform) |
| ood_t = _TransformWrapper(ood_test_ds, transform) |
| |
| torch.manual_seed(seed) |
| train_idx = torch.randperm(len(train_t))[:max_samples] |
| ood_idx = torch.randperm(len(ood_t))[:max_samples // 2] |
| |
| train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128, |
| shuffle=False, num_workers=0) |
| ood_loader = DataLoader(Subset(ood_t, ood_idx), batch_size=128, |
| shuffle=False, num_workers=0) |
| return train_loader, ood_loader |
| |
|
|
| def _select_epoch(run_dir: Path, requested: int | None) -> Tuple[int, Path]: |
| ckpts = find_checkpoints(str(run_dir)) |
| if not ckpts: |
| raise FileNotFoundError(f"No checkpoints in {run_dir}/checkpoints/") |
|
|
| if requested is not None: |
| for ep, p in ckpts: |
| if ep == requested: |
| return ep, Path(p) |
| raise ValueError(f"Requested epoch {requested} not in checkpoints " |
| f"({[ep for ep, _ in ckpts]})") |
| |
| # default: peak OOD epoch from summary.json |
| summary_path = run_dir / "results" / "summary.json" |
| peak = None |
| if summary_path.exists(): |
| s = json.loads(summary_path.read_text()) |
| peak = s.get("peak_ood_epoch", None) |
| |
| if peak is not None and peak > 0: |
| # nearest periodic checkpoint |
| nearest = min(ckpts, key=lambda x: abs(x[0] - peak)) |
| return nearest[0], Path(nearest[1]) |
| |
| # fall back to last checkpoint |
| return ckpts[-1][0], Path(ckpts[-1][1]) |
| |
|
|
| def _build_projector(W: np.ndarray) -> np.ndarray: |
| """W has shape (k, d). Returns P (d, d) projecting onto rowspace(W).""" |
| # Use SVD for a stable orthonormal basis of rowspace |
| U, s, Vt = np.linalg.svd(W, full_matrices=False) |
| # rowspace basis = Vt rows where singular values > tol |
| tol = max(W.shape) * np.finfo(s.dtype).eps * (s.max() if s.size else 0.0) |
| keep = s > tol |
| basis = Vt[keep] # (k', d) |
| return basis.T @ basis # (d, d) projector onto rowspace |
| |
|
|
| def _build_shortcut_subspace( |
| X: np.ndarray, hospital_ids: np.ndarray, |
| method: str = "lda", subspace_dim: int = 32 |
| ) -> np.ndarray: |
| """Return a (k, d) basis whose row-span is the 'shortcut subspace'. |
| |
| method='probe' — k = (n_classes - 1) probe weight rows (small subspace). |
| method='lda' — k = subspace_dim top between-class directions: take |
| per-hospital means in feature space, center them, |
| and run SVD. This gives a rank-bounded but data-driven |
| subspace that captures hospital-discriminating variance. |
| method='pca-class' — top-PCs of features colored by hospital (mean-removed |
| per class), giving us the variance directions that |
| mostly reflect within-hospital structure × class. |
| """ |
| if method == "probe": |
| clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1) |
| clf.fit(X, hospital_ids) |
| return clf.coef_ |
| |
| if method == "lda": |
| classes = np.unique(hospital_ids) |
| global_mean = X.mean(axis=0, keepdims=True) |
| between = [] |
| for c in classes: |
| mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True) |
| between.append(mu_c - global_mean) |
| between = np.vstack(between) # (n_classes, d) |
| # Augment with random hospital-correlated directions to grow rank up |
| # to subspace_dim — use top PCs of *centered-by-hospital-mean* features. |
| if subspace_dim > between.shape[0]: |
| # within-hospital residuals |
| residuals = [] |
| for c in classes: |
| mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True) |
| residuals.append(X[hospital_ids == c] - mu_c) |
| R = np.vstack(residuals) |
| # PCA on residuals — these are within-hospital directions; remove |
| # them from the shortcut subspace by KEEPING only the between-class |
| # directions. So we just return between as-is, plus the top PCs of |
| # the *original* features projected onto the orthogonal complement |
| # of `between` IF the user wants more dims. |
| U, s, Vt = np.linalg.svd(X - global_mean, full_matrices=False) |
| top = Vt[:subspace_dim] |
| # Score each PC by how much it correlates with hospital-id variance |
| # (one-hot expansion); keep top by that correlation. |
| one_hot = np.eye(len(classes))[ |
| np.searchsorted(classes, hospital_ids) |
| ] # (N, n_classes) |
| proj = (X - global_mean) @ top.T # (N, subspace_dim) |
| corrs = np.array([ |
| np.max(np.abs([np.corrcoef(proj[:, k], one_hot[:, c])[0, 1] |
| for c in range(len(classes))])) |
| for k in range(subspace_dim) |
| ]) |
| # take the top-k most-hospital-correlated PCs |
| order = np.argsort(-np.nan_to_num(corrs)) |
| top_hosp = top[order[:subspace_dim]] |
| # combine: between-class means + top-hospital-correlated PCs |
| return np.vstack([between, top_hosp]) |
| |
| return between |
| |
| raise ValueError(f"Unknown method: {method}") |
| |
|
|
| def _classifier_logits_from_features( |
| model: nn.Module, features: np.ndarray, layer: str, device: str |
| ) -> np.ndarray: |
| """Apply the *post-`layer`* part of the network to the (modified) features |
| and return the model's binary-classification logits. |
| |
| For ResNet, `avgpool` features have shape (N, C). The classifier head |
| `model.fc` (timm: `model.get_classifier()`) maps C → 2. For non-avgpool |
| layers we do not currently support full propagation — caller should use |
| layer='avgpool' for OOD-accuracy interventions.""" |
| if layer != "avgpool": |
| raise NotImplementedError( |
| "Re-applying the classifier head from intermediate spatial layers " |
| "is not yet supported. Use --layer avgpool for the head-level " |
| "ablation." |
| ) |
| |
| # Find the classifier head (timm convention: model.fc or model.get_classifier()) |
| if hasattr(model, "get_classifier"): |
| head = model.get_classifier() |
| elif hasattr(model, "fc"): |
| head = model.fc |
| elif hasattr(model, "classifier"): |
| head = model.classifier |
| else: |
| raise RuntimeError("Could not locate classifier head on the model.") |
| |
| head = head.to(device).eval() |
| with torch.no_grad(): |
| x = torch.tensor(features, dtype=torch.float32, device=device) |
| logits = head(x).cpu().numpy() |
| return logits |
| |
|
|
| def _accuracy(logits: np.ndarray, labels: np.ndarray) -> float: |
| if logits.ndim == 1 or logits.shape[1] == 1: |
| pred = (logits.flatten() > 0).astype(int) |
| else: |
| pred = logits.argmax(axis=1) |
| return float((pred == labels).mean()) |
| |
| |
| def run_ablation( |
| run_dir: Path, |
| data_root: str, |
| layer: str = "avgpool", |
| epoch: int | None = None, |
| max_samples: int = 1000, |
| device: str = "cuda", |
| subspace_method: str = "lda", |
| subspace_dim: int = 32, |
| ) -> Dict: |
| epoch, ckpt_path = _select_epoch(run_dir, epoch) |
| |
| print(f"\n M4 — Representation Ablation") |
| print(f" run_dir : {run_dir.name}") |
| print(f" epoch : {epoch} ({ckpt_path.name})") |
| print(f" layer : {layer}") |
| |
| # Load model and dataloaders |
| model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device) |
| model.eval() |
| register_hooks(model) |
| |
| cfg_path = run_dir / "config.json" |
| seed = 42 |
| if cfg_path.exists(): |
| seed = json.loads(cfg_path.read_text()).get("seed", 42) |
| train_loader, ood_loader = _build_loaders(data_root, max_samples, seed=seed) |
| |
| # Extract features |
| print(f" Extracting features ({max_samples} samples per split)...") |
| feats_train, hosp_train, tumor_train = extract_features( |
| model, train_loader, device, max_samples=max_samples |
| ) |
| feats_ood, hosp_ood, tumor_ood = extract_features( |
| model, ood_loader, device, max_samples=max_samples // 2 |
| ) |
| |
| if layer not in feats_train: |
| raise KeyError(f"Layer '{layer}' not in extracted features " |
| f"({list(feats_train.keys())})") |
| |
| X_tr = np.asarray(feats_train[layer]) # (N_tr, D) |
| X_ood = np.asarray(feats_ood[layer]) # (N_ood, D) |
| if X_tr.ndim > 2: # spatial map; flatten |
| X_tr = X_tr.reshape(X_tr.shape[0], -1) |
| X_ood = X_ood.reshape(X_ood.shape[0], -1) |
| |
| # Normalize features (probe is sensitive to scale; classifier head was |
| # trained on un-normalized features so we keep two parallel pipelines). |
| scaler = StandardScaler().fit(X_tr) |
| X_tr_n = scaler.transform(X_tr) |
| X_ood_n = scaler.transform(X_ood) |
| |
| # ──────────── 1. Fit hospital probe + build shortcut subspace |
| print(f" Fitting hospital probe on H0/H1/H2 train features...") |
| hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1) |
| hosp_clf.fit(X_tr_n, hosp_train) |
| hosp_acc_train = hosp_clf.score(X_tr_n, hosp_train) |
| |
| # Build a richer shortcut subspace via LDA-style between-class + |
| # hospital-correlated top PCs. This catches more shortcut variance than |
| # the (n_classes - 1)-D probe-rowspace alone. |
| W = _build_shortcut_subspace(X_tr_n, np.asarray(hosp_train), |
| method=subspace_method, |
| subspace_dim=subspace_dim) |
| P = _build_projector(W) # (D, D) |
| rank_subspace = int(np.linalg.matrix_rank(P, tol=1e-8)) |
| print(f" Shortcut subspace: dim={rank_subspace} method={subspace_method} " |
| f"(probe train acc {hosp_acc_train:.3f})") |
| |
| # ──────────── 2. Build ablated versions of features |
| # Apply the projection in the *normalized* feature space, then un-scale |
| # for re-feeding to the classifier head (which was trained on raw features). |
| def ablate_norm(X_n): |
| return X_n - X_n @ P.T |
| |
| X_ood_ablated_n = ablate_norm(X_ood_n) |
| # un-scale |
| X_ood_ablated = scaler.inverse_transform(X_ood_ablated_n) |
| |
| # Sanity probe metrics |
| print(f" Re-fitting tumor probe on train features...") |
| tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1) |
| tumor_clf.fit(X_tr_n, tumor_train) |
| tumor_acc_train = tumor_clf.score(X_tr_n, tumor_train) |
| |
| # Probe accuracies on raw vs ablated OOD features |
| hosp_acc_ood_raw = hosp_clf.score(X_ood_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan") |
| hosp_acc_ood_ablated = hosp_clf.score(X_ood_ablated_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan") |
| tumor_acc_ood_raw = tumor_clf.score(X_ood_n, tumor_ood) |
| tumor_acc_ood_ablated = tumor_clf.score(X_ood_ablated_n, tumor_ood) |
| |
| # ──────────── 3. Head-level OOD classification accuracy |
| print(f" Re-classifying OOD with model head (raw vs ablated features)...") |
| logits_raw = _classifier_logits_from_features(model, X_ood, layer, device) |
| logits_ablated = _classifier_logits_from_features(model, X_ood_ablated, layer, device) |
| |
| head_acc_raw = _accuracy(logits_raw, tumor_ood) |
| head_acc_ablated = _accuracy(logits_ablated, tumor_ood) |
| |
| # ──────────── 4. Pack + report |
| result = { |
| "run_id": run_dir.name, |
| "epoch": epoch, |
| "layer": layer, |
| "max_samples": max_samples, |
| "shortcut_subspace_dim": rank_subspace, |
| "hospital_probe_train_acc": hosp_acc_train, |
| "tumor_probe_train_acc": tumor_acc_train, |
| "hospital_probe_ood_raw": hosp_acc_ood_raw, |
| "hospital_probe_ood_ablated": hosp_acc_ood_ablated, |
| "tumor_probe_ood_raw": tumor_acc_ood_raw, |
| "tumor_probe_ood_ablated": tumor_acc_ood_ablated, |
| "head_ood_acc_raw": head_acc_raw, |
| "head_ood_acc_ablated": head_acc_ablated, |
| "intervention_effect": { |
| "shortcut_collapse": hosp_acc_ood_raw - hosp_acc_ood_ablated, |
| "ood_improvement": head_acc_ablated - head_acc_raw, |
| "tumor_preservation": tumor_acc_ood_ablated - tumor_acc_ood_raw, |
| }, |
| } |
| |
| print(f"\n RESULTS") |
| print(f" hospital probe (OOD): {hosp_acc_ood_raw:.3f} → {hosp_acc_ood_ablated:.3f} " |
| f"(Δ {result['intervention_effect']['shortcut_collapse']:+.3f})") |
| print(f" tumor probe (OOD) : {tumor_acc_ood_raw:.3f} → {tumor_acc_ood_ablated:.3f} " |
| f"(Δ {result['intervention_effect']['tumor_preservation']:+.3f})") |
| print(f" head OOD acc : {head_acc_raw:.3f} → {head_acc_ablated:.3f} " |
| f"(Δ {result['intervention_effect']['ood_improvement']:+.3f})") |
| |
| return result |
| |
|
|
| def plot_ablation(result: Dict, out_path: Path): |
| metrics = ["hospital_probe_ood", "tumor_probe_ood", "head_ood_acc"] |
| raw_keys = ["hospital_probe_ood_raw", "tumor_probe_ood_raw", "head_ood_acc_raw"] |
| ablated_keys = ["hospital_probe_ood_ablated", "tumor_probe_ood_ablated", "head_ood_acc_ablated"] |
| labels = ["Hospital probe\n(↓ = causal effect)", |
| "Tumor probe\n(stable = good)", |
| "Head OOD acc\n(↑ = causal effect)"] |
| raws = [result[k] for k in raw_keys] |
| ablateds = [result[k] for k in ablated_keys] |
| |
| fig, ax = plt.subplots(figsize=(9, 5)) |
| x = np.arange(len(metrics)) |
| w = 0.35 |
| b1 = ax.bar(x - w / 2, raws, w, label="raw features", color="#444") |
| b2 = ax.bar(x + w / 2, ablateds, w, label="shortcut-ablated", color="#c33") |
| for bars in (b1, b2): |
| for b in bars: |
| ax.text(b.get_x() + b.get_width() / 2, b.get_height() + 0.005, |
| f"{b.get_height():.3f}", ha="center", va="bottom", fontsize=9) |
| ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=9) |
| ax.set_ylim(0, 1.05); ax.set_ylabel("Accuracy") |
| ax.set_title(f"M4 — Causal Ablation of Shortcut Subspace\n" |
| f"{result['run_id']} • ep{result['epoch']} • layer={result['layer']} " |
| f"• subspace dim={result['shortcut_subspace_dim']}", |
| fontsize=10, fontweight="bold") |
| ax.legend(loc="upper right") |
| ax.grid(alpha=0.3, axis="y") |
| plt.tight_layout() |
| fig.savefig(out_path, dpi=180, bbox_inches="tight") |
| plt.close(fig) |
| |
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--run_dir", required=True) |
| p.add_argument("--data_root", default="data/wilds") |
| p.add_argument("--layer", default="avgpool", |
| choices=["avgpool"]) # head-level intervention only at avgpool |
| p.add_argument("--epoch", type=int, default=None, |
| help="Specific checkpoint epoch; default = peak_ood_epoch from summary.json") |
| p.add_argument("--max_samples", type=int, default=1000) |
| p.add_argument("--device", default="cuda") |
| p.add_argument("--subspace_method", default="lda", |
| choices=["lda", "probe"], |
| help="lda = LDA-style between-class + hospital-correlated PCs; " |
| "probe = LR probe row-space (small, often only 2-D)") |
| p.add_argument("--subspace_dim", type=int, default=32, |
| help="Target subspace dim for lda method") |
| p.add_argument("--all_epochs", action="store_true", |
| help="Sweep across all periodic checkpoints") |
| args = p.parse_args() |
| |
| run_dir = Path(args.run_dir) |
| out_dir = run_dir / "mechinterp" |
| out_dir.mkdir(parents=True, exist_ok=True) |
| |
| if args.all_epochs: |
| # Sweep across every periodic checkpoint, build a trajectory. |
| ckpts = find_checkpoints(str(run_dir)) |
| # de-duplicate (final.pt may share epoch with last ep*.pt) |
| seen = set(); uniq = [] |
| for ep, p in ckpts: |
| if ep in seen: |
| continue |
| seen.add(ep); uniq.append((ep, p)) |
| |
| traj = [] |
| for ep, _ in uniq: |
| try: |
| r = run_ablation( |
| run_dir=run_dir, data_root=args.data_root, layer=args.layer, |
| epoch=ep, max_samples=args.max_samples, device=args.device, |
| subspace_method=args.subspace_method, |
| subspace_dim=args.subspace_dim, |
| ) |
| traj.append(r) |
| except Exception as e: |
| print(f" [skip ep{ep}] {e}") |
| |
| out = out_dir / f"m4_ablation_{args.layer}_trajectory.json" |
| out.write_text(json.dumps(traj, indent=2)) |
| plot_trajectory(traj, out.with_suffix(".png")) |
| print(f"\n → {out}") |
| print(f" → {out.with_suffix('.png')}") |
| return |
| |
| result = run_ablation( |
| run_dir=run_dir, |
| data_root=args.data_root, |
| layer=args.layer, |
| epoch=args.epoch, |
| max_samples=args.max_samples, |
| device=args.device, |
| subspace_method=args.subspace_method, |
| subspace_dim=args.subspace_dim, |
| ) |
| |
| base = out_dir / f"m4_ablation_{args.layer}_ep{result['epoch']:05d}" |
| (base.with_suffix(".json")).write_text(json.dumps(result, indent=2)) |
| plot_ablation(result, base.with_suffix(".png")) |
| print(f"\n → {base.with_suffix('.json')}") |
| print(f" → {base.with_suffix('.png')}") |
| |
|
|
| def plot_trajectory(traj, out_path: Path): |
| """Plot the intervention effect across training epochs.""" |
| eps = [r["epoch"] for r in traj] |
| head_raw = [r["head_ood_acc_raw"] for r in traj] |
| head_abl = [r["head_ood_acc_ablated"] for r in traj] |
| tum_raw = [r["tumor_probe_ood_raw"] for r in traj] |
| tum_abl = [r["tumor_probe_ood_ablated"] for r in traj] |
| |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
| |
| # Panel A: head OOD acc raw vs ablated |
| ax = axes[0] |
| ax.plot(eps, head_raw, "k-o", lw=2, label="raw features") |
| ax.plot(eps, head_abl, "r-s", lw=2, label="shortcut-ablated features") |
| ax.fill_between(eps, head_raw, head_abl, |
| where=[a > b for a, b in zip(head_abl, head_raw)], |
| color="seagreen", alpha=0.3, label="ablation helps") |
| ax.fill_between(eps, head_raw, head_abl, |
| where=[a < b for a, b in zip(head_abl, head_raw)], |
| color="salmon", alpha=0.3, label="ablation hurts") |
| ax.set_xlabel("Training epoch"); ax.set_ylabel("OOD (H4) head accuracy") |
| ax.set_title("Head OOD accuracy: raw vs shortcut-ablated", fontweight="bold") |
| ax.legend(fontsize=9); ax.grid(alpha=0.3) |
| |
| # Panel B: tumor probe survival |
| ax = axes[1] |
| ax.plot(eps, tum_raw, "k-o", lw=2, label="raw features") |
| ax.plot(eps, tum_abl, "g-s", lw=2, label="shortcut-ablated features") |
| ax.set_xlabel("Training epoch"); ax.set_ylabel("Tumor probe OOD accuracy") |
| ax.set_title("Tumor probe survival under ablation\n(stable line = causal feature preserved)", |
| fontweight="bold") |
| ax.legend(fontsize=9); ax.grid(alpha=0.3); ax.set_ylim(0.4, 1.0) |
| |
| rid = traj[0]["run_id"] if traj else "?" |
| layer = traj[0]["layer"] if traj else "?" |
| fig.suptitle(f"M4 — Causal Ablation Trajectory: {rid} • layer={layer}", |
| fontsize=11, fontweight="bold") |
| plt.tight_layout() |
| fig.savefig(out_path, dpi=180, bbox_inches="tight") |
| plt.close(fig) |
| |
|
|
| if __name__ == "__main__": |
| main() |
| |
| ``` |
| |
| --- |
| |
| ## 14. Full source: `experiments/mechinterp_m5_steering.py` |
| |
| ```python |
| """M5 — Activation Steering: causally manipulate the shortcut direction. |
|
|
| For one checkpoint (default: peak_ood_epoch from summary.json), we: |
| 1. Extract avgpool features for train (H0-H2) + OOD (H4) splits. |
| 2. Identify the dominant shortcut direction `v_s` as the top eigenvector |
| of the between-hospital covariance (LDA's first projection direction). |
| 3. Sweep α ∈ {-3, -2, -1, 0, +1, +2, +3} and apply |
| h' = h + α · σ_align · v_s |
| where σ_align is the std of features projected onto v_s (so α counts |
| in 'standard deviations of shortcut activation'). |
| 4. Re-classify OOD with the original head. |
| 5. Re-fit hospital + tumor probes on the steered features and report |
| accuracy curves. |
| |
| Strong mechanistic claim if: |
| - tumor-head OOD acc declines monotonically as |α| grows |
| - hospital-probe acc on steered features rises with |α| |
| - tumor-probe acc on steered features stays approximately flat (the |
| *causal* feature isn't aligned with the shortcut direction) |
| |
| Usage |
| ----- |
| python -m experiments.mechinterp_m5_steering \\ |
| --run_dir experiments/runs/<id> \\ |
| --data_root data/wilds \\ |
| [--epoch 50] # default: peak_ood_epoch from summary.json |
| [--max_samples 1000] [--alphas " -3,-2,-1,0,1,2,3 "] |
| """ |
| from __future__ import annotations |
| |
| import argparse |
| import json |
| from pathlib import Path |
| from typing import Dict, List, Tuple |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.preprocessing import StandardScaler |
| |
| from experiments.mechinterp_m1 import ( |
| register_hooks, extract_features, load_model_from_checkpoint, |
| find_checkpoints, |
| ) |
| from experiments.mechinterp_m4_ablation import ( |
| _select_epoch, _build_loaders, |
| _classifier_logits_from_features, _accuracy, |
| ) |
| |
|
|
| def _top_lda_direction(X: np.ndarray, hospital_ids: np.ndarray) -> np.ndarray: |
| """Return a unit vector aligned with the dominant between-hospital direction |
| in feature space (LDA-1).""" |
| classes = np.unique(hospital_ids) |
| global_mean = X.mean(axis=0, keepdims=True) |
| means = np.vstack([ |
| X[hospital_ids == c].mean(axis=0, keepdims=True) - global_mean |
| for c in classes |
| ]) |
| # SVD: rows of Vt are the orthonormal between-class directions ranked by |
| # singular value (variance explained between hospitals). |
| U, s, Vt = np.linalg.svd(means, full_matrices=False) |
| return Vt[0] # (D,) unit vector |
| |
|
|
| def run_steering( |
| run_dir: Path, |
| data_root: str, |
| epoch: int | None = None, |
| max_samples: int = 1000, |
| device: str = "cuda", |
| alphas: List[float] = None, |
| ) -> Dict: |
| if alphas is None: |
| alphas = [-3.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 3.0] |
| |
| epoch, ckpt_path = _select_epoch(run_dir, epoch) |
| |
| print(f"\n M5 — Activation Steering") |
| print(f" run_dir : {run_dir.name}") |
| print(f" epoch : {epoch} ({ckpt_path.name})") |
| print(f" alphas : {alphas}") |
| |
| model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device) |
| model.eval() |
| register_hooks(model) |
| |
| cfg_path = run_dir / "config.json" |
| seed = 42 |
| if cfg_path.exists(): |
| seed = json.loads(cfg_path.read_text()).get("seed", 42) |
| train_loader, ood_loader = _build_loaders(data_root, max_samples, seed=seed) |
| |
| print(f" Extracting features...") |
| feats_train, hosp_train, tumor_train = extract_features( |
| model, train_loader, device, max_samples=max_samples |
| ) |
| feats_ood, hosp_ood, tumor_ood = extract_features( |
| model, ood_loader, device, max_samples=max_samples // 2 |
| ) |
| |
| layer = "avgpool" |
| X_tr = np.asarray(feats_train[layer]); X_tr = X_tr.reshape(X_tr.shape[0], -1) |
| X_ood = np.asarray(feats_ood[layer]); X_ood = X_ood.reshape(X_ood.shape[0], -1) |
| hosp_train = np.asarray(hosp_train) |
| hosp_ood = np.asarray(hosp_ood) |
| tumor_train = np.asarray(tumor_train) |
| tumor_ood = np.asarray(tumor_ood) |
| |
| # Standardize for probe-fitting; un-standardize when feeding head |
| scaler = StandardScaler().fit(X_tr) |
| X_tr_n = scaler.transform(X_tr) |
| X_ood_n = scaler.transform(X_ood) |
| |
| # 1. Top LDA direction in normalized feature space |
| v = _top_lda_direction(X_tr_n, hosp_train) # (D,) unit vec |
| # Std of training features projected onto v (scale unit for α) |
| sigma = float(np.std(X_tr_n @ v)) |
| print(f" Top hospital direction v_s : ‖v‖={np.linalg.norm(v):.3f}, " |
| f"σ(X_tr·v)={sigma:.3f}") |
| |
| # 2. Pre-fit reference probes on un-steered train features |
| hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1).fit(X_tr_n, hosp_train) |
| tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1).fit(X_tr_n, tumor_train) |
| |
| # 3. Sweep α |
| sweep = [] |
| for alpha in alphas: |
| # Steer features along v in normalized space, then un-scale for the head. |
| X_ood_steered_n = X_ood_n + alpha * sigma * v[None, :] |
| X_ood_steered = scaler.inverse_transform(X_ood_steered_n) |
| |
| # Head OOD accuracy |
| logits = _classifier_logits_from_features(model, X_ood_steered, layer, device) |
| head_acc = _accuracy(logits, tumor_ood) |
| |
| # Probe accuracies on steered features |
| if len(np.unique(hosp_ood)) > 1: |
| hosp_acc = hosp_clf.score(X_ood_steered_n, hosp_ood) |
| else: |
| hosp_acc = float("nan") |
| tumor_acc = tumor_clf.score(X_ood_steered_n, tumor_ood) |
| |
| sweep.append({ |
| "alpha": float(alpha), |
| "head_ood_acc": head_acc, |
| "hospital_probe": hosp_acc, |
| "tumor_probe": tumor_acc, |
| }) |
| print(f" α={alpha:+.2f} head_ood={head_acc:.3f} " |
| f"hosp_probe={hosp_acc if not np.isnan(hosp_acc) else 'nan':<5} " |
| f"tumor_probe={tumor_acc:.3f}") |
| |
| return { |
| "run_id": run_dir.name, |
| "epoch": epoch, |
| "layer": layer, |
| "max_samples": max_samples, |
| "v_norm": float(np.linalg.norm(v)), |
| "sigma": sigma, |
| "sweep": sweep, |
| } |
| |
|
|
| def plot_steering(result: Dict, out_path: Path): |
| sweep = result["sweep"] |
| a = [r["alpha"] for r in sweep] |
| head = [r["head_ood_acc"] for r in sweep] |
| hosp = [r["hospital_probe"] for r in sweep] |
| tumor = [r["tumor_probe"] for r in sweep] |
| |
| fig, axes = plt.subplots(1, 2, figsize=(13, 5)) |
| |
| # Panel A — Head OOD acc vs α |
| ax = axes[0] |
| ax.plot(a, head, "k-o", lw=2, ms=7) |
| ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5) |
| ax.set_xlabel("Steering coefficient α (in σ-units of shortcut direction)") |
| ax.set_ylabel("Head OOD (H4) accuracy") |
| ax.set_title("Causal effect of steering activations along v_s\n" |
| "(monotonic decline as |α| grows = causal evidence)", |
| fontweight="bold", fontsize=10) |
| ax.grid(alpha=0.3) |
| ax.set_ylim(0.4, max(0.85, max(head) + 0.05)) |
| |
| # Panel B — Probe accuracies vs α |
| ax = axes[1] |
| ax.plot(a, hosp, "r-s", lw=2, ms=7, label="Hospital probe (↑ with |α| = good)") |
| ax.plot(a, tumor, "g-^", lw=2, ms=7, label="Tumor probe (flat = causal disjoint)") |
| ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5) |
| ax.set_xlabel("Steering coefficient α") |
| ax.set_ylabel("Probe accuracy") |
| ax.set_title("Probe responses to steering", fontweight="bold", fontsize=10) |
| ax.legend(loc="best", fontsize=9); ax.grid(alpha=0.3) |
| ax.set_ylim(0, 1.05) |
| |
| fig.suptitle(f"M5 — Activation Steering: {result['run_id']} " |
| f"• ep{result['epoch']} • layer={result['layer']}", |
| fontsize=11, fontweight="bold") |
| plt.tight_layout() |
| fig.savefig(out_path, dpi=180, bbox_inches="tight") |
| plt.close(fig) |
| |
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--run_dir", required=True) |
| p.add_argument("--data_root", default="data/wilds") |
| p.add_argument("--epoch", type=int, default=None) |
| p.add_argument("--max_samples", type=int, default=1000) |
| p.add_argument("--device", default="cuda") |
| p.add_argument("--alphas", default=None, |
| help="Comma-separated α values, e.g. ' -3,-2,-1,0,1,2,3 '") |
| p.add_argument("--all_epochs", action="store_true", |
| help="Sweep across all periodic checkpoints; output a trajectory") |
| args = p.parse_args() |
| |
| alphas = None |
| if args.alphas is not None: |
| alphas = [float(x) for x in args.alphas.split(",")] |
| |
| run_dir = Path(args.run_dir) |
| out_dir = run_dir / "mechinterp" |
| out_dir.mkdir(parents=True, exist_ok=True) |
| |
| if args.all_epochs: |
| # Trajectory mode: run M5 at every periodic checkpoint |
| ckpts = find_checkpoints(str(run_dir)) |
| seen = set(); uniq = [] |
| for ep, p in ckpts: |
| if ep in seen: |
| continue |
| seen.add(ep); uniq.append((ep, p)) |
| traj = [] |
| for ep, _ in uniq: |
| try: |
| r = run_steering( |
| run_dir=run_dir, data_root=args.data_root, epoch=ep, |
| max_samples=args.max_samples, device=args.device, alphas=alphas, |
| ) |
| traj.append(r) |
| except Exception as e: |
| print(f" [skip ep{ep}] {e}") |
| out = out_dir / "m5_steering_trajectory.json" |
| out.write_text(json.dumps(traj, indent=2)) |
| print(f"\n → {out}") |
| return |
| |
| result = run_steering( |
| run_dir=run_dir, data_root=args.data_root, epoch=args.epoch, |
| max_samples=args.max_samples, device=args.device, alphas=alphas, |
| ) |
| base = out_dir / f"m5_steering_ep{result['epoch']:05d}" |
| base.with_suffix(".json").write_text(json.dumps(result, indent=2)) |
| plot_steering(result, base.with_suffix(".png")) |
| print(f"\n → {base.with_suffix('.json')}") |
| print(f" → {base.with_suffix('.png')}") |
| |
|
|
| if __name__ == "__main__": |
| main() |
| |
| ``` |
| |
| --- |
| |
| ## 15. Full source: `experiments/mechinterp_m6_neuron_ablation.py` |
| |
| ```python |
| """M6 — Neuron-level Ablation (the textbook reviewer-asked intervention). |
|
|
| Pipeline: |
| 1. At a chosen checkpoint (default: peak_ood_epoch), extract avgpool |
| features for train (H0-H2) and OOD (H4) splits. |
| 2. Score each of the 512 avgpool channels by *how predictive its activation |
| is of hospital ID*: we use a one-vs-rest logistic-regression coefficient |
| per channel × class as the per-neuron shortcut score: |
| score_c = max_h |β_{h,c}| (β = coefficients of LR fit per channel) |
| ↑ score_c → channel c is more strongly stain-shortcut-aligned. |
| 3. Sweep top-K ∈ {0, 8, 16, 32, 64, 128} ablated neurons (zero out their |
| activations) and measure: |
| - head OOD acc (raw vs ablated) |
| - hospital-probe acc on raw vs ablated features |
| - tumor-probe acc on raw vs ablated features |
| 4. Strong mechanistic claim: |
| - hospital-probe acc collapses sharply with K (these neurons are |
| carrying hospital info) |
| - head OOD acc *improves* (or at least preserves) at small K (the |
| model was using shortcut neurons to harm OOD) |
| - tumor-probe acc stays flat (causal info is distributed elsewhere) |
| |
| Usage |
| ----- |
| python -m experiments.mechinterp_m6_neuron_ablation \\ |
| --run_dir experiments/runs/<id> \\ |
| --data_root data/wilds \\ |
| [--epoch 50] [--max_samples 1000] \\ |
| [--ks "0,4,8,16,32,64,128,256"] |
| """ |
| from __future__ import annotations |
| |
| import argparse |
| import json |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.preprocessing import StandardScaler |
| |
| from torch.utils.data import DataLoader, Subset |
| from torchvision import transforms |
| |
| from experiments.mechinterp_m1 import ( |
| register_hooks, extract_features, load_model_from_checkpoint, |
| ) |
| from experiments.mechinterp_m4_ablation import ( |
| _select_epoch, _TransformWrapper, |
| _classifier_logits_from_features, _accuracy, |
| ) |
| from utils.camelyon_data import get_camelyon_subsets |
| |
|
|
| def _build_loaders_with_id(data_root: str, max_samples: int, seed: int = 42): |
| """Like M4's _build_loaders but also returns an ID validation loader so |
| we can track ID acc and compute the OOD/ID degradation ratio.""" |
| transform = transforms.Compose([ |
| transforms.Resize((96, 96)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
| train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets( |
| root_dir=data_root, download=False |
| ) |
| train_t = _TransformWrapper(train_ds, transform) |
| id_t = _TransformWrapper(id_val_ds, transform) |
| ood_t = _TransformWrapper(ood_test_ds, transform) |
| |
| torch.manual_seed(seed) |
| train_idx = torch.randperm(len(train_t))[:max_samples] |
| id_idx = torch.randperm(len(id_t))[:max_samples // 2] |
| ood_idx = torch.randperm(len(ood_t))[:max_samples // 2] |
| |
| train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128, |
| shuffle=False, num_workers=0) |
| id_loader = DataLoader(Subset(id_t, id_idx), batch_size=128, |
| shuffle=False, num_workers=0) |
| ood_loader = DataLoader(Subset(ood_t, ood_idx), batch_size=128, |
| shuffle=False, num_workers=0) |
| return train_loader, id_loader, ood_loader |
| |
|
|
| def _per_neuron_shortcut_scores(X_n: np.ndarray, hosp: np.ndarray) -> np.ndarray: |
| """Return a (D,) array — score per channel c, larger = more hospital-predictive. |
| |
| Uses a 1-feature-at-a-time log-reg fit's |coef| would be dominated by feature |
| scale; instead we fit a single multiclass LR over all features and use the |
| L2 norm of (β_{:,c}) — the column norm of the LR coefficient matrix — |
| as channel c's hospital-discrimination score. |
| """ |
| clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1).fit(X_n, hosp) |
| W = clf.coef_ # (n_classes, D) |
| # column norms — large means many class-discriminations rely on this neuron |
| return np.linalg.norm(W, axis=0) # (D,) |
| |
|
|
| def _ablate_and_eval( |
| X_n, mask, scaler, head_target, model, layer, device, |
| hosp_clf, tumor_clf, hosp_target, tumor_target, |
| ): |
| """Apply mask to normalized features, unscale, evaluate everything.""" |
| X_ablated_n = X_n * mask[None, :] |
| X_ablated = scaler.inverse_transform(X_ablated_n) |
| logits = _classifier_logits_from_features(model, X_ablated, layer, device) |
| head_acc = _accuracy(logits, head_target) |
| hosp_acc = hosp_clf.score(X_ablated_n, hosp_target) if hosp_clf is not None and len(np.unique(hosp_target)) > 1 else float("nan") |
| tumor_acc = tumor_clf.score(X_ablated_n, tumor_target) |
| return head_acc, hosp_acc, tumor_acc |
| |
|
|
| def run_neuron_ablation( |
| run_dir: Path, |
| data_root: str, |
| epoch: int | None = None, |
| max_samples: int = 1000, |
| device: str = "cuda", |
| ks: List[int] = None, |
| n_random_samples: int = 5, |
| include_morphology: bool = True, |
| include_id: bool = True, |
| ) -> Dict: |
| if ks is None: |
| # Dose-response curve emphasizing small K (per reviewer guidance) |
| ks = [0, 4, 8, 16, 32, 64, 128, 256] |
| |
| epoch, ckpt_path = _select_epoch(run_dir, epoch) |
| |
| print(f"\n M6 — Neuron Ablation (with random + morphology controls)") |
| print(f" run_dir : {run_dir.name}") |
| print(f" epoch : {epoch} ({ckpt_path.name})") |
| print(f" ks : {ks}") |
| print(f" random ablation: {n_random_samples} samplings per K") |
| |
| model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device) |
| model.eval() |
| register_hooks(model) |
| |
| cfg_path = run_dir / "config.json" |
| seed = 42 |
| if cfg_path.exists(): |
| seed = json.loads(cfg_path.read_text()).get("seed", 42) |
| |
| if include_id: |
| train_loader, id_loader, ood_loader = _build_loaders_with_id(data_root, max_samples, seed=seed) |
| else: |
| from experiments.mechinterp_m4_ablation import _build_loaders as _bl |
| train_loader, ood_loader = _bl(data_root, max_samples, seed=seed) |
| id_loader = None |
| |
| print(f" Extracting features (train + id + ood)...") |
| feats_train, hosp_train, tumor_train = extract_features( |
| model, train_loader, device, max_samples=max_samples |
| ) |
| feats_ood, hosp_ood, tumor_ood = extract_features( |
| model, ood_loader, device, max_samples=max_samples // 2 |
| ) |
| feats_id, hosp_id, tumor_id = (None, None, None) |
| if id_loader is not None: |
| feats_id, hosp_id, tumor_id = extract_features( |
| model, id_loader, device, max_samples=max_samples // 2 |
| ) |
| |
| layer = "avgpool" |
| def _to_2d(arr): |
| a = np.asarray(arr); return a.reshape(a.shape[0], -1) |
| X_tr = _to_2d(feats_train[layer]) |
| X_ood = _to_2d(feats_ood[layer]) |
| X_id = _to_2d(feats_id[layer]) if feats_id is not None else None |
| hosp_train = np.asarray(hosp_train); hosp_ood = np.asarray(hosp_ood) |
| tumor_train = np.asarray(tumor_train); tumor_ood = np.asarray(tumor_ood) |
| if X_id is not None: |
| hosp_id = np.asarray(hosp_id); tumor_id = np.asarray(tumor_id) |
| |
| scaler = StandardScaler().fit(X_tr) |
| X_tr_n = scaler.transform(X_tr) |
| X_ood_n = scaler.transform(X_ood) |
| X_id_n = scaler.transform(X_id) if X_id is not None else None |
| |
| # 1. Per-neuron scores: shortcut (hospital) and morphology (tumor) |
| print(f" Scoring {X_tr.shape[1]} avgpool channels...") |
| shortcut_scores = _per_neuron_shortcut_scores(X_tr_n, hosp_train) |
| morphology_scores = _per_neuron_shortcut_scores(X_tr_n, tumor_train) if include_morphology else None |
| rank_shortcut = np.argsort(-shortcut_scores) |
| rank_morphology = np.argsort(-morphology_scores) if morphology_scores is not None else None |
| |
| hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1).fit(X_tr_n, hosp_train) |
| tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1).fit(X_tr_n, tumor_train) |
| |
| rng = np.random.default_rng(seed) |
| D = X_tr.shape[1] |
| |
| sweep = [] |
| for k in ks: |
| row = {"k": int(k)} |
| |
| # Mask helpers |
| def make_mask(indices): |
| m = np.ones(D) |
| if k > 0: |
| m[indices[:k]] = 0.0 |
| return m |
| |
| # ── A: top-K SHORTCUT neurons (the targeted ablation) ── |
| mask_s = make_mask(rank_shortcut) |
| h_ood, hp_ood, tp_ood = _ablate_and_eval( |
| X_ood_n, mask_s, scaler, tumor_ood, model, layer, device, |
| hosp_clf, tumor_clf, hosp_ood, tumor_ood, |
| ) |
| row["shortcut_head_ood"] = float(h_ood) |
| row["shortcut_hosp_probe"] = float(hp_ood) |
| row["shortcut_tumor_probe"] = float(tp_ood) |
| if X_id_n is not None: |
| h_id, _, _ = _ablate_and_eval( |
| X_id_n, mask_s, scaler, tumor_id, model, layer, device, |
| None, tumor_clf, hosp_id, tumor_id, |
| ) |
| row["shortcut_head_id"] = float(h_id) |
| |
| # ── B: top-K MORPHOLOGY neurons (control: ablate the causal neurons) ── |
| if include_morphology and rank_morphology is not None: |
| mask_m = make_mask(rank_morphology) |
| h_ood_m, _, _ = _ablate_and_eval( |
| X_ood_n, mask_m, scaler, tumor_ood, model, layer, device, |
| None, tumor_clf, hosp_ood, tumor_ood, |
| ) |
| row["morphology_head_ood"] = float(h_ood_m) |
| if X_id_n is not None: |
| h_id_m, _, _ = _ablate_and_eval( |
| X_id_n, mask_m, scaler, tumor_id, model, layer, device, |
| None, tumor_clf, hosp_id, tumor_id, |
| ) |
| row["morphology_head_id"] = float(h_id_m) |
| |
| # ── C: K RANDOM neurons (control: damage uniformly) ── |
| if k > 0: |
| r_oods, r_ids = [], [] |
| for s_ in range(n_random_samples): |
| idx = rng.permutation(D)[:k] |
| m = np.ones(D); m[idx] = 0.0 |
| h_ood_r, _, _ = _ablate_and_eval( |
| X_ood_n, m, scaler, tumor_ood, model, layer, device, |
| None, tumor_clf, hosp_ood, tumor_ood, |
| ) |
| r_oods.append(h_ood_r) |
| if X_id_n is not None: |
| h_id_r, _, _ = _ablate_and_eval( |
| X_id_n, m, scaler, tumor_id, model, layer, device, |
| None, tumor_clf, hosp_id, tumor_id, |
| ) |
| r_ids.append(h_id_r) |
| row["random_head_ood_mean"] = float(np.mean(r_oods)) |
| row["random_head_ood_std"] = float(np.std(r_oods)) |
| if r_ids: |
| row["random_head_id_mean"] = float(np.mean(r_ids)) |
| row["random_head_id_std"] = float(np.std(r_ids)) |
| else: |
| row["random_head_ood_mean"] = row["shortcut_head_ood"] # K=0 same as baseline |
| row["random_head_ood_std"] = 0.0 |
| if X_id_n is not None: |
| row["random_head_id_mean"] = row.get("shortcut_head_id", float("nan")) |
| row["random_head_id_std"] = 0.0 |
| |
| sweep.append(row) |
| # Concise log line |
| print(f" K={k:>4} shortcut={row['shortcut_head_ood']:.3f} " |
| f"random={row.get('random_head_ood_mean', float('nan')):.3f}±" |
| f"{row.get('random_head_ood_std', 0):.3f} " |
| + (f"morphology={row.get('morphology_head_ood', float('nan')):.3f}" |
| if include_morphology else "")) |
| |
| return { |
| "run_id": run_dir.name, |
| "epoch": epoch, |
| "layer": layer, |
| "max_samples": max_samples, |
| "feature_dim": int(X_tr.shape[1]), |
| "shortcut_scores_top10": [int(i) for i in rank_shortcut[:10]], |
| "morphology_scores_top10": ([int(i) for i in rank_morphology[:10]] |
| if rank_morphology is not None else []), |
| "n_random_samples": n_random_samples, |
| "include_id": include_id, |
| "include_morphology": include_morphology, |
| "sweep": sweep, |
| } |
| |
|
|
| def plot_neuron_ablation(result: Dict, out_path: Path): |
| sweep = result["sweep"] |
| ks = [r["k"] for r in sweep] |
| |
| has_id = result.get("include_id", False) |
| has_morph = result.get("include_morphology", False) |
| |
| fig, axes = plt.subplots(1, 2 if has_id else 1, figsize=(13, 5)) if has_id else \ |
| plt.subplots(1, 1, figsize=(8, 5)) |
| if not has_id: |
| axes = [axes] |
| |
| # Panel A — Head OOD: shortcut vs random (vs morphology) |
| ax = axes[0] |
| shortcut_ood = [r.get("shortcut_head_ood") for r in sweep] |
| random_ood_mu = [r.get("random_head_ood_mean") for r in sweep] |
| random_ood_sd = [r.get("random_head_ood_std", 0) for r in sweep] |
| morphology_ood = [r.get("morphology_head_ood") for r in sweep] if has_morph else None |
| |
| ax.plot(ks, shortcut_ood, "r-o", lw=2.2, ms=7, label="top-K shortcut neurons (targeted)") |
| ax.plot(ks, random_ood_mu, "k-s", lw=1.8, ms=6, label="K random neurons (control)") |
| ax.fill_between(ks, |
| [m - s for m, s in zip(random_ood_mu, random_ood_sd)], |
| [m + s for m, s in zip(random_ood_mu, random_ood_sd)], |
| color="black", alpha=0.15) |
| if has_morph and morphology_ood is not None: |
| ax.plot(ks, morphology_ood, "g-^", lw=1.8, ms=6, |
| label="top-K morphology neurons (control)") |
| |
| base = shortcut_ood[0] |
| ax.axhline(base, color="gray", ls=":", lw=1, alpha=0.5, |
| label=f"K=0 baseline ({base:.3f})") |
| ax.set_xlabel("K (neurons zeroed at avgpool)") |
| ax.set_ylabel("Head OOD (H4) accuracy") |
| ax.set_xscale("symlog", linthresh=4) |
| ax.set_title("Targeted vs random ablation — OOD effect\n" |
| "(separation = shortcut neurons selectively hurt OOD)", |
| fontweight="bold", fontsize=10) |
| ax.legend(loc="best", fontsize=8); ax.grid(alpha=0.3) |
| |
| # Panel B — ID/OOD tradeoff |
| if has_id: |
| ax = axes[1] |
| shortcut_id = [r.get("shortcut_head_id") for r in sweep] |
| random_id_mu = [r.get("random_head_id_mean") for r in sweep] |
| random_id_sd = [r.get("random_head_id_std", 0) for r in sweep] |
| ax.plot(ks, shortcut_id, "r--o", lw=2, ms=7, alpha=0.85, label="ID (shortcut ablation)") |
| ax.plot(ks, shortcut_ood, "r-o", lw=2, ms=7, label="OOD (shortcut ablation)") |
| ax.plot(ks, random_id_mu, "k--s", lw=1.6, ms=5, alpha=0.7, label="ID (random ablation)") |
| ax.plot(ks, random_ood_mu, "k-s", lw=1.6, ms=5, alpha=0.7, label="OOD (random ablation)") |
| ax.set_xlabel("K (neurons zeroed at avgpool)") |
| ax.set_ylabel("Head accuracy") |
| ax.set_xscale("symlog", linthresh=4) |
| ax.set_title("ID vs OOD degradation tradeoff\n" |
| "(targeted: OOD steady or ↑ while ID slowly ↓ = good)", |
| fontweight="bold", fontsize=10) |
| ax.legend(fontsize=8, loc="best"); ax.grid(alpha=0.3) |
| |
| fig.suptitle(f"M6 — Targeted Neuron Ablation vs Random Control: {result['run_id']} " |
| f"• ep{result['epoch']}", |
| fontsize=11, fontweight="bold") |
| plt.tight_layout() |
| fig.savefig(out_path, dpi=180, bbox_inches="tight") |
| plt.close(fig) |
| |
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--run_dir", required=True) |
| p.add_argument("--data_root", default="data/wilds") |
| p.add_argument("--epoch", type=int, default=None) |
| p.add_argument("--max_samples", type=int, default=1000) |
| p.add_argument("--device", default="cuda") |
| p.add_argument("--ks", default=None, |
| help="Comma-separated K values, e.g. '0,4,8,16,32,64,128,256'") |
| p.add_argument("--n_random_samples", type=int, default=5, |
| help="Random ablation: averages over this many random K-subsets") |
| p.add_argument("--no_morphology", action="store_true", |
| help="Skip the morphology-targeted ablation control") |
| p.add_argument("--no_id", action="store_true", |
| help="Skip ID accuracy evaluation (faster but loses ID/OOD ratio)") |
| args = p.parse_args() |
| |
| ks = None |
| if args.ks is not None: |
| ks = [int(x) for x in args.ks.split(",")] |
| |
| run_dir = Path(args.run_dir) |
| out_dir = run_dir / "mechinterp" |
| out_dir.mkdir(parents=True, exist_ok=True) |
| |
| result = run_neuron_ablation( |
| run_dir=run_dir, data_root=args.data_root, epoch=args.epoch, |
| max_samples=args.max_samples, device=args.device, ks=ks, |
| n_random_samples=args.n_random_samples, |
| include_morphology=not args.no_morphology, |
| include_id=not args.no_id, |
| ) |
| base = out_dir / f"m6_neuron_ablation_ep{result['epoch']:05d}" |
| base.with_suffix(".json").write_text(json.dumps(result, indent=2)) |
| plot_neuron_ablation(result, base.with_suffix(".png")) |
| print(f"\n → {base.with_suffix('.json')}") |
| print(f" → {base.with_suffix('.png')}") |
| |
|
|
| if __name__ == "__main__": |
| main() |
| |
| ``` |
| |
| --- |
| |
| ## 16. Run inventory and summary results |
| |
| 14 runs at the canonical 3000-epoch config. The 11 runs with full M4/M5/M6 trajectories carry the paper's mechanistic claims; the 3 n=300 grokking runs have M1 only. |
| |
| ### `n = 1000` (5 grokking-favorable + 3 standard) |
| |
| | Cond. | Seed | Run ID | Peak OOD | Peak ep | Final OOD | Δ (ungrok) | Best ID | Final ‖W‖ | Final rank | Final IRM | |
| | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | |
| | Grok | 7 | 20260508-183413_grokking_n1000_s7 | 0.6876 | 50 | 0.5882 | −0.0995 | 0.8797 | 1516.7 | 69.83 | 4.47e-12 | |
| | Grok | 42 | 20260505-080445_grokking_n1000_s42 | 0.7336 | 350 | 0.6639 | −0.0696 | 0.8976 | 1470.5 | 35.87 | 4.34e-12 | |
| | Grok | 123 | 20260505-100720_grokking_n1000_s123 | 0.7270 | 350 | 0.6447 | −0.0823 | 0.8994 | 1457.2 | 56.65 | 6.73e-15 | |
| | Grok | 456 | 20260505-100720_grokking_n1000_s456 | 0.6722 | 1100 | 0.5224 | −0.1498 | 0.8824 | 1493.6 | 64.54 | 2.87e-09 | |
| | Grok | 2024 | 20260508-183413_grokking_n1000_s2024 | 0.7056 | 400 | 0.5506 | −0.1550 | 0.8959 | 1632.4 | 65.77 | 4.56e-07 | |
| | Std | 42 | 20260505-100720_standard_n1000_s42 | 0.7615 | 1 | 0.6482 | −0.1133 | 0.9011 | 812.6 | 33.35 | 2.24e-13 | |
| | Std | 123 | 20260508-183413_standard_n1000_s123 | 0.8880* | 1 | 0.6645 | −0.2235 | 0.8957 | 798.3 | 37.08 | 9.53e-14 | |
| | Std | 456 | 20260508-183413_standard_n1000_s456 | 0.7450 | 1050 | 0.5783 | −0.1667 | 0.8950 | 792.4 | 35.30 | 7.18e-09 | |
| |
| \* Std s123 peaks at epoch 1 on the random initialization (artifact). |
| |
| **Aggregates (n=1000)**: grokking 5-seed mean peak = **0.7052 ± 0.0237**, mean Δ = **−0.1112 ± 0.0345**. Standard corrected 2-seed mean peak (s42, s456) = **0.7533**; raw 3-seed mean = 0.7982. |
| |
| ### `n = 500` and `n = 300` |
| |
| | Cond. | Seed | Run ID | Peak OOD | Peak ep | Final OOD | Δ | Best ID | |
| | --- | --- | --- | --- | --- | --- | --- | --- | |
| | Grok | 42 | 20260505-080442_grokking_n500_s42 | 0.7924 | 50 | 0.5514 | −0.2410 | 0.8874 | |
| | Std | 42 | 20260505-100720_standard_n500_s42 | 0.7576 | 1050 | 0.6526 | −0.1050 | 0.8867 | |
| | Grok | 42 | 20260502-214859_grokking_n300_s42 | 0.7162 | 250 | 0.5189 | −0.1974 | 0.8664 | |
| | Grok | 123 | 20260502-214859_grokking_n300_s123 | 0.6961 | 50 | 0.5154 | −0.1807 | 0.8388 | |
| | Grok | 456 | 20260502-214859_grokking_n300_s456 | 0.6654 | 750 | 0.5469 | −0.1184 | 0.8522 | |
| | Std | 42 | 20260505-080836_standard_n300_s42 | 0.7647 | 250 | 0.7052 | −0.0596 | 0.8584 | |
| |
| **Universal finding**: every run *ungrokks* (Δ < 0 for all 14). No run shows a plateau-then-jump. `grokking_epoch = -1` everywhere; `irm_drop_pct ≈ 100%` everywhere. |
| |
| --- |
| |
| ## 17. Per-run `config.json` and `summary.json` (all 14 runs) |
| |
| Every run's exact saved JSON, verbatim from disk. |
| |
| ### `20260502-214859_grokking_n300_s42` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 42, |
| "n_train": 300, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "grokking", |
| "lr": 0.001, |
| "weight_decay": 0.005, |
| "n_epochs": 3000, |
| "init_scale": 4.0, |
| "use_grokfast": true, |
| "grokfast_alpha": 0.98, |
| "grokfast_lamb": 2.0, |
| "grad_clip": 1.0, |
| "run_id": "20260502-214859_grokking_n300_s42", |
| "run_dir": "experiments/runs/20260502-214859_grokking_n300_s42" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260502-214859_grokking_n300_s42", |
| "condition": "grokking", |
| "n_train": 300, |
| "seed": 42, |
| "best_id_val": 0.866388557806913, |
| "best_ood": 0.7162273379264937, |
| "peak_ood_epoch": 250, |
| "final_ood": 0.5188703647094787, |
| "ood_delta": -0.19735697321701506, |
| "ood_improvement": 0.01951701272132994, |
| "grokking_epoch": -1, |
| "final_weight_norm": 1131.2872887615824, |
| "final_feature_rank": 39.32365036010742, |
| "final_irm": 6.867336560523185e-14, |
| "final_shortcut_ratio": 1.0053635176200695, |
| "final_ood_gap": 0.3284478712857537, |
| "ungrokking_detected": true |
| } |
| ``` |
| |
| ### `20260502-214859_grokking_n300_s123` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 123, |
| "n_train": 300, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "grokking", |
| "lr": 0.001, |
| "weight_decay": 0.005, |
| "n_epochs": 3000, |
| "init_scale": 4.0, |
| "use_grokfast": true, |
| "grokfast_alpha": 0.98, |
| "grokfast_lamb": 2.0, |
| "grad_clip": 1.0, |
| "run_id": "20260502-214859_grokking_n300_s123", |
| "run_dir": "experiments/runs/20260502-214859_grokking_n300_s123" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260502-214859_grokking_n300_s123", |
| "condition": "grokking", |
| "n_train": 300, |
| "seed": 123, |
| "best_id_val": 0.8387663885578069, |
| "best_ood": 0.6961459778493663, |
| "peak_ood_epoch": 50, |
| "final_ood": 0.515413737155219, |
| "ood_delta": -0.18073224069414728, |
| "ood_improvement": 0.015413737155218987, |
| "grokking_epoch": -1, |
| "final_weight_norm": 981.2013665038359, |
| "final_feature_rank": 44.795772552490234, |
| "final_irm": 5.876544368309256e-13, |
| "final_shortcut_ratio": 1.0042143792951101, |
| "final_ood_gap": 0.23887708525240914, |
| "ungrokking_detected": true |
| } |
| ``` |
| |
| ### `20260502-214859_grokking_n300_s456` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 456, |
| "n_train": 300, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "grokking", |
| "lr": 0.001, |
| "weight_decay": 0.005, |
| "n_epochs": 3000, |
| "init_scale": 4.0, |
| "use_grokfast": true, |
| "grokfast_alpha": 0.98, |
| "grokfast_lamb": 2.0, |
| "grad_clip": 1.0, |
| "run_id": "20260502-214859_grokking_n300_s456", |
| "run_dir": "experiments/runs/20260502-214859_grokking_n300_s456" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260502-214859_grokking_n300_s456", |
| "condition": "grokking", |
| "n_train": 300, |
| "seed": 456, |
| "best_id_val": 0.8521752085816449, |
| "best_ood": 0.6653655324852447, |
| "peak_ood_epoch": 750, |
| "final_ood": 0.5469348884238249, |
| "ood_delta": -0.11843064406141979, |
| "ood_improvement": 0.04693488842382487, |
| "grokking_epoch": -1, |
| "final_weight_norm": 1109.5530019538146, |
| "final_feature_rank": 49.14884567260742, |
| "final_irm": 8.174740884214771e-10, |
| "final_shortcut_ratio": 0.9680124053677005, |
| "final_ood_gap": 0.25908418189798677, |
| "ungrokking_detected": true |
| } |
| ``` |
| |
| ### `20260505-080836_standard_n300_s42` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 42, |
| "n_train": 300, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "standard", |
| "lr": 0.001, |
| "weight_decay": 0.0001, |
| "n_epochs": 3000, |
| "init_scale": 1.0, |
| "use_grokfast": false, |
| "grad_clip": 1.0, |
| "run_id": "20260505-080836_standard_n300_s42", |
| "run_dir": "experiments/runs/20260505-080836_standard_n300_s42" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260505-080836_standard_n300_s42", |
| "condition": "standard", |
| "n_train": 300, |
| "seed": 42, |
| "best_id_val": 0.858373063170441, |
| "best_ood": 0.7647259388153408, |
| "peak_ood_epoch": 250, |
| "final_ood": 0.7051637783055471, |
| "ood_delta": -0.05956216050979368, |
| "ood_improvement": -0.026599572036588692, |
| "grokking_epoch": -1, |
| "irm_drop_pct": 99.99996794236725, |
| "irm_drop_epoch": 100, |
| "epoch_gap": -1, |
| "final_weight_norm": 452.7899771783757, |
| "final_feature_rank": 17.35470962524414, |
| "final_irm": 1.032695706726372e-09, |
| "final_shortcut_ratio": 1.0080115087353168, |
| "final_ood_gap": 0.12806029797574014 |
| } |
| ``` |
| |
| ### `20260505-080442_grokking_n500_s42` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 42, |
| "n_train": 500, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "grokking", |
| "lr": 0.001, |
| "weight_decay": 0.005, |
| "n_epochs": 3000, |
| "init_scale": 4.0, |
| "use_grokfast": true, |
| "grokfast_alpha": 0.98, |
| "grokfast_lamb": 2.0, |
| "grad_clip": 1.0, |
| "run_id": "20260505-080442_grokking_n500_s42", |
| "run_dir": "experiments/runs/20260505-080442_grokking_n500_s42" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260505-080442_grokking_n500_s42", |
| "condition": "grokking", |
| "n_train": 500, |
| "seed": 42, |
| "best_id_val": 0.8873957091775924, |
| "best_ood": 0.7924495026688927, |
| "peak_ood_epoch": 50, |
| "final_ood": 0.5514496672702048, |
| "ood_delta": -0.24099983539868786, |
| "ood_improvement": -0.06643779246126003, |
| "grokking_epoch": -1, |
| "irm_drop_pct": 99.99999995667906, |
| "irm_drop_epoch": 50, |
| "epoch_gap": -1, |
| "final_weight_norm": 1046.672482929869, |
| "final_feature_rank": 34.301780700683594, |
| "final_irm": 6.552892841682478e-07, |
| "final_shortcut_ratio": 0.9969108279290858, |
| "final_ood_gap": 0.32080897396936603 |
| } |
| ``` |
| |
| ### `20260505-100720_standard_n500_s42` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 42, |
| "n_train": 500, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "standard", |
| "lr": 0.001, |
| "weight_decay": 0.0001, |
| "n_epochs": 3000, |
| "init_scale": 1.0, |
| "use_grokfast": false, |
| "grad_clip": 1.0, |
| "run_id": "20260505-100720_standard_n500_s42", |
| "run_dir": "experiments/runs/20260505-100720_standard_n500_s42" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260505-100720_standard_n500_s42", |
| "condition": "standard", |
| "n_train": 500, |
| "seed": 42, |
| "best_id_val": 0.8866507747318236, |
| "best_ood": 0.7575775389752393, |
| "peak_ood_epoch": 1050, |
| "final_ood": 0.6525854163237472, |
| "ood_delta": -0.10499212265149205, |
| "ood_improvement": -0.08494603428410186, |
| "grokking_epoch": -1, |
| "irm_drop_pct": 99.99998953242944, |
| "irm_drop_epoch": 50, |
| "epoch_gap": -1, |
| "final_weight_norm": 555.1170332945212, |
| "final_feature_rank": 20.722705841064453, |
| "final_irm": 5.05333608291636e-10, |
| "final_shortcut_ratio": 0.9843676046096169, |
| "final_ood_gap": 0.19765296269889876 |
| } |
| ``` |
| |
| ### `20260505-080445_grokking_n1000_s42` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 42, |
| "n_train": 1000, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "grokking", |
| "lr": 0.001, |
| "weight_decay": 0.005, |
| "n_epochs": 3000, |
| "init_scale": 4.0, |
| "use_grokfast": true, |
| "grokfast_alpha": 0.98, |
| "grokfast_lamb": 2.0, |
| "grad_clip": 1.0, |
| "run_id": "20260505-080445_grokking_n1000_s42", |
| "run_dir": "experiments/runs/20260505-080445_grokking_n1000_s42" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260505-080445_grokking_n1000_s42", |
| "condition": "grokking", |
| "n_train": 1000, |
| "seed": 42, |
| "best_id_val": 0.8976460071513707, |
| "best_ood": 0.7335575046441084, |
| "peak_ood_epoch": 350, |
| "final_ood": 0.6639076351494345, |
| "ood_delta": -0.06964986949467389, |
| "ood_improvement": 0.011399816587109313, |
| "grokking_epoch": -1, |
| "irm_drop_pct": 99.99999500910222, |
| "irm_drop_epoch": 50, |
| "epoch_gap": -1, |
| "final_weight_norm": 1470.4773196265805, |
| "final_feature_rank": 35.86945724487305, |
| "final_irm": 4.340286376830482e-12, |
| "final_shortcut_ratio": 0.9871634909839839, |
| "final_ood_gap": 0.20680154244293736 |
| } |
| ``` |
| |
| ### `20260505-100720_grokking_n1000_s123` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 123, |
| "n_train": 1000, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "grokking", |
| "lr": 0.001, |
| "weight_decay": 0.005, |
| "n_epochs": 3000, |
| "init_scale": 4.0, |
| "use_grokfast": true, |
| "grokfast_alpha": 0.98, |
| "grokfast_lamb": 2.0, |
| "grad_clip": 1.0, |
| "run_id": "20260505-100720_grokking_n1000_s123", |
| "run_dir": "experiments/runs/20260505-100720_grokking_n1000_s123" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260505-100720_grokking_n1000_s123", |
| "condition": "grokking", |
| "n_train": 1000, |
| "seed": 123, |
| "best_id_val": 0.8994338498212158, |
| "best_ood": 0.7269734521598044, |
| "peak_ood_epoch": 350, |
| "final_ood": 0.6446610388694242, |
| "ood_delta": -0.08231241329038019, |
| "ood_improvement": 0.01956639311496222, |
| "grokking_epoch": -1, |
| "irm_drop_pct": 99.9999991755092, |
| "irm_drop_epoch": 50, |
| "epoch_gap": -1, |
| "final_weight_norm": 1457.2260357210637, |
| "final_feature_rank": 56.64516067504883, |
| "final_irm": 6.7278793620213426e-15, |
| "final_shortcut_ratio": 0.9760735232865748, |
| "final_ood_gap": 0.23534492060614198 |
| } |
| ``` |
| |
| ### `20260505-100720_grokking_n1000_s456` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 456, |
| "n_train": 1000, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "grokking", |
| "lr": 0.001, |
| "weight_decay": 0.005, |
| "n_epochs": 3000, |
| "init_scale": 4.0, |
| "use_grokfast": true, |
| "grokfast_alpha": 0.98, |
| "grokfast_lamb": 2.0, |
| "grad_clip": 1.0, |
| "run_id": "20260505-100720_grokking_n1000_s456", |
| "run_dir": "experiments/runs/20260505-100720_grokking_n1000_s456" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260505-100720_grokking_n1000_s456", |
| "condition": "grokking", |
| "n_train": 1000, |
| "seed": 456, |
| "best_id_val": 0.8824493444576877, |
| "best_ood": 0.6721847297011311, |
| "peak_ood_epoch": 1100, |
| "final_ood": 0.522397535683213, |
| "ood_delta": -0.14978719401791807, |
| "ood_improvement": -0.08302960472170617, |
| "grokking_epoch": -1, |
| "irm_drop_pct": 99.99999977269624, |
| "irm_drop_epoch": 50, |
| "epoch_gap": -1, |
| "final_weight_norm": 1493.580040593733, |
| "final_feature_rank": 64.54296875, |
| "final_irm": 2.8693030174054e-09, |
| "final_shortcut_ratio": 1.0356636404914459, |
| "final_ood_gap": 0.3221495441737596 |
| } |
| ``` |
| |
| ### `20260508-183413_grokking_n1000_s7` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 7, |
| "n_train": 1000, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "grokking", |
| "lr": 0.001, |
| "weight_decay": 0.005, |
| "n_epochs": 3000, |
| "init_scale": 4.0, |
| "use_grokfast": true, |
| "grokfast_alpha": 0.98, |
| "grokfast_lamb": 2.0, |
| "grad_clip": 1.0, |
| "run_id": "20260508-183413_grokking_n1000_s7", |
| "run_dir": "experiments/runs/20260508-183413_grokking_n1000_s7" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260508-183413_grokking_n1000_s7", |
| "condition": "grokking", |
| "n_train": 1000, |
| "seed": 7, |
| "best_id_val": 0.8797079856972586, |
| "best_ood": 0.6876454958026665, |
| "peak_ood_epoch": 50, |
| "final_ood": 0.5881910315799375, |
| "ood_delta": -0.09945446422272908, |
| "ood_improvement": -0.03798527993980294, |
| "grokking_epoch": -1, |
| "irm_drop_pct": 99.99999355254296, |
| "irm_drop_epoch": 50, |
| "epoch_gap": -1, |
| "final_weight_norm": 1516.661424559645, |
| "final_feature_rank": 69.8335952758789, |
| "final_irm": 4.471252361415434e-12, |
| "final_shortcut_ratio": 0.996297612800058, |
| "final_ood_gap": 0.26586141180504463 |
| } |
| ``` |
| |
| ### `20260508-183413_grokking_n1000_s2024` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 2024, |
| "n_train": 1000, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "grokking", |
| "lr": 0.001, |
| "weight_decay": 0.005, |
| "n_epochs": 3000, |
| "init_scale": 4.0, |
| "use_grokfast": true, |
| "grokfast_alpha": 0.98, |
| "grokfast_lamb": 2.0, |
| "grad_clip": 1.0, |
| "run_id": "20260508-183413_grokking_n1000_s2024", |
| "run_dir": "experiments/runs/20260508-183413_grokking_n1000_s2024" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260508-183413_grokking_n1000_s2024", |
| "condition": "grokking", |
| "n_train": 1000, |
| "seed": 2024, |
| "best_id_val": 0.8959177592371871, |
| "best_ood": 0.7056105532955534, |
| "peak_ood_epoch": 400, |
| "final_ood": 0.5506031462365085, |
| "ood_delta": -0.15500740705904492, |
| "ood_improvement": -0.04230488865896964, |
| "grokking_epoch": -1, |
| "irm_drop_pct": 99.9999998741651, |
| "irm_drop_epoch": 50, |
| "epoch_gap": -1, |
| "final_weight_norm": 1632.3948325021925, |
| "final_feature_rank": 65.77389526367188, |
| "final_irm": 4.5635979972757923e-07, |
| "final_shortcut_ratio": 0.9633737610070339, |
| "final_ood_gap": 0.308663838268855 |
| } |
| ``` |
| |
| ### `20260505-100720_standard_n1000_s42` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 42, |
| "n_train": 1000, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "standard", |
| "lr": 0.001, |
| "weight_decay": 0.0001, |
| "n_epochs": 3000, |
| "init_scale": 1.0, |
| "use_grokfast": false, |
| "grad_clip": 1.0, |
| "run_id": "20260505-100720_standard_n1000_s42", |
| "run_dir": "experiments/runs/20260505-100720_standard_n1000_s42" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260505-100720_standard_n1000_s42", |
| "condition": "standard", |
| "n_train": 1000, |
| "seed": 42, |
| "best_id_val": 0.9011025029797378, |
| "best_ood": 0.7615162132292426, |
| "peak_ood_epoch": 1, |
| "final_ood": 0.6482234815528958, |
| "ood_delta": -0.11329273167634679, |
| "ood_improvement": -0.022357561078844124, |
| "grokking_epoch": -1, |
| "irm_drop_pct": 99.99999329006182, |
| "irm_drop_epoch": 50, |
| "epoch_gap": -1, |
| "final_weight_norm": 812.5540534619066, |
| "final_feature_rank": 33.34842300415039, |
| "final_irm": 2.2439123855463178e-13, |
| "final_shortcut_ratio": 0.99423543483118, |
| "final_ood_gap": 0.24736650652815306 |
| } |
| ``` |
| |
| ### `20260508-183413_standard_n1000_s123` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 123, |
| "n_train": 1000, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "standard", |
| "lr": 0.001, |
| "weight_decay": 0.0001, |
| "n_epochs": 3000, |
| "init_scale": 1.0, |
| "use_grokfast": false, |
| "grad_clip": 1.0, |
| "run_id": "20260508-183413_standard_n1000_s123", |
| "run_dir": "experiments/runs/20260508-183413_standard_n1000_s123" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260508-183413_standard_n1000_s123", |
| "condition": "standard", |
| "n_train": 1000, |
| "seed": 123, |
| "best_id_val": 0.8957091775923719, |
| "best_ood": 0.8879652926376185, |
| "peak_ood_epoch": 1, |
| "final_ood": 0.6644837397418111, |
| "ood_delta": -0.2234815528958074, |
| "ood_improvement": -0.10371998965363183, |
| "grokking_epoch": -1, |
| "irm_drop_pct": 99.9999667168104, |
| "irm_drop_epoch": 150, |
| "epoch_gap": -1, |
| "final_weight_norm": 798.3091903337586, |
| "final_feature_rank": 37.0809326171875, |
| "final_irm": 9.526699497634447e-14, |
| "final_shortcut_ratio": 0.9913218948516812, |
| "final_ood_gap": 0.22869266073494698 |
| } |
| ``` |
| |
| ### `20260508-183413_standard_n1000_s456` |
| |
| `config.json`: |
| |
| ```json |
| { |
| "seed": 456, |
| "n_train": 1000, |
| "batch_size": 32, |
| "img_size": 96, |
| "n_classes": 2, |
| "log_every": 50, |
| "device": "cuda", |
| "condition": "standard", |
| "lr": 0.001, |
| "weight_decay": 0.0001, |
| "n_epochs": 3000, |
| "init_scale": 1.0, |
| "use_grokfast": false, |
| "grad_clip": 1.0, |
| "run_id": "20260508-183413_standard_n1000_s456", |
| "run_dir": "experiments/runs/20260508-183413_standard_n1000_s456" |
| } |
| ``` |
| |
| `results/summary.json`: |
| |
| ```json |
| { |
| "run_id": "20260508-183413_standard_n1000_s456", |
| "condition": "standard", |
| "n_train": 1000, |
| "seed": 456, |
| "best_id_val": 0.8949940405244339, |
| "best_ood": 0.7449737813624285, |
| "peak_ood_epoch": 1050, |
| "final_ood": 0.5783149528534813, |
| "ood_delta": -0.1666588285089472, |
| "ood_improvement": 0.02752839372633853, |
| "grokking_epoch": -1, |
| "irm_drop_pct": 99.99999467428148, |
| "irm_drop_epoch": 50, |
| "epoch_gap": -1, |
| "final_weight_norm": 792.3747983792097, |
| "final_feature_rank": 35.297889709472656, |
| "final_irm": 7.180730676736857e-09, |
| "final_shortcut_ratio": 0.987733651871859, |
| "final_ood_gap": 0.2792237837376986 |
| } |
| ``` |
| |
| --- |
| |
| ## 18. Full training log: grokking n=1000 seed=42 |
| |
| ``` |
| # launched: 2026-05-05T08:04:45Z |
| # host: ubuntu-Standard-PC-Q35-ICH9-2009 |
| # pwd: /home/garima/CausalGrok |
| # cmd: /home/garima/anaconda3/envs/causalgrok/bin/python -u -m experiments.causalgrok_camelyon_v2 --condition grokking --n_train 1000 --seed 42 --run_dir experiments/runs/20260505-080445_grokking_n1000_s42 --wandb_project causalgrok --wandb_mode offline |
| ---- |
| |
| Device: cuda |
| Run ID: 20260505-080445_grokking_n1000_s42 |
| Started: 2026-05-05T08:04:50.408035+00:00 |
| Env hospital=0: 181 samples, positive rate=0.53 |
| Env hospital=3: 371 samples, positive rate=0.48 |
| Env hospital=4: 448 samples, positive rate=0.50 |
| Train: 1000 | ID val (H3): 33560 | OOD test (H4): 85054 |
| Params: 11,177,538 |
|
|
| ============================================================ |
| GROKKING | Camelyon17 v2 | 3000 epochs |
| WD=0.005 | α=4.0 | n=1000 |
| Tracking: ID val (H3) + OOD test (H4) at every checkpoint |
| Grokking detection: watching OOD acc, not ID val acc |
| IRM envs: 3 hospitals |
| ============================================================ |
| ep 1 | tr 0.734 | id 0.728 | ood 0.499 | gap +0.228 | ‖W‖ 356.1 | rank 109.6 | IRM 0.2004 | sc 1.12x |
| ep 50 | tr 0.996 | id 0.859 | ood 0.697 | gap +0.162 | ‖W‖ 495.4 | rank 91.4 | IRM 0.0001 | sc 0.97x |
| ep 100 | tr 1.000 | id 0.883 | ood 0.592 | gap +0.291 | ‖W‖ 551.6 | rank 81.7 | IRM 0.0000 | sc 0.96x |
| ep 150 | tr 1.000 | id 0.872 | ood 0.641 | gap +0.231 | ‖W‖ 650.2 | rank 72.2 | IRM 0.0000 | sc 0.98x |
| ✓ Checkpoint → ep00200.pt |
| ep 200 | tr 1.000 | id 0.885 | ood 0.613 | gap +0.272 | ‖W‖ 741.5 | rank 62.6 | IRM 0.0000 | sc 0.97x |
| ep 250 | tr 1.000 | id 0.890 | ood 0.659 | gap +0.231 | ‖W‖ 842.7 | rank 61.1 | IRM 0.0000 | sc 0.99x |
| ep 300 | tr 1.000 | id 0.876 | ood 0.650 | gap +0.227 | ‖W‖ 878.1 | rank 60.7 | IRM 0.0000 | sc 0.99x |
| ep 350 | tr 1.000 | id 0.881 | ood 0.734 | gap +0.147 | ‖W‖ 907.9 | rank 53.9 | IRM 0.0000 | sc 0.97x |
| ✓ Checkpoint → ep00400.pt |
| ep 400 | tr 1.000 | id 0.875 | ood 0.647 | gap +0.229 | ‖W‖ 1005.7 | rank 60.9 | IRM 0.0000 | sc 1.00x |
| ep 450 | tr 1.000 | id 0.876 | ood 0.615 | gap +0.261 | ‖W‖ 1031.2 | rank 57.8 | IRM 0.0000 | sc 0.99x |
| ep 500 | tr 1.000 | id 0.880 | ood 0.607 | gap +0.273 | ‖W‖ 1023.1 | rank 55.7 | IRM 0.0000 | sc 0.99x |
| ep 550 | tr 1.000 | id 0.888 | ood 0.611 | gap +0.276 | ‖W‖ 1101.7 | rank 52.6 | IRM 0.0000 | sc 0.99x |
| ✓ Checkpoint → ep00600.pt |
| ep 600 | tr 1.000 | id 0.884 | ood 0.546 | gap +0.337 | ‖W‖ 1102.1 | rank 54.5 | IRM 0.0000 | sc 1.00x |
| ep 650 | tr 1.000 | id 0.878 | ood 0.605 | gap +0.273 | ‖W‖ 1142.4 | rank 55.2 | IRM 0.0000 | sc 0.98x |
| ep 700 | tr 1.000 | id 0.889 | ood 0.587 | gap +0.303 | ‖W‖ 1187.3 | rank 47.9 | IRM 0.0000 | sc 0.98x |
| ep 750 | tr 1.000 | id 0.892 | ood 0.670 | gap +0.222 | ‖W‖ 1180.7 | rank 48.4 | IRM 0.0000 | sc 0.98x |
| ✓ Checkpoint → ep00800.pt |
| ep 800 | tr 1.000 | id 0.879 | ood 0.675 | gap +0.204 | ‖W‖ 1265.4 | rank 48.0 | IRM 0.0000 | sc 0.99x |
| ep 850 | tr 1.000 | id 0.886 | ood 0.627 | gap +0.259 | ‖W‖ 1300.7 | rank 48.3 | IRM 0.0000 | sc 0.98x |
| ep 900 | tr 1.000 | id 0.886 | ood 0.653 | gap +0.233 | ‖W‖ 1290.6 | rank 48.0 | IRM 0.0000 | sc 0.98x |
| ep 950 | tr 1.000 | id 0.883 | ood 0.640 | gap +0.243 | ‖W‖ 1280.4 | rank 47.1 | IRM 0.0000 | sc 0.98x |
| ✓ Checkpoint → ep01000.pt |
| ep 1000 | tr 1.000 | id 0.886 | ood 0.653 | gap +0.233 | ‖W‖ 1270.3 | rank 47.3 | IRM 0.0000 | sc 0.99x |
| ep 1050 | tr 1.000 | id 0.898 | ood 0.697 | gap +0.201 | ‖W‖ 1289.1 | rank 42.4 | IRM 0.0000 | sc 0.99x |
| ep 1100 | tr 0.999 | id 0.876 | ood 0.641 | gap +0.235 | ‖W‖ 1308.0 | rank 46.4 | IRM 0.0000 | sc 0.99x |
| ep 1150 | tr 1.000 | id 0.894 | ood 0.663 | gap +0.231 | ‖W‖ 1313.5 | rank 45.8 | IRM 0.0000 | sc 0.99x |
| ✓ Checkpoint → ep01200.pt |
| ep 1200 | tr 0.999 | id 0.878 | ood 0.592 | gap +0.286 | ‖W‖ 1320.4 | rank 42.9 | IRM 0.0000 | sc 0.99x |
| ep 1250 | tr 1.000 | id 0.877 | ood 0.546 | gap +0.330 | ‖W‖ 1373.5 | rank 47.2 | IRM 0.0000 | sc 0.99x |
| ep 1300 | tr 1.000 | id 0.853 | ood 0.600 | gap +0.254 | ‖W‖ 1375.3 | rank 48.0 | IRM 0.0000 | sc 0.98x |
| ep 1350 | tr 1.000 | id 0.881 | ood 0.575 | gap +0.306 | ‖W‖ 1415.1 | rank 49.4 | IRM 0.0000 | sc 0.99x |
| ✓ Checkpoint → ep01400.pt |
| ep 1400 | tr 1.000 | id 0.892 | ood 0.605 | gap +0.287 | ‖W‖ 1428.3 | rank 42.5 | IRM 0.0000 | sc 1.00x |
| ep 1450 | tr 1.000 | id 0.867 | ood 0.619 | gap +0.248 | ‖W‖ 1438.2 | rank 41.4 | IRM 0.0000 | sc 0.99x |
| ep 1500 | tr 1.000 | id 0.875 | ood 0.652 | gap +0.223 | ‖W‖ 1440.8 | rank 47.2 | IRM 0.0000 | sc 0.97x |
| ep 1550 | tr 1.000 | id 0.879 | ood 0.643 | gap +0.236 | ‖W‖ 1429.5 | rank 45.8 | IRM 0.0000 | sc 0.97x |
| ✓ Checkpoint → ep01600.pt |
| ep 1600 | tr 1.000 | id 0.878 | ood 0.630 | gap +0.248 | ‖W‖ 1418.2 | rank 46.4 | IRM 0.0000 | sc 0.98x |
| ep 1650 | tr 1.000 | id 0.881 | ood 0.636 | gap +0.245 | ‖W‖ 1406.9 | rank 45.9 | IRM 0.0000 | sc 0.98x |
| ep 1700 | tr 1.000 | id 0.883 | ood 0.634 | gap +0.249 | ‖W‖ 1395.7 | rank 47.0 | IRM 0.0000 | sc 0.98x |
| ep 1750 | tr 1.000 | id 0.875 | ood 0.702 | gap +0.173 | ‖W‖ 1441.2 | rank 45.9 | IRM 0.0000 | sc 0.99x |
| ✓ Checkpoint → ep01800.pt |
| ep 1800 | tr 1.000 | id 0.880 | ood 0.648 | gap +0.232 | ‖W‖ 1433.0 | rank 45.6 | IRM 0.0000 | sc 0.98x |
| ep 1850 | tr 1.000 | id 0.880 | ood 0.644 | gap +0.236 | ‖W‖ 1421.6 | rank 45.8 | IRM 0.0000 | sc 0.98x |
| ep 1900 | tr 1.000 | id 0.885 | ood 0.644 | gap +0.242 | ‖W‖ 1410.4 | rank 43.9 | IRM 0.0000 | sc 0.98x |
| ep 1950 | tr 1.000 | id 0.883 | ood 0.549 | gap +0.334 | ‖W‖ 1415.3 | rank 49.1 | IRM 0.0000 | sc 0.98x |
| ✓ Checkpoint → ep02000.pt |
| ep 2000 | tr 1.000 | id 0.889 | ood 0.581 | gap +0.308 | ‖W‖ 1404.6 | rank 47.1 | IRM 0.0000 | sc 0.98x |
| ep 2050 | tr 1.000 | id 0.888 | ood 0.577 | gap +0.311 | ‖W‖ 1393.4 | rank 46.6 | IRM 0.0000 | sc 0.98x |
| ep 2100 | tr 1.000 | id 0.884 | ood 0.617 | gap +0.266 | ‖W‖ 1460.6 | rank 33.9 | IRM 0.0000 | sc 1.00x |
| ep 2150 | tr 1.000 | id 0.870 | ood 0.597 | gap +0.273 | ‖W‖ 1470.9 | rank 37.5 | IRM 0.0000 | sc 1.00x |
| ✓ Checkpoint → ep02200.pt |
| ep 2200 | tr 1.000 | id 0.869 | ood 0.568 | gap +0.301 | ‖W‖ 1460.2 | rank 38.5 | IRM 0.0000 | sc 0.99x |
| ep 2250 | tr 1.000 | id 0.870 | ood 0.588 | gap +0.282 | ‖W‖ 1448.6 | rank 36.9 | IRM 0.0000 | sc 0.98x |
| ep 2300 | tr 0.998 | id 0.872 | ood 0.706 | gap +0.166 | ‖W‖ 1485.2 | rank 41.9 | IRM 0.0000 | sc 0.97x |
| ep 2350 | tr 1.000 | id 0.877 | ood 0.648 | gap +0.229 | ‖W‖ 1506.9 | rank 41.2 | IRM 0.0000 | sc 1.00x |
| ✓ Checkpoint → ep02400.pt |
| ep 2400 | tr 1.000 | id 0.876 | ood 0.650 | gap +0.226 | ‖W‖ 1495.0 | rank 40.1 | IRM 0.0000 | sc 1.00x |
| ep 2450 | tr 1.000 | id 0.869 | ood 0.682 | gap +0.187 | ‖W‖ 1486.6 | rank 36.8 | IRM 0.0000 | sc 0.98x |
| ep 2500 | tr 1.000 | id 0.884 | ood 0.621 | gap +0.263 | ‖W‖ 1510.2 | rank 40.7 | IRM 0.0000 | sc 1.00x |
| ep 2550 | tr 1.000 | id 0.876 | ood 0.667 | gap +0.209 | ‖W‖ 1499.1 | rank 38.9 | IRM 0.0000 | sc 0.99x |
| ✓ Checkpoint → ep02600.pt |
| ep 2600 | tr 1.000 | id 0.871 | ood 0.635 | gap +0.236 | ‖W‖ 1506.6 | rank 39.4 | IRM 0.0000 | sc 1.01x |
| ep 2650 | tr 1.000 | id 0.874 | ood 0.692 | gap +0.183 | ‖W‖ 1500.6 | rank 36.8 | IRM 0.0000 | sc 0.99x |
| ep 2700 | tr 1.000 | id 0.877 | ood 0.607 | gap +0.269 | ‖W‖ 1510.5 | rank 39.9 | IRM 0.0000 | sc 0.97x |
| ep 2750 | tr 1.000 | id 0.873 | ood 0.599 | gap +0.273 | ‖W‖ 1521.2 | rank 38.5 | IRM 0.0000 | sc 1.00x |
| ✓ Checkpoint → ep02800.pt |
| ep 2800 | tr 1.000 | id 0.878 | ood 0.584 | gap +0.295 | ‖W‖ 1515.0 | rank 36.9 | IRM 0.0000 | sc 1.00x |
| ep 2850 | tr 1.000 | id 0.878 | ood 0.619 | gap +0.259 | ‖W‖ 1504.0 | rank 36.2 | IRM 0.0000 | sc 0.99x |
| ep 2900 | tr 1.000 | id 0.880 | ood 0.614 | gap +0.266 | ‖W‖ 1492.6 | rank 35.7 | IRM 0.0000 | sc 0.99x |
| ep 2950 | tr 1.000 | id 0.871 | ood 0.619 | gap +0.252 | ‖W‖ 1482.1 | rank 36.4 | IRM 0.0000 | sc 0.99x |
| ✓ Checkpoint → ep03000.pt |
| ep 3000 | tr 1.000 | id 0.871 | ood 0.664 | gap +0.207 | ‖W‖ 1470.5 | rank 35.9 | IRM 0.0000 | sc 0.99x |
|
|
| Best ID val (H3): 0.8976 |
| Best OOD (H4): 0.7336 |
| OOD improvement: +0.0114 ← did OOD grok? |
| Grokking at: None |
| IRM drop: 100.0% |
|
|
| Wall time: 358.5 min |
|
|
| ``` |
| |
| --- |
| |
| ## 19. Full training log: standard n=1000 seed=42 |
| |
| ``` |
| # launched: 2026-05-05T10:07:20Z |
| # host: ubuntu-Standard-PC-Q35-ICH9-2009 |
| # pwd: /home/garima/CausalGrok |
| # cmd: /home/garima/anaconda3/envs/causalgrok/bin/python -u -m experiments.causalgrok_camelyon_v2 --condition standard --n_train 1000 --seed 42 --run_dir experiments/runs/20260505-100720_standard_n1000_s42 --wandb_project causalgrok --wandb_mode offline --n_epochs 3000 |
| ---- |
|
|
| Device: cuda |
| Run ID: 20260505-100720_standard_n1000_s42 |
| Started: 2026-05-05T10:07:36.596218+00:00 |
| Env hospital=0: 181 samples, positive rate=0.53 |
| Env hospital=3: 371 samples, positive rate=0.48 |
| Env hospital=4: 448 samples, positive rate=0.50 |
| Train: 1000 | ID val (H3): 33560 | OOD test (H4): 85054 |
| Params: 11,177,538 |
| |
| ============================================================ |
| STANDARD | Camelyon17 v2 | 3000 epochs |
| WD=0.0001 | α=1.0 | n=1000 |
| Tracking: ID val (H3) + OOD test (H4) at every checkpoint |
| Grokking detection: watching OOD acc, not ID val acc |
| IRM envs: 3 hospitals |
| ============================================================ |
| ep 1 | tr 0.681 | id 0.664 | ood 0.762 | gap -0.098 | ‖W‖ 105.0 | rank 78.5 | IRM 0.1490 | sc 1.25x |
| ep 50 | tr 0.999 | id 0.897 | ood 0.620 | gap +0.276 | ‖W‖ 162.5 | rank 45.1 | IRM 0.0000 | sc 0.99x |
| ep 100 | tr 1.000 | id 0.901 | ood 0.722 | gap +0.179 | ‖W‖ 196.1 | rank 35.1 | IRM 0.0000 | sc 0.98x |
| ep 150 | tr 0.997 | id 0.880 | ood 0.648 | gap +0.232 | ‖W‖ 228.8 | rank 29.4 | IRM 0.0000 | sc 0.98x |
| ✓ Checkpoint → ep00200.pt |
| ep 200 | tr 1.000 | id 0.886 | ood 0.611 | gap +0.274 | ‖W‖ 268.7 | rank 30.1 | IRM 0.0000 | sc 0.99x |
| ep 250 | tr 1.000 | id 0.890 | ood 0.672 | gap +0.218 | ‖W‖ 294.3 | rank 30.9 | IRM 0.0000 | sc 0.97x |
| ep 300 | tr 1.000 | id 0.900 | ood 0.684 | gap +0.216 | ‖W‖ 323.3 | rank 26.7 | IRM 0.0000 | sc 0.98x |
| ep 350 | tr 1.000 | id 0.891 | ood 0.573 | gap +0.318 | ‖W‖ 343.5 | rank 27.9 | IRM 0.0000 | sc 0.98x |
| ✓ Checkpoint → ep00400.pt |
| ep 400 | tr 1.000 | id 0.885 | ood 0.642 | gap +0.243 | ‖W‖ 361.4 | rank 28.6 | IRM 0.0000 | sc 0.98x |
| ep 450 | tr 1.000 | id 0.894 | ood 0.700 | gap +0.194 | ‖W‖ 377.9 | rank 31.3 | IRM 0.0000 | sc 0.98x |
| ep 500 | tr 1.000 | id 0.890 | ood 0.705 | gap +0.185 | ‖W‖ 378.2 | rank 29.3 | IRM 0.0000 | sc 0.97x |
| ep 550 | tr 1.000 | id 0.895 | ood 0.656 | gap +0.239 | ‖W‖ 412.9 | rank 26.5 | IRM 0.0000 | sc 0.97x |
| ✓ Checkpoint → ep00600.pt |
| ep 600 | tr 1.000 | id 0.862 | ood 0.717 | gap +0.145 | ‖W‖ 426.0 | rank 29.7 | IRM 0.0000 | sc 0.97x |
| ep 650 | tr 1.000 | id 0.885 | ood 0.713 | gap +0.172 | ‖W‖ 445.0 | rank 25.9 | IRM 0.0000 | sc 0.99x |
| ep 700 | tr 1.000 | id 0.892 | ood 0.639 | gap +0.253 | ‖W‖ 454.4 | rank 28.0 | IRM 0.0000 | sc 0.98x |
| ep 750 | tr 1.000 | id 0.880 | ood 0.648 | gap +0.232 | ‖W‖ 472.1 | rank 25.5 | IRM 0.0000 | sc 0.98x |
| ✓ Checkpoint → ep00800.pt |
| ep 800 | tr 1.000 | id 0.888 | ood 0.681 | gap +0.207 | ‖W‖ 489.6 | rank 28.8 | IRM 0.0000 | sc 0.97x |
| ep 850 | tr 1.000 | id 0.887 | ood 0.626 | gap +0.262 | ‖W‖ 506.4 | rank 28.2 | IRM 0.0000 | sc 0.99x |
| ep 900 | tr 1.000 | id 0.888 | ood 0.703 | gap +0.185 | ‖W‖ 515.4 | rank 31.3 | IRM 0.0000 | sc 0.99x |
| ep 950 | tr 1.000 | id 0.883 | ood 0.667 | gap +0.215 | ‖W‖ 526.2 | rank 27.6 | IRM 0.0000 | sc 1.00x |
| ✓ Checkpoint → ep01000.pt |
| ep 1000 | tr 1.000 | id 0.897 | ood 0.674 | gap +0.222 | ‖W‖ 530.5 | rank 26.2 | IRM 0.0000 | sc 1.00x |
| ep 1050 | tr 1.000 | id 0.896 | ood 0.581 | gap +0.315 | ‖W‖ 544.2 | rank 26.3 | IRM 0.0000 | sc 0.98x |
| ep 1100 | tr 1.000 | id 0.877 | ood 0.655 | gap +0.222 | ‖W‖ 559.3 | rank 26.2 | IRM 0.0000 | sc 1.00x |
| ep 1150 | tr 1.000 | id 0.899 | ood 0.694 | gap +0.205 | ‖W‖ 562.6 | rank 23.7 | IRM 0.0000 | sc 0.99x |
| ✓ Checkpoint → ep01200.pt |
| ep 1200 | tr 1.000 | id 0.892 | ood 0.578 | gap +0.315 | ‖W‖ 581.9 | rank 27.0 | IRM 0.0000 | sc 0.98x |
| ep 1250 | tr 1.000 | id 0.889 | ood 0.647 | gap +0.242 | ‖W‖ 595.8 | rank 30.1 | IRM 0.0000 | sc 0.99x |
| ep 1300 | tr 1.000 | id 0.880 | ood 0.616 | gap +0.264 | ‖W‖ 604.9 | rank 31.7 | IRM 0.0000 | sc 0.98x |
| ep 1350 | tr 1.000 | id 0.878 | ood 0.649 | gap +0.228 | ‖W‖ 606.3 | rank 30.0 | IRM 0.0000 | sc 0.98x |
| ✓ Checkpoint → ep01400.pt |
| ep 1400 | tr 1.000 | id 0.879 | ood 0.735 | gap +0.144 | ‖W‖ 624.0 | rank 32.7 | IRM 0.0000 | sc 0.96x |
| ep 1450 | tr 1.000 | id 0.885 | ood 0.703 | gap +0.182 | ‖W‖ 625.3 | rank 33.0 | IRM 0.0000 | sc 0.98x |
| ep 1500 | tr 1.000 | id 0.894 | ood 0.686 | gap +0.207 | ‖W‖ 626.2 | rank 30.6 | IRM 0.0000 | sc 0.98x |
| ep 1550 | tr 1.000 | id 0.879 | ood 0.670 | gap +0.209 | ‖W‖ 640.3 | rank 32.3 | IRM 0.0000 | sc 0.97x |
| ✓ Checkpoint → ep01600.pt |
| ep 1600 | tr 1.000 | id 0.871 | ood 0.669 | gap +0.202 | ‖W‖ 653.6 | rank 31.0 | IRM 0.0000 | sc 0.98x |
| ep 1650 | tr 1.000 | id 0.886 | ood 0.540 | gap +0.346 | ‖W‖ 662.7 | rank 34.4 | IRM 0.0000 | sc 0.98x |
| ep 1700 | tr 1.000 | id 0.897 | ood 0.659 | gap +0.239 | ‖W‖ 667.4 | rank 38.2 | IRM 0.0000 | sc 0.97x |
| ep 1750 | tr 1.000 | id 0.871 | ood 0.716 | gap +0.155 | ‖W‖ 678.7 | rank 31.7 | IRM 0.0000 | sc 0.97x |
| ✓ Checkpoint → ep01800.pt |
| ep 1800 | tr 1.000 | id 0.887 | ood 0.665 | gap +0.222 | ‖W‖ 682.0 | rank 33.0 | IRM 0.0000 | sc 0.98x |
| ep 1850 | tr 1.000 | id 0.892 | ood 0.598 | gap +0.294 | ‖W‖ 685.4 | rank 31.1 | IRM 0.0000 | sc 0.98x |
| ep 1900 | tr 1.000 | id 0.890 | ood 0.574 | gap +0.316 | ‖W‖ 690.9 | rank 32.8 | IRM 0.0000 | sc 0.99x |
| ep 1950 | tr 1.000 | id 0.889 | ood 0.623 | gap +0.265 | ‖W‖ 706.3 | rank 30.0 | IRM 0.0000 | sc 0.98x |
| ✓ Checkpoint → ep02000.pt |
| ep 2000 | tr 1.000 | id 0.890 | ood 0.595 | gap +0.296 | ‖W‖ 714.3 | rank 31.5 | IRM 0.0000 | sc 0.98x |
| ep 2050 | tr 1.000 | id 0.888 | ood 0.596 | gap +0.292 | ‖W‖ 720.1 | rank 31.4 | IRM 0.0000 | sc 0.99x |
| ep 2100 | tr 1.000 | id 0.860 | ood 0.686 | gap +0.173 | ‖W‖ 730.1 | rank 30.6 | IRM 0.0000 | sc 0.99x |
| ep 2150 | tr 1.000 | id 0.892 | ood 0.712 | gap +0.180 | ‖W‖ 732.0 | rank 28.5 | IRM 0.0000 | sc 0.99x |
| ✓ Checkpoint → ep02200.pt |
| ep 2200 | tr 1.000 | id 0.881 | ood 0.621 | gap +0.260 | ‖W‖ 736.6 | rank 31.0 | IRM 0.0000 | sc 0.99x |
| ep 2250 | tr 1.000 | id 0.882 | ood 0.621 | gap +0.261 | ‖W‖ 736.5 | rank 29.6 | IRM 0.0000 | sc 0.99x |
| ep 2300 | tr 1.000 | id 0.877 | ood 0.641 | gap +0.236 | ‖W‖ 748.9 | rank 32.6 | IRM 0.0000 | sc 0.99x |
| ep 2350 | tr 1.000 | id 0.883 | ood 0.652 | gap +0.230 | ‖W‖ 753.4 | rank 35.2 | IRM 0.0000 | sc 0.99x |
| ✓ Checkpoint → ep02400.pt |
| ep 2400 | tr 1.000 | id 0.884 | ood 0.673 | gap +0.212 | ‖W‖ 753.4 | rank 34.2 | IRM 0.0000 | sc 1.00x |
| ep 2450 | tr 1.000 | id 0.887 | ood 0.677 | gap +0.209 | ‖W‖ 755.5 | rank 32.2 | IRM 0.0000 | sc 1.00x |
| ep 2500 | tr 1.000 | id 0.887 | ood 0.601 | gap +0.286 | ‖W‖ 768.2 | rank 35.0 | IRM 0.0000 | sc 0.98x |
| ep 2550 | tr 1.000 | id 0.880 | ood 0.645 | gap +0.235 | ‖W‖ 772.8 | rank 35.5 | IRM 0.0000 | sc 0.98x |
| ✓ Checkpoint → ep02600.pt |
| ep 2600 | tr 1.000 | id 0.884 | ood 0.631 | gap +0.253 | ‖W‖ 774.0 | rank 34.8 | IRM 0.0000 | sc 0.99x |
| ep 2650 | tr 1.000 | id 0.885 | ood 0.584 | gap +0.301 | ‖W‖ 776.4 | rank 37.4 | IRM 0.0000 | sc 0.99x |
| ep 2700 | tr 1.000 | id 0.888 | ood 0.597 | gap +0.291 | ‖W‖ 780.4 | rank 36.2 | IRM 0.0000 | sc 0.99x |
| ep 2750 | tr 1.000 | id 0.895 | ood 0.683 | gap +0.212 | ‖W‖ 786.8 | rank 35.2 | IRM 0.0000 | sc 0.99x |
| ✓ Checkpoint → ep02800.pt |
| ep 2800 | tr 1.000 | id 0.894 | ood 0.597 | gap +0.297 | ‖W‖ 794.5 | rank 34.0 | IRM 0.0000 | sc 0.99x |
| ep 2850 | tr 1.000 | id 0.874 | ood 0.623 | gap +0.252 | ‖W‖ 803.4 | rank 32.7 | IRM 0.0000 | sc 0.99x |
| ep 2900 | tr 1.000 | id 0.891 | ood 0.679 | gap +0.212 | ‖W‖ 808.3 | rank 35.3 | IRM 0.0000 | sc 0.98x |
| ep 2950 | tr 1.000 | id 0.886 | ood 0.704 | gap +0.182 | ‖W‖ 811.6 | rank 33.7 | IRM 0.0000 | sc 1.00x |
| ✓ Checkpoint → ep03000.pt |
| ep 3000 | tr 1.000 | id 0.896 | ood 0.648 | gap +0.247 | ‖W‖ 812.6 | rank 33.3 | IRM 0.0000 | sc 0.99x |
| |
| Best ID val (H3): 0.9011 |
| Best OOD (H4): 0.7615 |
| OOD improvement: -0.0224 ← did OOD grok? |
| Grokking at: None |
| IRM drop: 100.0% |
| |
| Wall time: 327.3 min |
| |
| ``` |
| |
| --- |
| |
| ## 20. M5 — Full activation-steering JSONs (8 runs at n=1000) |
| |
| Every run's full `m5_steering_ep*.json`, verbatim from disk. `head_ood_acc` is head OOD accuracy on the steered features `h' = h + alpha * sigma * v_s`. `tumor_probe` is the linear-probe accuracy on the same steered features. `hospital_probe` is NaN by construction (H4 hospital labels do not overlap with the training-hospital probe's class set). |
|
|
| ### `20260505-080445_grokking_n1000_s42` |
| |
| ```json |
| { |
| "run_id": "20260505-080445_grokking_n1000_s42", |
| "epoch": 400, |
| "layer": "avgpool", |
| "max_samples": 800, |
| "v_norm": 1.0, |
| "sigma": 8.685188293457031, |
| "sweep": [ |
| { |
| "alpha": -3.0, |
| "head_ood_acc": 0.63, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6625 |
| }, |
| { |
| "alpha": -2.0, |
| "head_ood_acc": 0.625, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6525 |
| }, |
| { |
| "alpha": -1.0, |
| "head_ood_acc": 0.64, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.665 |
| }, |
| { |
| "alpha": -0.5, |
| "head_ood_acc": 0.6325, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.655 |
| }, |
| { |
| "alpha": 0.0, |
| "head_ood_acc": 0.6525, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.66 |
| }, |
| { |
| "alpha": 0.5, |
| "head_ood_acc": 0.6575, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.665 |
| }, |
| { |
| "alpha": 1.0, |
| "head_ood_acc": 0.6575, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6675 |
| }, |
| { |
| "alpha": 2.0, |
| "head_ood_acc": 0.62, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.65 |
| }, |
| { |
| "alpha": 3.0, |
| "head_ood_acc": 0.59, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.635 |
| } |
| ] |
| } |
| ``` |
| |
| ### `20260505-100720_grokking_n1000_s123` |
| |
| ```json |
| { |
| "run_id": "20260505-100720_grokking_n1000_s123", |
| "epoch": 400, |
| "layer": "avgpool", |
| "max_samples": 800, |
| "v_norm": 1.0, |
| "sigma": 6.000692844390869, |
| "sweep": [ |
| { |
| "alpha": -3.0, |
| "head_ood_acc": 0.7325, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6925 |
| }, |
| { |
| "alpha": -2.0, |
| "head_ood_acc": 0.7325, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6925 |
| }, |
| { |
| "alpha": -1.0, |
| "head_ood_acc": 0.73, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.695 |
| }, |
| { |
| "alpha": -0.5, |
| "head_ood_acc": 0.71, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.695 |
| }, |
| { |
| "alpha": 0.0, |
| "head_ood_acc": 0.6975, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6925 |
| }, |
| { |
| "alpha": 0.5, |
| "head_ood_acc": 0.6925, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.69 |
| }, |
| { |
| "alpha": 1.0, |
| "head_ood_acc": 0.69, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.685 |
| }, |
| { |
| "alpha": 2.0, |
| "head_ood_acc": 0.68, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.685 |
| }, |
| { |
| "alpha": 3.0, |
| "head_ood_acc": 0.6575, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.685 |
| } |
| ] |
| } |
| ``` |
| |
| ### `20260505-100720_grokking_n1000_s456` |
| |
| ```json |
| { |
| "run_id": "20260505-100720_grokking_n1000_s456", |
| "epoch": 1000, |
| "layer": "avgpool", |
| "max_samples": 800, |
| "v_norm": 1.0, |
| "sigma": 6.7488112449646, |
| "sweep": [ |
| { |
| "alpha": -3.0, |
| "head_ood_acc": 0.6825, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.63 |
| }, |
| { |
| "alpha": -2.0, |
| "head_ood_acc": 0.6575, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.615 |
| }, |
| { |
| "alpha": -1.0, |
| "head_ood_acc": 0.64, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.605 |
| }, |
| { |
| "alpha": -0.5, |
| "head_ood_acc": 0.63, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.595 |
| }, |
| { |
| "alpha": 0.0, |
| "head_ood_acc": 0.6225, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.595 |
| }, |
| { |
| "alpha": 0.5, |
| "head_ood_acc": 0.62, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.5925 |
| }, |
| { |
| "alpha": 1.0, |
| "head_ood_acc": 0.61, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.5925 |
| }, |
| { |
| "alpha": 2.0, |
| "head_ood_acc": 0.6025, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.5825 |
| }, |
| { |
| "alpha": 3.0, |
| "head_ood_acc": 0.5975, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.59 |
| } |
| ] |
| } |
| ``` |
| |
| ### `20260508-183413_grokking_n1000_s7` |
| |
| ```json |
| { |
| "run_id": "20260508-183413_grokking_n1000_s7", |
| "epoch": 200, |
| "layer": "avgpool", |
| "max_samples": 800, |
| "v_norm": 0.9999999403953552, |
| "sigma": 5.532620429992676, |
| "sweep": [ |
| { |
| "alpha": -3.0, |
| "head_ood_acc": 0.495, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.58 |
| }, |
| { |
| "alpha": -2.0, |
| "head_ood_acc": 0.49, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.5625 |
| }, |
| { |
| "alpha": -1.0, |
| "head_ood_acc": 0.485, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.5425 |
| }, |
| { |
| "alpha": -0.5, |
| "head_ood_acc": 0.485, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.5325 |
| }, |
| { |
| "alpha": 0.0, |
| "head_ood_acc": 0.485, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.52 |
| }, |
| { |
| "alpha": 0.5, |
| "head_ood_acc": 0.4875, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.5125 |
| }, |
| { |
| "alpha": 1.0, |
| "head_ood_acc": 0.4875, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.52 |
| }, |
| { |
| "alpha": 2.0, |
| "head_ood_acc": 0.4775, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.51 |
| }, |
| { |
| "alpha": 3.0, |
| "head_ood_acc": 0.4825, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.51 |
| } |
| ] |
| } |
| ``` |
| |
| ### `20260508-183413_grokking_n1000_s2024` |
| |
| ```json |
| { |
| "run_id": "20260508-183413_grokking_n1000_s2024", |
| "epoch": 400, |
| "layer": "avgpool", |
| "max_samples": 800, |
| "v_norm": 1.0, |
| "sigma": 9.224357604980469, |
| "sweep": [ |
| { |
| "alpha": -3.0, |
| "head_ood_acc": 0.7425, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.69 |
| }, |
| { |
| "alpha": -2.0, |
| "head_ood_acc": 0.7775, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.69 |
| }, |
| { |
| "alpha": -1.0, |
| "head_ood_acc": 0.7325, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.69 |
| }, |
| { |
| "alpha": -0.5, |
| "head_ood_acc": 0.715, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.69 |
| }, |
| { |
| "alpha": 0.0, |
| "head_ood_acc": 0.7125, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.69 |
| }, |
| { |
| "alpha": 0.5, |
| "head_ood_acc": 0.705, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.69 |
| }, |
| { |
| "alpha": 1.0, |
| "head_ood_acc": 0.68, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.69 |
| }, |
| { |
| "alpha": 2.0, |
| "head_ood_acc": 0.6225, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.69 |
| }, |
| { |
| "alpha": 3.0, |
| "head_ood_acc": 0.595, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.69 |
| } |
| ] |
| } |
| ``` |
| |
| ### `20260505-100720_standard_n1000_s42` |
| |
| ```json |
| { |
| "run_id": "20260505-100720_standard_n1000_s42", |
| "epoch": 200, |
| "layer": "avgpool", |
| "max_samples": 800, |
| "v_norm": 1.0, |
| "sigma": 7.23820686340332, |
| "sweep": [ |
| { |
| "alpha": -3.0, |
| "head_ood_acc": 0.585, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6125 |
| }, |
| { |
| "alpha": -2.0, |
| "head_ood_acc": 0.6, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.62 |
| }, |
| { |
| "alpha": -1.0, |
| "head_ood_acc": 0.61, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.625 |
| }, |
| { |
| "alpha": -0.5, |
| "head_ood_acc": 0.6175, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.63 |
| }, |
| { |
| "alpha": 0.0, |
| "head_ood_acc": 0.6175, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.625 |
| }, |
| { |
| "alpha": 0.5, |
| "head_ood_acc": 0.6175, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6225 |
| }, |
| { |
| "alpha": 1.0, |
| "head_ood_acc": 0.61, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6275 |
| }, |
| { |
| "alpha": 2.0, |
| "head_ood_acc": 0.6025, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6325 |
| }, |
| { |
| "alpha": 3.0, |
| "head_ood_acc": 0.5925, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.625 |
| } |
| ] |
| } |
| ``` |
| |
| ### `20260508-183413_standard_n1000_s123` |
| |
| ```json |
| { |
| "run_id": "20260508-183413_standard_n1000_s123", |
| "epoch": 200, |
| "layer": "avgpool", |
| "max_samples": 800, |
| "v_norm": 1.0, |
| "sigma": 13.366204261779785, |
| "sweep": [ |
| { |
| "alpha": -3.0, |
| "head_ood_acc": 0.5425, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.565 |
| }, |
| { |
| "alpha": -2.0, |
| "head_ood_acc": 0.5975, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.615 |
| }, |
| { |
| "alpha": -1.0, |
| "head_ood_acc": 0.6725, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.68 |
| }, |
| { |
| "alpha": -0.5, |
| "head_ood_acc": 0.7125, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.71 |
| }, |
| { |
| "alpha": 0.0, |
| "head_ood_acc": 0.72, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.7375 |
| }, |
| { |
| "alpha": 0.5, |
| "head_ood_acc": 0.7175, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.7425 |
| }, |
| { |
| "alpha": 1.0, |
| "head_ood_acc": 0.6825, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.7425 |
| }, |
| { |
| "alpha": 2.0, |
| "head_ood_acc": 0.6275, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.695 |
| }, |
| { |
| "alpha": 3.0, |
| "head_ood_acc": 0.555, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6425 |
| } |
| ] |
| } |
| ``` |
| |
| ### `20260508-183413_standard_n1000_s456` |
| |
| ```json |
| { |
| "run_id": "20260508-183413_standard_n1000_s456", |
| "epoch": 1000, |
| "layer": "avgpool", |
| "max_samples": 800, |
| "v_norm": 1.0, |
| "sigma": 10.473133087158203, |
| "sweep": [ |
| { |
| "alpha": -3.0, |
| "head_ood_acc": 0.6375, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6475 |
| }, |
| { |
| "alpha": -2.0, |
| "head_ood_acc": 0.6575, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.66 |
| }, |
| { |
| "alpha": -1.0, |
| "head_ood_acc": 0.675, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.69 |
| }, |
| { |
| "alpha": -0.5, |
| "head_ood_acc": 0.6725, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.69 |
| }, |
| { |
| "alpha": 0.0, |
| "head_ood_acc": 0.6175, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.6575 |
| }, |
| { |
| "alpha": 0.5, |
| "head_ood_acc": 0.58, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.5975 |
| }, |
| { |
| "alpha": 1.0, |
| "head_ood_acc": 0.5375, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.585 |
| }, |
| { |
| "alpha": 2.0, |
| "head_ood_acc": 0.505, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.5225 |
| }, |
| { |
| "alpha": 3.0, |
| "head_ood_acc": 0.505, |
| "hospital_probe": NaN, |
| "tumor_probe": 0.51 |
| } |
| ] |
| } |
| ``` |
| |
| --- |
|
|
| ## 21. M5 — Aggregated sweep tables |
|
|
| ### Grokking-favorable |
|
|
| | α | s7 (ep200, σ=5.53) | s42 (ep400, σ=8.69) | s123 (ep400, σ=6.00) | s456 (ep1000, σ=6.75) | s2024 (ep400, σ=9.22) | |
| | --- | --- | --- | --- | --- | --- | |
| | −3.0 | 0.4950 | 0.6300 | 0.7325 | 0.6825 | 0.7425 | |
| | −2.0 | 0.4900 | 0.6250 | 0.7325 | 0.6575 | 0.7775 | |
| | −1.0 | 0.4850 | 0.6400 | 0.7300 | 0.6400 | 0.7325 | |
| | −0.5 | 0.4850 | 0.6325 | 0.7100 | 0.6300 | 0.7150 | |
| | 0.0 | 0.4850 | 0.6525 | 0.6975 | 0.6225 | 0.7125 | |
| | +0.5 | 0.4875 | 0.6575 | 0.6925 | 0.6200 | 0.7050 | |
| | +1.0 | 0.4875 | 0.6575 | 0.6900 | 0.6100 | 0.6800 | |
| | +2.0 | 0.4775 | 0.6200 | 0.6800 | 0.6025 | 0.6225 | |
| | +3.0 | 0.4825 | 0.5900 | 0.6575 | 0.5975 | 0.5950 | |
| | Strict mono? | yes | no (α=0 ≥ α=−3) | yes | yes | yes | |
|
|
| ### Standard |
|
|
| | α | s42 (ep200, σ=7.24) | s123 (ep200, σ=13.37) | s456 (ep1000, σ=10.47) | |
| | --- | --- | --- | --- | |
| | −3.0 | 0.5850 | 0.5425 | 0.6375 | |
| | −2.0 | 0.6000 | 0.5975 | 0.6575 | |
| | −1.0 | 0.6100 | 0.6725 | 0.6750 | |
| | −0.5 | 0.6175 | 0.7125 | 0.6725 | |
| | 0.0 | 0.6175 | 0.7200 | 0.6175 | |
| | +0.5 | 0.6175 | 0.7175 | 0.5800 | |
| | +1.0 | 0.6100 | 0.6825 | 0.5375 | |
| | +2.0 | 0.6025 | 0.6275 | 0.5050 | |
| | +3.0 | 0.5925 | 0.5550 | 0.5050 | |
| | Strict mono? | no (peak at α=0) | no (peak at α=0) | yes | |
|
|
| **Aggregates and statistics**: |
|
|
| - Strict monotonicity (`acc(−3) ≥ acc(0) ≥ acc(+3)`): **4/5 grokking** vs **1/3 standard**. |
| - Mean σ: grokking 7.24, standard 10.36 (1.43× ratio — the σ-scaling confound). |
| - Mean Δ(α=0→−3): grokking **+0.0225 ± 0.0276**, standard **−0.0633 ± 0.0835**. |
| - Mean Δ(α=0→+3): grokking **−0.0500 ± 0.039**, standard **−0.101 ± 0.058**. |
| - **Fisher exact one-sided p = 0.286**, **Mann-Whitney U one-sided p = 0.071** (continuous statistic, primary), binomial sign test p = 0.188. **None reaches p < 0.05.** |
|
|
| --- |
|
|
| ## 22. M6 — Full K-sweep results (per-seed, all K) |
|
|
| Aggregated from `paper_figures/m6_summary.csv`. Δ(targ−rand) is `head_OOD(top-K shortcut ablated) − mean(head_OOD over 5 K-random ablations)`. Positive = targeted shortcut ablation beats random. |
|
|
| ### Grokking-favorable (n=1000) |
|
|
| | K | s7 | s42 | s123 | s456 | s2024 | N_+ | |
| | --- | --- | --- | --- | --- | --- | --- | |
| | 0 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | — | |
| | 4 | +0.0015 | +0.0045 | +0.0015 | +0.0025 | −0.0005 | 4/5 | |
| | 8 | −0.0010 | +0.0025 | +0.0010 | 0.0000 | −0.0020 | 2/5 | |
| | 16 | +0.0005 | −0.0010 | +0.0055 | +0.0035 | −0.0015 | 3/5 | |
| | 32 | −0.0045 | −0.0025 | +0.0010 | +0.0045 | −0.0020 | 2/5 | |
| | 64 | −0.0015 | +0.0005 | +0.0115 | +0.0120 | −0.0115 | 3/5 | |
| | 128 | −0.0060 | +0.0005 | +0.0090 | +0.0015 | −0.0225 | 3/5 | |
| | 256 | −0.0345 | −0.0010 | +0.0060 | +0.0075 | −0.0205 | 2/5 | |
| |
| ### Standard (n=1000) |
| |
| | K | s42 | s123 | s456 | N_+ | |
| | --- | --- | --- | --- | --- | |
| | 0 | 0.0000 | 0.0000 | 0.0000 | — | |
| | 4 | +0.0015 | −0.0025 | −0.0025 | 1/3 | |
| | 8 | −0.0035 | −0.0010 | −0.0005 | 0/3 | |
| | 16 | −0.0030 | −0.0025 | +0.0025 | 1/3 | |
| | 32 | −0.0040 | −0.0055 | −0.0060 | 0/3 | |
| | 64 | −0.0035 | −0.0100 | −0.0040 | 0/3 | |
| | 128 | −0.0085 | −0.0090 | −0.0025 | 0/3 | |
| | 256 | −0.0115 | +0.0040 | +0.0075 | 2/3 | |
|
|
| ### Per-seed full ablation rows (K=64 and K=256) |
|
|
| ``` |
| grokking s42 K= 64: base=0.6575 short=0.6600 rand=0.6595±0.0043 morph=0.6475 Δshort-base=+0.0025 Δtarg-rand=+0.0005 |
| grokking s42 K=256: base=0.6575 short=0.6625 rand=0.6635±0.0054 morph=0.5500 Δshort-base=+0.0050 Δtarg-rand=−0.0010 |
| grokking s123 K= 64: base=0.6825 short=0.6950 rand=0.6835±0.0066 morph=0.6625 Δshort-base=+0.0125 Δtarg-rand=+0.0115 |
| grokking s123 K=256: base=0.6825 short=0.7275 rand=0.7215±0.0179 morph=0.5575 Δshort-base=+0.0450 Δtarg-rand=+0.0060 |
| grokking s456 K= 64: base=0.6450 short=0.6425 rand=0.6305±0.0120 morph=0.6550 Δshort-base=−0.0025 Δtarg-rand=+0.0120 |
| grokking s456 K=256: base=0.6450 short=0.6325 rand=0.6250±0.0105 morph=0.5275 Δshort-base=−0.0125 Δtarg-rand=+0.0075 |
| grokking s7 K= 64: base=0.4925 short=0.4875 rand=0.4890±0.0075 morph=0.4475 Δshort-base=−0.0050 Δtarg-rand=−0.0015 |
| grokking s7 K=256: base=0.4925 short=0.4650 rand=0.4995±0.0056 morph=0.3800 Δshort-base=−0.0275 Δtarg-rand=−0.0345 |
| grokking s2024 K= 64: base=0.7100 short=0.7125 rand=0.7240±0.0086 morph=0.7225 Δshort-base=+0.0025 Δtarg-rand=−0.0115 |
| grokking s2024 K=256: base=0.7100 short=0.7225 rand=0.7430±0.0129 morph=0.5050 Δshort-base=+0.0125 Δtarg-rand=−0.0205 |
| |
| standard s42 K= 64: base=0.6150 short=0.6125 rand=0.6160±0.0034 morph=0.6100 Δshort-base=−0.0025 Δtarg-rand=−0.0035 |
| standard s42 K=256: base=0.6150 short=0.6100 rand=0.6215±0.0020 morph=0.5950 Δshort-base=−0.0050 Δtarg-rand=−0.0115 |
| standard s123 K= 64: base=0.7225 short=0.7150 rand=0.7250±0.0016 morph=0.7125 Δshort-base=−0.0075 Δtarg-rand=−0.0100 |
| standard s123 K=256: base=0.7225 short=0.6900 rand=0.6860±0.0025 morph=0.6975 Δshort-base=−0.0325 Δtarg-rand=+0.0040 |
| standard s456 K= 64: base=0.5975 short=0.5975 rand=0.6015±0.0030 morph=0.5850 Δshort-base= 0.0000 Δtarg-rand=−0.0040 |
| standard s456 K=256: base=0.5975 short=0.5500 rand=0.5425±0.0055 morph=0.5525 Δshort-base=−0.0475 Δtarg-rand=+0.0075 |
| ``` |
|
|
| **Aggregates**: |
|
|
| - K=64: grokking **3/5** positive Δ(targ−rand) (mean +0.0022 ± 0.0088); standard **0/3** (mean −0.0058 ± 0.0030). Fisher one-sided p = 0.179. |
| - K=256: grokking 2/5 positive; standard 2/3 positive — essentially symmetric. |
| - ID accuracy stays within 0.01 of baseline across all targeted ablations. |
| - Random control averaged over 5 samplings per K (the main M6 weakness). |
|
|
| The complete 88-row reviewer CSV is at `paper_figures/m6_summary.csv` — every run, every K, every condition (baseline / shortcut / random ± sd / morphology). |
|
|
| --- |
|
|
| ## 23. Exact commands |
|
|
| Training (launched detached under nohup via `scripts/launch.sh`): |
|
|
| ```bash |
| python -u -m experiments.causalgrok_camelyon_v2 \ |
| --condition grokking --n_train 1000 --seed 42 \ |
| --run_dir experiments/runs/<run_id> \ |
| --wandb_project causalgrok --wandb_mode offline |
| # standard: --condition standard (wd/init/grokfast set automatically by get_config) |
| ``` |
|
|
| Mechanistic interpretability: |
|
|
| ```bash |
| python -m experiments.mechinterp_m1 --run_dir experiments/runs/<id> --data_root data/wilds |
| python -m experiments.mechinterp_m4_ablation --run_dir experiments/runs/<id> --data_root data/wilds --layer avgpool --all_epochs |
| python -m experiments.mechinterp_m5_steering --run_dir experiments/runs/<id> --data_root data/wilds |
| python -m experiments.mechinterp_m6_neuron_ablation --run_dir experiments/runs/<id> --data_root data/wilds \ |
| --ks "0,4,8,16,32,64,128,256" |
| ``` |
|
|
| Figures: |
|
|
| ```bash |
| python -m experiments.regenerate_all_figures # rebuilds all 7 paper figures from saved JSON in <2 min |
| ``` |
|
|
| --- |
|
|
| ## 24. Output layout (per run) |
|
|
| ``` |
| experiments/runs/<run_id>/ |
| ├── config.json # full hyperparameter config |
| ├── run.pid |
| ├── checkpoints/ |
| │ ├── ep00200.pt … ep03000.pt # 15 periodic checkpoints, ~44 MB each |
| │ └── final.pt |
| ├── logs/ |
| │ ├── train.log # launch cmd + per-checkpoint log lines |
| │ └── train.err |
| ├── results/ |
| │ ├── history.json # 61 per-checkpoint metric rows |
| │ └── summary.json # final-summary fields |
| ├── wandb/ # offline wandb run metadata |
| └── mechinterp/ |
| ├── m1_probe_data.json + heatmap/curves PNG |
| ├── m4_ablation_avgpool_trajectory.json + PNG |
| ├── m5_steering_ep<E>.json + PNG |
| └── m6_neuron_ablation_ep<E>.json + PNG |
| ``` |
|
|
| Aggregated: `paper_figures/m6_summary.csv` (88 data rows), `paper_figures/*.{png,pdf}` (7 figures). |
| Checkpoints (~10 GB, 240 `.pt` files) are mirrored to Hugging Face at `nileshsarkar-ai/CausalGrok`. |
|
|
| --- |
|
|
| *All numerical values in this document were read directly from on-disk `config.json` / `results/summary.json` / `results/history.json` / `mechinterp/*.json` / `paper_figures/m6_summary.csv`. Training logs in §18 and §19 are verbatim from each run's `logs/train.log`. Source code in §10–§15 is the exact code that ran for every reported result.* |
|
|