CausalGrok / docs /TRAINING_DETAILS.md
nileshsarkar-ai's picture
Sync TRAINING_DETAILS.md (full source + per-run JSONs + logs)
cc24355 verified

CausalGrok — Complete Training, Evaluation, and Mechanistic-Interpretability Reference

Paper: Interventional Analysis of Shortcut Geometry Under Grokking-Favorable Training.

This document is the complete preservation archive of the project. It contains the full source of every script that ran, every per-run config and summary, the full training logs of two reference runs, every M5 activation-steering result, the full M6 K-sweep data, and all formulas. All numerical values are read directly from on-disk config.json / results/summary.json / results/history.json / mechinterp/*.json / paper_figures/m6_summary.csv. All source code is the exact code that produced every reported result.

Contents

  1. Environment
  2. Dataset and data pipeline
  3. Model architecture and initialization
  4. Hyperparameter tables
  5. Loss function (cross-entropy) — formula
  6. IRM penalty (diagnostic) — formula
  7. Grokfast EMA — formula
  8. Training loop and checkpointing
  9. Evaluation metrics — formulas
  10. Full source: utils/grokfast.py
  11. Full source: experiments/causalgrok_camelyon_v2.py
  12. Full source: experiments/mechinterp_m1.py
  13. Full source: experiments/mechinterp_m4_ablation.py
  14. Full source: experiments/mechinterp_m5_steering.py
  15. Full source: experiments/mechinterp_m6_neuron_ablation.py
  16. Run inventory and summary results (14 runs)
  17. Per-run config.json and summary.json (all 14 runs)
  18. Full training log: grokking n=1000 seed=42
  19. Full training log: standard n=1000 seed=42
  20. M5 — Full activation-steering JSONs (8 runs at n=1000)
  21. M5 — Aggregated sweep tables
  22. M6 — Full K-sweep results (per-seed, all K)
  23. Exact commands
  24. Output layout

1. Environment

Item Value
Python env conda env: causalgrok (Python 3.10)
Framework PyTorch, timm, torchvision, scikit-learn, numpy, wandb (offline)
Model timm.create_model("resnet18", pretrained=False, num_classes=2)
Device CUDA (NVIDIA A100 80GB PCIe)
Precision TF32 (set_float32_matmul_precision("high"), cudnn.benchmark=True, allow_tf32=True)
Params 11,177,538
Dataset Camelyon17 via WILDS (utils.camelyon_data.get_camelyon_subsets, auto-download)
Wall time per run ~5.5 h (grokking n=1000 s42, 3000 epochs, A100)

2. Dataset and data pipeline

Camelyon17 (WILDS), H&E-stained histopathology patches, binary tumor label. Five hospitals; the WILDS split provides train hospitals, an in-distribution validation set, and a held-out OOD test hospital.

  • ID validation: 33,560 images.
  • OOD test (held-out hospital): 85,054 images.
  • Train: subsampled to n_train. The n=1000 seed-42 run drew hospitals {0, 3, 4} with 181 / 371 / 448 samples (positive rates 0.53 / 0.48 / 0.50).

Transforms: Resize((96, 96)), ToTensor, Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]). Train loader: batch_size=32, shuffle=True; eval loaders: batch_size=256, shuffle=False; num_workers=0, pin_memory=True.

IRM environments — one {x, y} dict per unique training hospital, built once from metadata[:, 0].

3. Model architecture and initialization

ResNet-18 (timm, no ImageNet pretraining), 96×96 input, 2-class head, 11,177,538 parameters. The grokking-favorable regime multiplies every multi-dim weight tensor by init_scale = 4.0 at initialization; standard uses init_scale = 1.0 (no rescaling). avgpool feature dimension D = 512. Six probed stages: stem, layer1, layer2, layer3, layer4, avgpool.

4. Hyperparameter tables

Hyperparameter Standard Grokking-favorable
Optimizer AdamW AdamW
Learning rate 1e-3 1e-3
Weight decay 1e-4 5e-3 (50×)
Epochs 3000 3000
Init scale 1.0 4.0
Grokfast EMA off on
Grokfast alpha (EMA decay) 0.98
Grokfast lamb (slow-grad amplification) 2.0
Gradient clip (max-norm) 1.0 1.0
Batch size 32 32
Image size 96 × 96 96 × 96
log_every (metric cadence) 50 epochs 50 epochs
Checkpoint cadence 200 epochs 200 epochs
IRM weight in loss 0.0 0.0

Three-axis confound: weight decay (50× ratio), init scale (4× ratio), and Grokfast EMA (on vs off) differ simultaneously between regimes. No single-axis ablation was run.

5. Loss function — cross-entropy

L = CE(f_theta(x), y) = - (1/B) * sum_i log [ exp(z_{i, y_i}) / sum_k exp(z_{i, k}) ]

where z = f_theta(x) are the logits. The training objective is pure CE for every reported run (irm_weight = 0.0).

6. IRM penalty (diagnostic, not in loss)

IRMv1 penalty (Arjovsky et al. 2019):

IRM_e = || grad_{w=1.0}  CE( w · f_theta(x^e), y^e ) ||^2
IRM   = (1/|E|) sum_{e in E} IRM_e

Logged at every checkpoint as irm_mean, irm_var. Collapses from ~0.20 to ~1e-13 within ~50–150 epochs in every run (a CE-memorization consequence, not invariance learning).

7. Grokfast EMA (grokking-favorable only)

Grokfast (Lee et al. 2024, arXiv:2405.20233):

g_ema  <-  alpha * g_ema + (1 - alpha) * g          # alpha = 0.98
g      <-  g + lamb * g_ema                         # lamb  = 2.0

Applied after loss.backward(), before optimizer.step().

8. Training loop and checkpointing

  • Loop over epoch ∈ [1, 3000].
  • Per minibatch: forward, CE loss, loss.backward(), Grokfast filter (grokking only), clip_grad_norm_(max_norm=1.0), optimizer.step().
  • Metrics computed every 50 epochs (+ epoch 1) → results/history.json (61 rows/run).
  • Checkpoints every 200 epochs → 15 .pt per run + final.pt.
  • OOD-aware early stopping exists but defaults off; all reported runs train the full 3000 epochs.
  • Grokking detection watches OOD-accuracy plateau-then-jump; grokking_epoch = -1 (never fires) for every run.

9. Evaluation metrics — formulas

Metric Definition
Accuracy argmax accuracy on a loader; train_acc, id_val_acc, ood_acc. ood_gap = id_val_acc - ood_acc.
Weight norm `
Effective feature rank exp(- sum_i ŝ_i log ŝ_i) where ŝ_i = s_i / sum_j s_j are the normalized SVD singular values of layer4 avgpool features on ≤300 samples.
Shortcut ratio min(border_conf / center_conf, 10). >1 = stain-reliant, <1 = tissue-reliant.
IRM penalty see §6 — computed across the 3 training-hospital environments per epoch.

Summary fields (summary.json): best_id_val, best_ood, peak_ood_epoch, final_ood, ood_delta (= final_ood − best_ood, ungrokking signal), ood_improvement, grokking_epoch, irm_drop_pct, irm_drop_epoch, epoch_gap, final_weight_norm, final_feature_rank, final_irm, final_shortcut_ratio, final_ood_gap.


10. Full source: utils/grokfast.py

"""
utils.grokfast — accelerated grokking by amplifying slow-varying gradient
components (Lee et al. 2024, arXiv:2405.20233).

Maintain an EMA of gradients across steps; the slow-EMA component
corresponds to the generalising circuit. Adding it back into the live
gradient (scaled by `lamb`) accelerates the grokking transition 20-100×.
"""

from __future__ import annotations


def gradfilter_ema(model, grads_ema, alpha: float = 0.98, lamb: float = 2.0):
    """
    Call this AFTER `loss.backward()` and BEFORE `optimizer.step()`.

    Args:
        model:     the network whose gradients we are filtering.
        grads_ema: dict {param_name: ema_grad}, or None on the first call.
        alpha:     EMA decay (0.98 → very slow, emphasises persistent grads).
        lamb:      amplification factor for the slow component.

    Returns:
        Updated `grads_ema` dict — pass it back in on the next step.
    """
    if grads_ema is None:
        grads_ema = {}

    for name, p in model.named_parameters():
        if p.requires_grad and p.grad is not None:
            if name not in grads_ema:
                grads_ema[name] = p.grad.data.detach().clone()
            else:
                grads_ema[name] = (
                    grads_ema[name] * alpha
                    + p.grad.data.detach() * (1 - alpha)
                )
            p.grad.data = p.grad.data + grads_ema[name] * lamb

    return grads_ema

11. Full source: experiments/causalgrok_camelyon_v2.py

"""
CausalGrok — Camelyon17 Training Loop v2
Nilesh

KEY CHANGE FROM v1:
    OOD test accuracy (H4 — unseen hospital) is now tracked at EVERY
    checkpoint, not just at the end. Grokking detection watches OOD acc,
    not ID val acc. This is the correct signal.

    The paper claim: after ID accuracy converges (fast, expected), the model
    undergoes a delayed phase transition in OOD generalization — grokking
    the cross-hospital invariant causal features. This co-occurs with a drop
    in IRM penalty. That is the grokking we care about for clinical deployment.

    Two curves to watch:
        val_acc  (H3 ID val)  — converges fast, expected ~0.86 by ep 50
        ood_acc  (H4 OOD test) — should plateau then JUMP (the grokking)

Run via:
    python -m experiments.causalgrok_camelyon_v2 --condition grokking --n_train 300
"""

from __future__ import annotations

import argparse
import json
import os
import time
from datetime import datetime, timezone

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import timm
try:
    import wandb
except ImportError:
    wandb = None

from utils.grokfast import gradfilter_ema
from utils.camelyon_data import get_camelyon_subsets
from utils.run_dir import make_run_dir, ensure_run_dir, save_config


# ──────────────────────────────────────────────
# CONFIG
# ──────────────────────────────────────────────

def get_config(condition):
    base = dict(
        seed=42, n_train=300, batch_size=32, img_size=96,
        n_classes=2, log_every=50,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )
    if condition == "standard":
        base.update(dict(
            condition="standard",
            lr=1e-3, weight_decay=1e-4,
            # Default 3000 epochs to match grokking config and the
            # paper's reported runs; previously defaulted to 300 which
            # made the standard baseline trivially under-trained
            # relative to grokking. See paper Limitations §M3.
            n_epochs=3000, init_scale=1.0, use_grokfast=False,
        ))
    elif condition == "grokking":
        base.update(dict(
            condition="grokking",
            lr=1e-3, weight_decay=5e-3,
            n_epochs=3000, init_scale=4.0, use_grokfast=True,
            grokfast_alpha=0.98, grokfast_lamb=2.0,
        ))
    return base


# ──────────────────────────────────────────────
# WILDS-SAFE METRICS
# All handle the (imgs, labels, metadata) 3-tuple WILDS batch format.
# ──────────────────────────────────────────────

@torch.no_grad()
def accuracy_wilds(model, loader, device, max_samples=None):
    model.eval()
    correct = total = 0
    for batch in loader:
        imgs   = batch[0].to(device)
        labels = batch[1].squeeze().long().to(device)
        preds  = model(imgs).argmax(1)
        correct += (preds == labels).sum().item()
        total   += labels.size(0)
        if max_samples and total >= max_samples:
            break
    return correct / max(total, 1)


@torch.no_grad()
def weight_norm_fn(model):
    return sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5


@torch.no_grad()
def feature_rank_wilds(model, loader, device, n=300):
    model.eval()
    feats = []

    def hook_fn(module, input, output):
        avg_pool = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1))
        feats.append(avg_pool.view(avg_pool.size(0), -1).cpu())

    hook  = model.layer4[-1].register_forward_hook(hook_fn)
    count = 0
    for batch in loader:
        model(batch[0].to(device))
        count += batch[0].size(0)
        if count >= n:
            break
    hook.remove()
    if not feats:
        return float("nan")
    F_mat = torch.cat(feats)[:n]
    try:
        _, s, _ = torch.svd(F_mat)
        s = s / (s.sum() + 1e-10)
        return torch.exp(-(s * torch.log(s + 1e-10)).sum()).item()
    except Exception:
        return float("nan")


@torch.no_grad()
def shortcut_ratio_wilds(model, loader, device, n_samples=200):
    """
    Stain shortcut proxy: compare model confidence on center crop
    (tissue — causal features) vs. border region (stain — spurious).

    sc > 1.0 = relying on border stain more than tissue (shortcut)
    sc < 1.0 = relying on tissue center more than stain (causal)

    The transition from > 1.0 to < 1.0 during training is the
    attribution-level signature of the grokking transition.
    """
    model.eval()
    cc, bc = [], []
    count  = 0
    for batch in loader:
        if count >= n_samples:
            break
        imgs   = batch[0].to(device)
        B, C, H, W = imgs.shape
        hs, he = H // 4, 3 * H // 4
        ws, we = W // 4, 3 * W // 4
        center = F.interpolate(
            imgs[:, :, hs:he, ws:we], size=(H, W),
            mode="bilinear", align_corners=False
        )
        border = imgs.clone()
        border[:, :, hs:he, ws:we] = 0.0
        cc.append(F.softmax(model(center), 1).max(1).values.mean().item())
        bc.append(F.softmax(model(border), 1).max(1).values.mean().item())
        count += imgs.size(0)
    cconf = float(np.mean(cc)) if cc else 0.5
    bconf = float(np.mean(bc)) if bc else 0.5
    return cconf, bconf


def irm_penalty_wilds(model, envs, device):
    """
    IRMv1 penalty across TRAINING hospital environments (H0-H2).
    Diagnostic version: uses create_graph=False, returns floats. Used as a
    monitoring metric only (logged per epoch).
    """
    model.eval()
    penalties = []
    for env in envs:
        w      = torch.tensor(1.0, requires_grad=True, device=device)
        logits = model(env["x"]) * w
        loss   = F.cross_entropy(logits, env["y"])
        grad   = torch.autograd.grad(loss, w, create_graph=False)[0]
        penalties.append(grad.item() ** 2)
    t = torch.tensor(penalties)
    return t.mean().item(), t.var().item()


def irm_penalty_train_time(logits_list, y_list):
    """
    IRMv1 penalty for use INSIDE the training loss (differentiable).
    Splits a batch by environment, computes per-env loss with a virtual
    scale variable, takes the squared gradient of each per-env loss w.r.t.
    that scale, returns the mean across envs.

    Args:
        logits_list: list of (per-env) logits tensors
        y_list: list of (per-env) label tensors

    Returns:
        scalar tensor (differentiable), the IRM penalty contribution.
    """
    penalty = 0.0
    n = 0
    for logits, y in zip(logits_list, y_list):
        if logits.shape[0] == 0:
            continue
        scale  = torch.tensor(1.0, requires_grad=True, device=logits.device)
        loss   = F.cross_entropy(logits * scale, y)
        grad   = torch.autograd.grad(loss, scale, create_graph=True)[0]
        penalty = penalty + grad ** 2
        n += 1
    if n == 0:
        return torch.tensor(0.0, device=logits_list[0].device)
    return penalty / n


def eval_irm_penalty_wilds(model, id_val_loader, ood_test_loader, device):
    """
    IRM penalty evaluated on HELD-OUT environments (H3 and H4).
    This avoids the measurement artifact of training on H0-H2 where loss→0.
    HIGH penalty = model relies on hospital-discriminating features = shortcuts.
    LOW penalty  = model ignores hospital labels = causal features.
    """
    model.eval()
    penalties = []

    # Create environment views from eval data
    for loader, hospital_label in [
        (id_val_loader, "H3"),
        (ood_test_loader, "H4"),
    ]:
        xs, ys = [], []
        count = 0
        with torch.no_grad():
            for batch in loader:
                imgs  = batch[0].to(device)
                labels = batch[1].squeeze().long().to(device)
                xs.append(model(imgs))
                ys.append(labels)
                count += imgs.size(0)
                if count >= 500:
                    break
        if xs:
            x = torch.cat(xs)
            y = torch.cat(ys)
            w = torch.tensor(1.0, requires_grad=True, device=device)
            logits = x * w
            loss = F.cross_entropy(logits, y)
            try:
                grad = torch.autograd.grad(loss, w, create_graph=False)[0]
                penalties.append(grad.item() ** 2)
            except:
                penalties.append(float("nan"))

    if penalties and not any(np.isnan(p) for p in penalties):
        return float(np.mean(penalties)), float(np.var(penalties))
    else:
        return float("nan"), float("nan")


# ──────────────────────────────────────────────
# DATA
# ──────────────────────────────────────────────

class TransformWrapper:
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        img, label, metadata = self.dataset[idx]
        return self.transform(img), label, metadata


def get_dataloaders(cfg, data_root):
    transform = transforms.Compose([
        transforms.Resize((cfg["img_size"], cfg["img_size"])),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets(
        root_dir=data_root, download=True)

    # Subsample training set
    torch.manual_seed(cfg["seed"])
    indices      = torch.randperm(len(train_ds))[:cfg["n_train"]]
    train_subset = Subset(train_ds, indices)

    # Wrap with TransformWrapper to apply transforms
    train_subset = TransformWrapper(train_subset, transform)
    id_val_ds    = TransformWrapper(id_val_ds, transform)
    ood_test_ds  = TransformWrapper(ood_test_ds, transform)

    train_loader    = DataLoader(train_subset, batch_size=cfg["batch_size"],
                                 shuffle=True,  num_workers=0, pin_memory=True)
    id_val_loader   = DataLoader(id_val_ds,    batch_size=256,
                                 shuffle=False, num_workers=0, pin_memory=True)
    ood_test_loader = DataLoader(ood_test_ds,  batch_size=256,
                                 shuffle=False, num_workers=0, pin_memory=True)

    return train_loader, id_val_loader, ood_test_loader, train_subset


def get_hospital_environments(train_subset, device):
    """
    Build IRM environments from ground-truth hospital labels.
    Returns list of {x, y} dicts — one per unique hospital in the subset.
    Hospitals in Camelyon17 train split: 0, 1, 2.
    """
    loader = DataLoader(train_subset, batch_size=512,
                        shuffle=False, num_workers=4)
    all_imgs, all_labels, all_meta = [], [], []
    for imgs, labels, meta in loader:
        all_imgs.append(imgs)
        all_labels.append(labels.squeeze().long())
        all_meta.append(meta)

    all_imgs   = torch.cat(all_imgs)
    all_labels = torch.cat(all_labels)
    hospitals  = torch.cat(all_meta)[:, 0].long()  # field 0 = hospital ID

    envs = []
    for h in torch.unique(hospitals):
        mask = hospitals == h
        n    = mask.sum().item()
        envs.append({
            "x":        all_imgs[mask].to(device),
            "y":        all_labels[mask].to(device),
            "hospital": int(h),
        })
        pos_rate = all_labels[mask].float().mean().item()
        print(f"  Env hospital={int(h)}: {n} samples, "
              f"positive rate={pos_rate:.2f}")
    return envs


# ──────────────────────────────────────────────
# MODEL
# ──────────────────────────────────────────────

def build_model(cfg):
    model = timm.create_model("resnet18", pretrained=False,
                               num_classes=cfg["n_classes"])
    if cfg["init_scale"] != 1.0:
        with torch.no_grad():
            for name, p in model.named_parameters():
                if "weight" in name and p.dim() > 1:
                    p.data *= cfg["init_scale"]
    return model.to(cfg["device"])


# ──────────────────────────────────────────────
# TRAIN
# ──────────────────────────────────────────────

def train(cfg, model, train_loader, id_val_loader, ood_test_loader,
          envs, optimizer, run_dir):

    criterion       = nn.CrossEntropyLoss()
    grads_ema       = None
    history         = []
    best_id_val     = 0.0
    best_ood        = 0.0
    peak_ood_epoch  = None  # Epoch where best_ood was achieved
    grok_epoch      = None
    irm_base        = None
    history_path    = os.path.join(run_dir, "results", "history.json")
    grad_clip       = cfg.get("grad_clip", 1.0)

    # Grokking detection parameters.
    # We watch OOD accuracy (H4), not ID val accuracy (H3).
    # ID val converges fast (expected). OOD is what should grok.
    plateau_window  = 10
    plateau_eps     = 0.01

    # Ungrokking early stopping parameters.
    # If OOD peaks then declines, stop at the peak rather than training to convergence.
    ood_patience    = cfg.get("ood_patience", 20)   # checkpoints to wait before stopping
    ood_min_delta   = cfg.get("ood_min_delta", 0.01)  # minimum improvement threshold
    use_ood_early_stop = cfg.get("use_ood_early_stop", False)

    print(f"\n{'='*60}")
    print(f"  {cfg['condition'].upper()} | Camelyon17 v2 | {cfg['n_epochs']} epochs")
    print(f"  WD={cfg['weight_decay']} | α={cfg['init_scale']} | n={cfg['n_train']}")
    print(f"  Tracking: ID val (H3) + OOD test (H4) at every checkpoint")
    print(f"  Grokking detection: watching OOD acc, not ID val acc")
    print(f"  IRM envs: {len(envs)} hospitals")
    print(f"{'='*60}", flush=True)

    irm_weight = float(cfg.get("irm_weight", 0.0))
    use_irm_in_loss = irm_weight > 0.0
    if use_irm_in_loss:
        print(f"  IRM-in-loss: ENABLED, alpha={irm_weight}", flush=True)
    else:
        print(f"  IRM-in-loss: disabled (CE-only training; IRM penalty is diagnostic)", flush=True)

    for epoch in range(1, cfg["n_epochs"] + 1):
        # ── Train step ────────────────────────────────────────────────
        model.train()
        loss_sum = n_b = 0
        for imgs, labels, metadata in train_loader:
            imgs   = imgs.to(cfg["device"])
            labels = labels.squeeze().long().to(cfg["device"])
            optimizer.zero_grad()
            logits = model(imgs)
            ce_loss = criterion(logits, labels)

            if use_irm_in_loss:
                # Split this batch by training hospital (H0/H1/H2) and
                # compute IRMv1 penalty as a differentiable scalar.
                hosp_ids = metadata[:, 0].long().to(cfg["device"])
                logits_per_env, y_per_env = [], []
                for h in [0, 1, 2]:
                    mask = (hosp_ids == h)
                    if mask.sum() < 2:
                        continue
                    logits_per_env.append(logits[mask])
                    y_per_env.append(labels[mask])
                if len(logits_per_env) >= 2:
                    irm_term = irm_penalty_train_time(logits_per_env, y_per_env)
                    loss = ce_loss + irm_weight * irm_term
                else:
                    loss = ce_loss
            else:
                loss = ce_loss

            loss.backward()
            if cfg.get("use_grokfast"):
                grads_ema = gradfilter_ema(
                    model, grads_ema,
                    alpha=cfg.get("grokfast_alpha", 0.98),
                    lamb=cfg.get("grokfast_lamb", 2.0))
            if grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm=grad_clip)
            optimizer.step()
            loss_sum += loss.item()
            n_b      += 1

        # ── Checkpoint metrics ────────────────────────────────────────
        if epoch % cfg["log_every"] == 0 or epoch == 1:
            tr_acc  = accuracy_wilds(model, train_loader,   cfg["device"])
            id_acc  = accuracy_wilds(model, id_val_loader,  cfg["device"])
            ood_acc = accuracy_wilds(model, ood_test_loader, cfg["device"])  # KEY
            wn      = weight_norm_fn(model)
            fr      = feature_rank_wilds(model, id_val_loader, cfg["device"])
            irm_m, irm_v = irm_penalty_wilds(model, envs, cfg["device"])
            cconf, bconf  = shortcut_ratio_wilds(
                model, id_val_loader, cfg["device"])

            if irm_base is None:
                irm_base = irm_m

            # ── OOD grokking detection ────────────────────────────────
            # Require sustained plateau in OOD acc before the jump.
            # The ID val acc plateau is expected and not grokking.
            if grok_epoch is None and len(history) >= plateau_window:
                last = history[-plateau_window:]
                ref  = last[-1]["ood_acc"]
                flat = sum(1 for r in last
                           if abs(r["ood_acc"] - ref) < plateau_eps)
                if flat >= plateau_window - 2 and ood_acc > best_ood + 0.05:
                    grok_epoch = epoch
                    irm_drop   = (irm_base - irm_m) / (irm_base + 1e-8) * 100
                    print(f"\n  *** OOD GROKKING at epoch {epoch} ***")
                    print(f"      OOD: {best_ood:.3f}{ood_acc:.3f} | "
                          f"IRM drop: {irm_drop:.1f}%", flush=True)

            if id_acc  > best_id_val: best_id_val = id_acc
            if ood_acc > best_ood:
                best_ood      = ood_acc
                peak_ood_epoch = epoch  # Track when peak was achieved

            sc_ratio = min(bconf / (cconf + 1e-8), 10.0)

            # OOD gap: how much worse is OOD vs ID?
            # This should shrink at the grokking transition.
            ood_gap = id_acc - ood_acc

            row = dict(
                epoch          = epoch,
                train_loss     = loss_sum / n_b,
                train_acc      = tr_acc,
                id_val_acc     = id_acc,
                ood_acc        = ood_acc,       # ← primary grokking signal
                ood_gap        = ood_gap,       # ← should narrow at transition
                weight_norm    = wn,
                feature_rank   = fr,
                irm_mean       = irm_m,
                irm_var        = irm_v,
                center_conf    = cconf,
                border_conf    = bconf,
                shortcut_ratio = sc_ratio,
                grokking_detected = grok_epoch is not None,
            )
            history.append(row)
            if wandb:
                wandb.log(row)

            with open(history_path, "w") as f:
                json.dump(history, f, indent=2)

            # Save periodic checkpoint for M1 analysis (every 200 epochs)
            if epoch % 200 == 0:
                ckpt_dir = os.path.join(run_dir, "checkpoints")
                os.makedirs(ckpt_dir, exist_ok=True)
                ckpt_path = os.path.join(ckpt_dir, f"ep{epoch:05d}.pt")
                torch.save(model.state_dict(), ckpt_path)
                print(f"  ✓ Checkpoint → ep{epoch:05d}.pt", flush=True)

            # ── OOD-aware early stopping (if ungrokking detected) ───────
            # If OOD peaks then declines, stop at the peak rather than full epochs.
            if use_ood_early_stop and peak_ood_epoch is not None and len(history) >= ood_patience:
                recent_ood = [r["ood_acc"] for r in history[-ood_patience:]]
                ood_trend = max(recent_ood) - min(recent_ood)

                if ood_acc < best_ood - ood_min_delta:
                    print(f"\n  *** EARLY STOP (OOD declining) at epoch {epoch} ***", flush=True)
                    print(f"      Peak OOD: {best_ood:.4f} at epoch {peak_ood_epoch}", flush=True)
                    print(f"      Current:  {ood_acc:.4f} ({ood_acc-best_ood:+.4f})", flush=True)

                    # Save peak checkpoint separately for clinical deployment
                    if peak_ood_epoch and peak_ood_epoch % 200 == 0:
                        peak_src = os.path.join(run_dir, "checkpoints", f"ep{peak_ood_epoch:05d}.pt")
                        peak_dst = os.path.join(run_dir, "checkpoints", "peak_ood.pt")
                        if os.path.exists(peak_src):
                            import shutil
                            shutil.copy(peak_src, peak_dst)
                            print(f"      Saved peak → checkpoints/peak_ood.pt", flush=True)

                    break  # Exit training loop

            print(f"  ep {epoch:5d} | "
                  f"tr {tr_acc:.3f} | "
                  f"id {id_acc:.3f} | "
                  f"ood {ood_acc:.3f} | "
                  f"gap {ood_gap:+.3f} | "  # + means OOD worse than ID
                  f"‖W‖ {wn:.1f} | "
                  f"rank {fr:.1f} | "
                  f"IRM {irm_m:.4f} | "
                  f"sc {sc_ratio:.2f}x",
                  flush=True)

    # ── Final summary ─────────────────────────────────────────────────
    # One final OOD eval at the very end
    final_ood = accuracy_wilds(model, ood_test_loader, cfg["device"])
    if wandb:
        wandb.log({"final_ood_acc": final_ood,
                   "grokking_epoch": grok_epoch or -1})

    # Decision numbers
    irm_drop_pct = float("nan")
    irm_drop_ep  = epoch_gap = -1
    if history:
        irm0    = history[0]["irm_mean"]
        irm_min = min(r["irm_mean"] for r in history)
        if irm0:
            irm_drop_pct = (irm0 - irm_min) / (irm0 + 1e-8) * 100
        if len(history) > 1:
            biggest = 0.0
            for prev, cur in zip(history[:-1], history[1:]):
                d = abs(cur["irm_mean"] - prev["irm_mean"])
                if d > biggest:
                    biggest    = d
                    irm_drop_ep = cur["epoch"]
        if grok_epoch and irm_drop_ep > 0:
            epoch_gap = abs(grok_epoch - irm_drop_ep)

    # OOD grokking: did OOD acc improve significantly after ID convergence?
    # Measure: max OOD acc in last 20% of training vs. OOD acc when ID
    # first plateaued (epoch ~200-300 for standard training).
    ood_early = np.mean([r["ood_acc"] for r in history[:5]]) if history else 0
    ood_late  = np.mean([r["ood_acc"] for r in history[-5:]]) if history else 0
    ood_improvement = ood_late - ood_early

    # Ungrokking detection: did OOD collapse after peaking?
    ood_delta = final_ood - best_ood  # Negative = ungrokking

    summary = dict(
        run_id            = cfg["run_id"],
        condition         = cfg["condition"],
        n_train           = cfg["n_train"],
        seed              = cfg["seed"],
        best_id_val       = best_id_val,
        best_ood          = best_ood,
        peak_ood_epoch    = peak_ood_epoch or -1,  # When peak was achieved
        final_ood         = final_ood,
        ood_delta         = ood_delta,              # final - best (ungrokking signal)
        ood_improvement   = ood_improvement,  # ← key: did OOD grok?
        grokking_epoch    = grok_epoch or -1,
        irm_drop_pct      = irm_drop_pct,
        irm_drop_epoch    = irm_drop_ep,
        epoch_gap         = epoch_gap,
        final_weight_norm = history[-1]["weight_norm"]    if history else None,
        final_feature_rank= history[-1]["feature_rank"]   if history else None,
        final_irm         = history[-1]["irm_mean"]       if history else None,
        final_shortcut_ratio = history[-1]["shortcut_ratio"] if history else None,
        final_ood_gap     = history[-1]["ood_gap"]        if history else None,
    )
    with open(os.path.join(run_dir, "results", "summary.json"), "w") as f:
        json.dump(summary, f, indent=2)

    torch.save(model.state_dict(),
               os.path.join(run_dir, "checkpoints", "final.pt"))

    print(f"\n  Best ID val (H3): {best_id_val:.4f}")
    print(f"  Best OOD (H4):    {best_ood:.4f}")
    print(f"  OOD improvement:  {ood_improvement:+.4f}  ← did OOD grok?")
    print(f"  Grokking at:      {grok_epoch}")
    print(f"  IRM drop:         {irm_drop_pct:.1f}%",
          flush=True)
    return history


# ──────────────────────────────────────────────
# MAIN
# ──────────────────────────────────────────────

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--condition",     default="grokking",
                   choices=["standard", "grokking"])
    p.add_argument("--n_train",       type=int,   default=300)
    p.add_argument("--seed",          type=int,   default=42)
    p.add_argument("--log_every",     type=int,   default=50)
    p.add_argument("--wandb_project", default="causalgrok")
    p.add_argument("--wandb_mode",    default="offline",
                   choices=["online", "offline", "disabled"])
    p.add_argument("--run_dir",       default=None)
    p.add_argument("--data_root",     default="data/wilds")
    p.add_argument("--weight_decay",  type=float, default=None)
    p.add_argument("--init_scale",    type=float, default=None)
    p.add_argument("--n_epochs",      type=int,   default=None)
    p.add_argument("--lr",            type=float, default=None)
    p.add_argument("--grokfast",      choices=["on", "off"], default=None)
    p.add_argument("--grad_clip",     type=float, default=1.0)
    p.add_argument("--irm_weight",    type=float, default=0.0,
                   help="IRMv1 penalty weight added to training loss "
                        "(0 = pure cross-entropy / diagnostic-only IRM).")
    args = p.parse_args()

    cfg = get_config(args.condition)
    cfg.update(n_train=args.n_train, seed=args.seed,
               log_every=args.log_every, grad_clip=args.grad_clip)

    if args.weight_decay is not None: cfg["weight_decay"] = args.weight_decay
    if args.init_scale   is not None: cfg["init_scale"]   = args.init_scale
    if args.n_epochs     is not None: cfg["n_epochs"]     = args.n_epochs
    if args.lr           is not None: cfg["lr"]           = args.lr
    if args.grokfast     is not None: cfg["use_grokfast"] = (args.grokfast == "on")
    cfg["irm_weight"] = args.irm_weight

    if cfg["device"] == "cuda":
        torch.set_float32_matmul_precision("high")
        torch.backends.cudnn.benchmark         = True
        torch.backends.cuda.matmul.allow_tf32  = True
        torch.backends.cudnn.allow_tf32        = True

    torch.manual_seed(cfg["seed"])
    np.random.seed(cfg["seed"])

    if args.run_dir is None:
        run_dir, run_id = make_run_dir(
            ["camelyon_v2", cfg["condition"],
             f"n{cfg['n_train']}", f"s{cfg['seed']}"])
    else:
        run_dir = args.run_dir
        ensure_run_dir(run_dir)
        run_id  = os.path.basename(os.path.normpath(run_dir))

    cfg["run_id"]  = run_id
    cfg["run_dir"] = run_dir
    save_config(cfg, run_dir)

    if wandb:
        wandb.init(project=args.wandb_project, config=cfg, name=run_id,
                   mode=args.wandb_mode, dir=run_dir)

    print(f"\nDevice:  {cfg['device']}")
    print(f"Run ID:  {run_id}")
    print(f"Started: {datetime.now(timezone.utc).isoformat()}", flush=True)

    train_loader, id_val_loader, ood_test_loader, train_subset = \
        get_dataloaders(cfg, args.data_root)

    envs  = get_hospital_environments(train_subset, cfg["device"])
    model = build_model(cfg)

    print(f"Train: {len(train_subset)} | "
          f"ID val (H3): {len(id_val_loader.dataset)} | "
          f"OOD test (H4): {len(ood_test_loader.dataset)}")
    print(f"Params: {sum(p.numel() for p in model.parameters()):,}",
          flush=True)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=cfg["lr"], weight_decay=cfg["weight_decay"])

    t0 = time.time()
    train(cfg, model, train_loader, id_val_loader, ood_test_loader,
          envs, optimizer, run_dir)
    print(f"\nWall time: {(time.time()-t0)/60:.1f} min", flush=True)
    if wandb:
        wandb.finish()


if __name__ == "__main__":
    main()

12. Full source: experiments/mechinterp_m1.py

"""
CausalGrok — M1: Layer-wise Linear Probing
Nilesh

The mechanistic claim:
    Before grokking: hospital probe HIGH (model uses stain shortcut),
                     tumor probe LOW in early layers
    At transition:   hospital probe DROPS, tumor probe RISES
    After grokking:  inverted — tumor high, hospital low

    If OOD acc jump + hospital probe drop + tumor probe rise
    all happen at the same epoch → mechanistic claim proven.
    That's Figure 2 of the paper.

Usage:
    # Run on all saved checkpoints from a run
    python -m experiments.mechinterp_m1 \
        --run_dir experiments/runs/<run_id> \
        --data_root data/wilds

    # Run on latest checkpoint only (quick check while training)
    python -m experiments.mechinterp_m1 \
        --run_dir experiments/runs/<run_id> \
        --data_root data/wilds \
        --latest_only

    # Run on ALL camelyon_v2 grokking runs
    python -m experiments.mechinterp_m1 \
        --all_runs \
        --data_root data/wilds

Output per run:
    experiments/runs/<run_id>/mechinterp/
        m1_probe_heatmap.png       ← epoch × layer, hospital probe acc
        m1_tumor_heatmap.png       ← epoch × layer, tumor probe acc
        m1_probe_curves.png        ← hospital vs tumor probe over epochs (layer 4)
        m1_probe_data.json         ← raw numbers for paper tables
"""

from __future__ import annotations

import argparse
import glob
import json
import os
from typing import Dict, List, Optional

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import matplotlib
import matplotlib.pyplot as plt
import timm
import warnings
warnings.filterwarnings("ignore")

matplotlib.rcParams.update({"font.size": 11, "figure.dpi": 150})


# ──────────────────────────────────────────────
# RESNET-18 LAYER HOOKS
# Extract features after each of the 6 measurable stages:
#   stem → layer1 → layer2 → layer3 → layer4 → avgpool
# ──────────────────────────────────────────────

LAYER_NAMES = [
    "stem",     # After initial conv + bn + relu + maxpool
    "layer1",   # ResNet block 1 (64 channels)
    "layer2",   # ResNet block 2 (128 channels)
    "layer3",   # ResNet block 3 (256 channels)
    "layer4",   # ResNet block 4 (512 channels)
    "avgpool",  # Global average pool — penultimate representation
]


def register_hooks(model):
    """
    Register forward hooks on all 6 extraction points.
    Returns (hooks, features_dict).
    """
    features = {name: [] for name in LAYER_NAMES}
    hooks    = []

    def make_hook(name):
        def hook_fn(module, input, output):
            if output.dim() == 4:
                feat = output.mean(dim=[2, 3])
            else:
                feat = output.view(output.size(0), -1)
            features[name].append(feat.detach().cpu())
        return hook_fn

    hooks.append(model.maxpool.register_forward_hook(make_hook("stem")))
    hooks.append(model.layer1.register_forward_hook(make_hook("layer1")))
    hooks.append(model.layer2.register_forward_hook(make_hook("layer2")))
    hooks.append(model.layer3.register_forward_hook(make_hook("layer3")))
    hooks.append(model.layer4.register_forward_hook(make_hook("layer4")))
    hooks.append(model.global_pool.register_forward_hook(make_hook("avgpool")))

    return hooks, features


def extract_features(model, loader, device, max_samples=1000):
    """
    Run forward pass and collect features at all 6 layers.
    """
    model.eval()
    hooks, feat_dict = register_hooks(model)

    all_hospital = []
    all_tumor    = []
    count        = 0

    with torch.no_grad():
        for batch in loader:
            imgs     = batch[0].to(device)
            labels   = batch[1].squeeze().long()
            metadata = batch[2]
            model(imgs)
            all_hospital.append(metadata[:, 0].long())
            all_tumor.append(labels)
            count += imgs.size(0)
            if count >= max_samples:
                break

    for h in hooks:
        h.remove()

    features     = {k: torch.cat(v).numpy() for k, v in feat_dict.items()}
    hospital_ids = torch.cat(all_hospital).numpy()
    tumor_labels = torch.cat(all_tumor).numpy()

    n = min(max_samples, len(hospital_ids))
    features     = {k: v[:n] for k, v in features.items()}
    hospital_ids = hospital_ids[:n]
    tumor_labels = tumor_labels[:n]

    return features, hospital_ids, tumor_labels


def train_probe(X_train, y_train, X_val, y_val):
    """
    Train logistic regression probe on frozen features.
    """
    if len(np.unique(y_train)) < 2:
        return 0.5

    scaler  = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_val   = scaler.transform(X_val)

    clf = LogisticRegression(
        max_iter=500,
        C=1.0,
        solver="lbfgs",
        multi_class="auto",
        n_jobs=-1,
    )
    try:
        clf.fit(X_train, y_train)
        return clf.score(X_val, y_val)
    except Exception:
        return float("nan")


# ──────────────────────────────────────────────
# CHECKPOINT DISCOVERY
# ──────────────────────────────────────────────

def find_checkpoints(run_dir: str) -> List[tuple]:
    """
    Find all checkpoints in a run directory.
    Returns list of (epoch, checkpoint_path) sorted by epoch.
    """
    ckpt_dir = os.path.join(run_dir, "checkpoints")
    if not os.path.isdir(ckpt_dir):
        return []

    checkpoints = []

    # Periodic checkpoints: ep050.pt, ep100.pt, etc.
    for f in sorted(glob.glob(os.path.join(ckpt_dir, "ep*.pt"))):
        epoch_str = os.path.basename(f).replace("ep", "").replace(".pt", "")
        try:
            epoch = int(epoch_str)
            checkpoints.append((epoch, f))
        except ValueError:
            continue

    # Final checkpoint
    final = os.path.join(ckpt_dir, "final.pt")
    if os.path.isfile(final):
        hist_path = os.path.join(run_dir, "results", "history.json")
        if os.path.isfile(hist_path):
            try:
                hist  = json.load(open(hist_path))
                epoch = hist[-1]["epoch"] if hist else 9999
            except Exception:
                epoch = 9999
        else:
            epoch = 9999
        checkpoints.append((epoch, final))

    return sorted(checkpoints, key=lambda x: x[0])


def load_model_from_checkpoint(ckpt_path: str, n_classes: int = 2,
                                device: str = "cuda") -> nn.Module:
    model = timm.create_model("resnet18", pretrained=False,
                               num_classes=n_classes)
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state, strict=True)
    model.eval()
    return model.to(device)


# ──────────────────────────────────────────────
# MAIN PROBE ANALYSIS
# ──────────────────────────────────────────────

def run_probe_analysis(run_dir: str, data_root: str,
                       device: str = "cuda",
                       max_samples: int = 800,
                       latest_only: bool = False) -> Optional[Dict]:
    """
    For each checkpoint in a run, extract features at all 6 layers
    and train hospital + tumor probes.
    """
    from utils.camelyon_data import get_camelyon_subsets

    cfg_path = os.path.join(run_dir, "config.json")
    if not os.path.isfile(cfg_path):
        print(f"  No config.json in {run_dir}, skipping")
        return None

    cfg        = json.load(open(cfg_path))
    condition  = cfg.get("condition", "unknown")
    n_train    = cfg.get("n_train", 300)
    seed       = cfg.get("seed", 42)

    print(f"\n{'='*55}")
    print(f"  M1 Probe Analysis: {os.path.basename(run_dir)}")
    print(f"  condition={condition}, n_train={n_train}, seed={seed}")
    print(f"{'='*55}")

    checkpoints = find_checkpoints(run_dir)
    if not checkpoints:
        print(f"  No checkpoints found — skipping")
        return None

    if latest_only:
        checkpoints = checkpoints[-1:]

    print(f"  Found {len(checkpoints)} checkpoints: "
          f"epochs {[e for e,_ in checkpoints]}")
    print(f"  Hospital probe: fits on training data (H0-H2), "
          f"evaluates on H3 and H4 separately")

    # ── Data ─────────────────────────────────────────────────────────
    transform = transforms.Compose([
        transforms.Resize((96, 96)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    train_ds, id_val_ds, ood_test_ds, full_ds = get_camelyon_subsets(
        root_dir=data_root, download=False)

    # Wrap datasets with transform (WILDS returns PIL images)
    class _TransformWrapper:
        def __init__(self, dataset, transform):
            self.dataset = dataset
            self.transform = transform
        def __len__(self):
            return len(self.dataset)
        def __getitem__(self, idx):
            img, label, metadata = self.dataset[idx]
            return self.transform(img), label, metadata

    id_val_t = _TransformWrapper(id_val_ds, transform)
    ood_test_t = _TransformWrapper(ood_test_ds, transform)
    train_t = _TransformWrapper(train_ds, transform)

    torch.manual_seed(seed)
    probe_idx = torch.randperm(len(id_val_t))[:max_samples // 2]
    ood_idx   = torch.randperm(len(ood_test_t))[:max_samples // 2]
    train_idx = torch.randperm(len(train_t))[:max_samples]

    probe_loader = DataLoader(
        Subset(id_val_t, probe_idx),
        batch_size=128, shuffle=False, num_workers=0)
    ood_loader   = DataLoader(
        Subset(ood_test_t, ood_idx),
        batch_size=128, shuffle=False, num_workers=0)
    train_loader = DataLoader(
        Subset(train_t, train_idx),
        batch_size=128, shuffle=False, num_workers=0)

    # ── Results storage ──────────────────────────────────────────────
    results = {
        "run_id":     os.path.basename(run_dir),
        "condition":  condition,
        "n_train":    n_train,
        "seed":       seed,
        "epochs":     [],
        "layers":     LAYER_NAMES,
        "hospital_probe_id":  [],  # Hospital accuracy on H3
        "hospital_probe_ood": [],  # Hospital accuracy on H4
        "tumor_probe_id":     [],  # Tumor accuracy on H3
        "tumor_probe_ood":    [],  # Tumor accuracy on H4
    }

    # ── Per-checkpoint analysis ───────────────────────────────────────
    for epoch, ckpt_path in checkpoints:
        print(f"\n  Epoch {epoch} | {os.path.basename(ckpt_path)}")

        try:
            model = load_model_from_checkpoint(
                ckpt_path, n_classes=2, device=device)
        except Exception as e:
            print(f"    Failed to load checkpoint: {e}")
            continue

        # Extract features from all three datasets
        feats_train, hosp_train, tumor_train = extract_features(
            model, train_loader, device, max_samples=max_samples)

        feats_id, hosp_id, tumor_id = extract_features(
            model, probe_loader, device, max_samples=max_samples // 2)

        feats_ood, hosp_ood, tumor_ood = extract_features(
            model, ood_loader, device, max_samples=max_samples // 2)

        epoch_hosp_id = []
        epoch_hosp_ood = []
        epoch_tumor_id = []
        epoch_tumor_ood = []

        for layer_name in LAYER_NAMES:
            # Fit probes on training features, evaluate on H3 and H4
            X_train_layer = feats_train[layer_name]
            X_id_layer = feats_id[layer_name]
            X_ood_layer = feats_ood[layer_name]

            # Hospital probe: can model distinguish hospitals H0-H2?
            # If yes on H3/H4 → stain is encoded
            h_acc_id = train_probe(X_train_layer, hosp_train,
                                   X_id_layer,  hosp_id)
            h_acc_ood = train_probe(X_train_layer, hosp_train,
                                    X_ood_layer, hosp_ood)

            # Tumor probe: can model distinguish tumor vs normal?
            t_acc_id = train_probe(X_train_layer, tumor_train,
                                   X_id_layer,  tumor_id)
            t_acc_ood = train_probe(X_train_layer, tumor_train,
                                    X_ood_layer, tumor_ood)

            epoch_hosp_id.append(h_acc_id)
            epoch_hosp_ood.append(h_acc_ood)
            epoch_tumor_id.append(t_acc_id)
            epoch_tumor_ood.append(t_acc_ood)

            print(f"    {layer_name:8s}: "
                  f"hosp_H3={h_acc_id:.3f} hosp_H4={h_acc_ood:.3f}  "
                  f"tumor_H3={t_acc_id:.3f} tumor_H4={t_acc_ood:.3f}")

        results["epochs"].append(epoch)
        results["hospital_probe_id"].append(epoch_hosp_id)
        results["hospital_probe_ood"].append(epoch_hosp_ood)
        results["tumor_probe_id"].append(epoch_tumor_id)
        results["tumor_probe_ood"].append(epoch_tumor_ood)

        del model

    # ── Save raw data ─────────────────────────────────────────────────
    out_dir = os.path.join(run_dir, "mechinterp")
    os.makedirs(out_dir, exist_ok=True)

    data_path = os.path.join(out_dir, "m1_probe_data.json")
    with open(data_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\n  Probe data → {data_path}")

    # ── Plots ─────────────────────────────────────────────────────────
    _plot_probe_heatmaps(results, out_dir)
    _plot_probe_curves(results, out_dir)

    print(f"  Figures → {out_dir}/")
    return results


def _plot_probe_heatmaps(results: Dict, out_dir: str):
    """
    Epoch (x) × layer (y), color = probe accuracy.

    Hospital probe shown on H3 (held-in held-out hospital, classes overlap
    with training). The H4 version is degenerate by construction since the
    probe is fit on the training-hospital class set and H4 is not in it
    (hospital_probe_ood ≡ 0 across all epochs / layers).

    Tumor probe shown on H4 (truly OOD hospital) since tumor labels are
    binary and shared across hospitals — H4 captures the causal-feature
    transferability we care about.
    """
    epochs = results["epochs"]
    layers = results["layers"]

    if not epochs:
        return

    hosp_matrix  = np.array(results["hospital_probe_id"])    # H3 — has signal
    tumor_matrix = np.array(results["tumor_probe_ood"])      # H4 — true OOD

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    for ax, matrix, title, cmap in [
        (axes[0], hosp_matrix,  "Hospital probe on H3 (shortcut recoverability)\nHigh = stain still encoded = BAD", "Reds"),
        (axes[1], tumor_matrix, "Tumor probe on H4 (causal, OOD)\nHigh = causal feature transfers = GOOD",          "Greens"),
    ]:
        im = ax.imshow(
            matrix.T,
            aspect="auto",
            cmap=cmap,
            vmin=0.0, vmax=1.0,
            interpolation="nearest",
            origin="lower",
        )
        ax.set_xticks(range(len(epochs)))
        ax.set_xticklabels(epochs, rotation=45, ha="right", fontsize=8)
        ax.set_yticks(range(len(layers)))
        ax.set_yticklabels(layers, fontsize=9)
        ax.set_xlabel("Training epoch")
        ax.set_ylabel("ResNet layer")
        ax.set_title(title, fontsize=10, fontweight="bold")
        plt.colorbar(im, ax=ax, label="Probe accuracy")

    fig.suptitle(
        f"M1 — Layer-wise Linear Probing: {results['run_id']}\n"
        "Circuit signature: deep-layer hospital-probe drop (Reds) + sustained tumor recoverability (Greens)",
        fontsize=10, y=1.02
    )
    plt.tight_layout()
    out = os.path.join(out_dir, "m1_probe_heatmap.png")
    plt.savefig(out, bbox_inches="tight")
    plt.close()


def _plot_probe_curves(results: Dict, out_dir: str):
    """
    Line plot per-layer: hospital probe (H3 — recoverability) + tumor probe
    (H4 — causal-feature transfer to truly unseen hospital), with OOD accuracy
    from history.json overlaid.
    """
    epochs     = results["epochs"]
    run_dir    = os.path.join(out_dir, "..")
    layers     = results["layers"]
    avgpool_idx = layers.index("avgpool")
    layer2_idx  = layers.index("layer2")

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    for ax, layer_idx, layer_label in [
        (axes[0], avgpool_idx, "avgpool (penultimate)"),
        (axes[1], layer2_idx,  "layer2 (early)"),
    ]:
        hosp  = [results["hospital_probe_id"][i][layer_idx]
                 for i in range(len(epochs))]
        tumor = [results["tumor_probe_ood"][i][layer_idx]
                 for i in range(len(epochs))]

        ax.plot(epochs, hosp,  "r-o", markersize=4, lw=2,
                label="Hospital probe on H3 (shortcut recoverability ↓ want)")
        ax.plot(epochs, tumor, "g-s", markersize=4, lw=2,
                label="Tumor probe on H4 (causal transfer ↑ want)")

        hist_path = os.path.join(run_dir, "results", "history.json")
        if os.path.isfile(hist_path):
            try:
                hist     = json.load(open(hist_path))
                hist_eps = [r["epoch"]   for r in hist]
                ood_accs = [r.get("ood_acc", float("nan")) for r in hist]
                ax.plot(hist_eps, ood_accs, "b--", lw=1.5, alpha=0.7,
                        label="OOD accuracy (H4)")
            except Exception:
                pass

        ax.axhline(0.5, color="gray", ls=":", lw=1, alpha=0.5,
                   label="Chance (0.5)")
        ax.set_xlabel("Training epoch")
        ax.set_ylabel("Probe / OOD accuracy")
        ax.set_title(f"Layer: {layer_label}", fontweight="bold")
        ax.legend(fontsize=9)
        ax.set_ylim([0, 1.05])
        ax.grid(alpha=0.3)

    fig.suptitle(
        f"M1 — Probe Curves: {results['run_id']}\n"
        "Hospital recoverability (H3) drops in deep layers + tumor transfers (H4) — circuit signature",
        fontsize=10, y=1.02
    )
    plt.tight_layout()
    out = os.path.join(out_dir, "m1_probe_curves.png")
    plt.savefig(out, bbox_inches="tight")
    plt.close()


def main():
    p = argparse.ArgumentParser(
        description="M1: Layer-wise linear probing for CausalGrok")
    p.add_argument("--run_dir",     default=None,
                   help="Single run directory to analyze")
    p.add_argument("--all_runs",    action="store_true",
                   help="Analyze all camelyon_v2 grokking runs")
    p.add_argument("--data_root",   default="data/wilds")
    p.add_argument("--device",      default="cuda")
    p.add_argument("--max_samples", type=int, default=800)
    p.add_argument("--latest_only", action="store_true",
                   help="Analyze only latest checkpoint (quick check)")
    args = p.parse_args()

    if args.all_runs:
        run_dirs = sorted(glob.glob(
            "experiments/runs/*camelyon_v2*grokking*"))
        print(f"Found {len(run_dirs)} grokking runs")
        all_results = []
        for rd in run_dirs:
            r = run_probe_analysis(rd, args.data_root,
                                   device=args.device,
                                   max_samples=args.max_samples,
                                   latest_only=args.latest_only)
            if r:
                all_results.append(r)

        if all_results:
            os.makedirs("paper_figures", exist_ok=True)
            with open("paper_figures/m1_all_probes.json", "w") as f:
                json.dump(all_results, f, indent=2)
            print(f"\nCombined → paper_figures/m1_all_probes.json")

    elif args.run_dir:
        run_probe_analysis(args.run_dir, args.data_root,
                           device=args.device,
                           max_samples=args.max_samples,
                           latest_only=args.latest_only)
    else:
        print("Specify --run_dir <path> or --all_runs")


if __name__ == "__main__":
    main()

13. Full source: experiments/mechinterp_m4_ablation.py

"""M4 — Representation Ablation: causal intervention on the shortcut subspace.

Pipeline:
  1. Pick a checkpoint (peak-OOD epoch by default).
  2. Extract features at avgpool (or `--layer`) for train (H0-H2) + OOD (H4) splits.
  3. Fit a hospital-classification logistic-regression probe on train features.
     The probe's weight rows define the *shortcut subspace* in feature space.
  4. Build the projector P = W^T (W W^T)^-1 W onto that subspace and define
     `ablate(h) = h - P h`.
  5. Re-classify OOD images with the *same* trained classifier head, fed:
        (a) raw features h           — baseline OOD accuracy
        (b) ablated features h - Ph  — post-intervention OOD accuracy
  6. Also report:
        (c) shortcut accuracy (probe.score on h vs h-Ph)
        (d) tumor probe accuracy on h vs h-Ph (sanity: the causal feature
            should survive the intervention)
        (e) head's tumor classification accuracy on H4 with raw vs ablated features

If the intervention is causal:
   - shortcut probe accuracy: collapses
   - OOD accuracy: improves (or at least doesn't decay as much)
   - tumor probe accuracy: largely preserved

Usage
-----
    python -m experiments.mechinterp_m4_ablation \\
        --run_dir experiments/runs/<id> \\
        --data_root data/wilds \\
        --layer avgpool \\
        [--epoch 50]    # default: peak_ood_epoch from summary.json
        [--max_samples 1000]

Output:
    <run_dir>/mechinterp/m4_ablation_<layer>_ep<E>.json
    <run_dir>/mechinterp/m4_ablation_<layer>_ep<E>.png
"""
from __future__ import annotations

import argparse
import json
import os
import sys
from pathlib import Path
from typing import Dict, Tuple

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))

# Re-use M1 helpers — hooks, model loader, feature extraction, ckpt discovery.
from experiments.mechinterp_m1 import (
    register_hooks,
    extract_features,
    load_model_from_checkpoint,
    find_checkpoints,
)
from utils.camelyon_data import get_camelyon_subsets


class _TransformWrapper:
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        img, label, metadata = self.dataset[idx]
        return self.transform(img), label, metadata


def _build_loaders(data_root: str, max_samples: int, seed: int = 42):
    transform = transforms.Compose([
        transforms.Resize((96, 96)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets(
        root_dir=data_root, download=False
    )
    train_t = _TransformWrapper(train_ds, transform)
    ood_t   = _TransformWrapper(ood_test_ds, transform)

    torch.manual_seed(seed)
    train_idx = torch.randperm(len(train_t))[:max_samples]
    ood_idx   = torch.randperm(len(ood_t))[:max_samples // 2]

    train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128,
                              shuffle=False, num_workers=0)
    ood_loader   = DataLoader(Subset(ood_t,   ood_idx),   batch_size=128,
                              shuffle=False, num_workers=0)
    return train_loader, ood_loader


def _select_epoch(run_dir: Path, requested: int | None) -> Tuple[int, Path]:
    ckpts = find_checkpoints(str(run_dir))
    if not ckpts:
        raise FileNotFoundError(f"No checkpoints in {run_dir}/checkpoints/")

    if requested is not None:
        for ep, p in ckpts:
            if ep == requested:
                return ep, Path(p)
        raise ValueError(f"Requested epoch {requested} not in checkpoints "
                         f"({[ep for ep, _ in ckpts]})")

    # default: peak OOD epoch from summary.json
    summary_path = run_dir / "results" / "summary.json"
    peak = None
    if summary_path.exists():
        s = json.loads(summary_path.read_text())
        peak = s.get("peak_ood_epoch", None)

    if peak is not None and peak > 0:
        # nearest periodic checkpoint
        nearest = min(ckpts, key=lambda x: abs(x[0] - peak))
        return nearest[0], Path(nearest[1])

    # fall back to last checkpoint
    return ckpts[-1][0], Path(ckpts[-1][1])


def _build_projector(W: np.ndarray) -> np.ndarray:
    """W has shape (k, d). Returns P (d, d) projecting onto rowspace(W)."""
    # Use SVD for a stable orthonormal basis of rowspace
    U, s, Vt = np.linalg.svd(W, full_matrices=False)
    # rowspace basis = Vt rows where singular values > tol
    tol = max(W.shape) * np.finfo(s.dtype).eps * (s.max() if s.size else 0.0)
    keep = s > tol
    basis = Vt[keep]                     # (k', d)
    return basis.T @ basis                # (d, d) projector onto rowspace


def _build_shortcut_subspace(
    X: np.ndarray, hospital_ids: np.ndarray,
    method: str = "lda", subspace_dim: int = 32
) -> np.ndarray:
    """Return a (k, d) basis whose row-span is the 'shortcut subspace'.

    method='probe'  — k = (n_classes - 1) probe weight rows (small subspace).
    method='lda'    — k = subspace_dim top between-class directions: take
                      per-hospital means in feature space, center them,
                      and run SVD. This gives a rank-bounded but data-driven
                      subspace that captures hospital-discriminating variance.
    method='pca-class' — top-PCs of features colored by hospital (mean-removed
                          per class), giving us the variance directions that
                          mostly reflect within-hospital structure × class.
    """
    if method == "probe":
        clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
                                  multi_class="auto", n_jobs=-1)
        clf.fit(X, hospital_ids)
        return clf.coef_

    if method == "lda":
        classes = np.unique(hospital_ids)
        global_mean = X.mean(axis=0, keepdims=True)
        between = []
        for c in classes:
            mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True)
            between.append(mu_c - global_mean)
        between = np.vstack(between)            # (n_classes, d)
        # Augment with random hospital-correlated directions to grow rank up
        # to subspace_dim — use top PCs of *centered-by-hospital-mean* features.
        if subspace_dim > between.shape[0]:
            # within-hospital residuals
            residuals = []
            for c in classes:
                mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True)
                residuals.append(X[hospital_ids == c] - mu_c)
            R = np.vstack(residuals)
            # PCA on residuals — these are within-hospital directions; remove
            # them from the shortcut subspace by KEEPING only the between-class
            # directions. So we just return between as-is, plus the top PCs of
            # the *original* features projected onto the orthogonal complement
            # of `between` IF the user wants more dims.
            U, s, Vt = np.linalg.svd(X - global_mean, full_matrices=False)
            top = Vt[:subspace_dim]
            # Score each PC by how much it correlates with hospital-id variance
            # (one-hot expansion); keep top by that correlation.
            one_hot = np.eye(len(classes))[
                np.searchsorted(classes, hospital_ids)
            ]                                          # (N, n_classes)
            proj = (X - global_mean) @ top.T           # (N, subspace_dim)
            corrs = np.array([
                np.max(np.abs([np.corrcoef(proj[:, k], one_hot[:, c])[0, 1]
                                for c in range(len(classes))]))
                for k in range(subspace_dim)
            ])
            # take the top-k most-hospital-correlated PCs
            order = np.argsort(-np.nan_to_num(corrs))
            top_hosp = top[order[:subspace_dim]]
            # combine: between-class means + top-hospital-correlated PCs
            return np.vstack([between, top_hosp])

        return between

    raise ValueError(f"Unknown method: {method}")


def _classifier_logits_from_features(
    model: nn.Module, features: np.ndarray, layer: str, device: str
) -> np.ndarray:
    """Apply the *post-`layer`* part of the network to the (modified) features
    and return the model's binary-classification logits.

    For ResNet, `avgpool` features have shape (N, C). The classifier head
    `model.fc` (timm: `model.get_classifier()`) maps C → 2. For non-avgpool
    layers we do not currently support full propagation — caller should use
    layer='avgpool' for OOD-accuracy interventions."""
    if layer != "avgpool":
        raise NotImplementedError(
            "Re-applying the classifier head from intermediate spatial layers "
            "is not yet supported. Use --layer avgpool for the head-level "
            "ablation."
        )

    # Find the classifier head (timm convention: model.fc or model.get_classifier())
    if hasattr(model, "get_classifier"):
        head = model.get_classifier()
    elif hasattr(model, "fc"):
        head = model.fc
    elif hasattr(model, "classifier"):
        head = model.classifier
    else:
        raise RuntimeError("Could not locate classifier head on the model.")

    head = head.to(device).eval()
    with torch.no_grad():
        x = torch.tensor(features, dtype=torch.float32, device=device)
        logits = head(x).cpu().numpy()
    return logits


def _accuracy(logits: np.ndarray, labels: np.ndarray) -> float:
    if logits.ndim == 1 or logits.shape[1] == 1:
        pred = (logits.flatten() > 0).astype(int)
    else:
        pred = logits.argmax(axis=1)
    return float((pred == labels).mean())


def run_ablation(
    run_dir: Path,
    data_root: str,
    layer: str = "avgpool",
    epoch: int | None = None,
    max_samples: int = 1000,
    device: str = "cuda",
    subspace_method: str = "lda",
    subspace_dim: int = 32,
) -> Dict:
    epoch, ckpt_path = _select_epoch(run_dir, epoch)

    print(f"\n  M4 — Representation Ablation")
    print(f"  run_dir : {run_dir.name}")
    print(f"  epoch   : {epoch} ({ckpt_path.name})")
    print(f"  layer   : {layer}")

    # Load model and dataloaders
    model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device)
    model.eval()
    register_hooks(model)

    cfg_path = run_dir / "config.json"
    seed = 42
    if cfg_path.exists():
        seed = json.loads(cfg_path.read_text()).get("seed", 42)
    train_loader, ood_loader = _build_loaders(data_root, max_samples, seed=seed)

    # Extract features
    print(f"  Extracting features ({max_samples} samples per split)...")
    feats_train, hosp_train, tumor_train = extract_features(
        model, train_loader, device, max_samples=max_samples
    )
    feats_ood, hosp_ood, tumor_ood = extract_features(
        model, ood_loader, device, max_samples=max_samples // 2
    )

    if layer not in feats_train:
        raise KeyError(f"Layer '{layer}' not in extracted features "
                       f"({list(feats_train.keys())})")

    X_tr = np.asarray(feats_train[layer])    # (N_tr, D)
    X_ood = np.asarray(feats_ood[layer])     # (N_ood, D)
    if X_tr.ndim > 2:    # spatial map; flatten
        X_tr  = X_tr.reshape(X_tr.shape[0], -1)
        X_ood = X_ood.reshape(X_ood.shape[0], -1)

    # Normalize features (probe is sensitive to scale; classifier head was
    # trained on un-normalized features so we keep two parallel pipelines).
    scaler = StandardScaler().fit(X_tr)
    X_tr_n = scaler.transform(X_tr)
    X_ood_n = scaler.transform(X_ood)

    # ──────────── 1. Fit hospital probe + build shortcut subspace
    print(f"  Fitting hospital probe on H0/H1/H2 train features...")
    hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
                                   multi_class="auto", n_jobs=-1)
    hosp_clf.fit(X_tr_n, hosp_train)
    hosp_acc_train = hosp_clf.score(X_tr_n, hosp_train)

    # Build a richer shortcut subspace via LDA-style between-class +
    # hospital-correlated top PCs. This catches more shortcut variance than
    # the (n_classes - 1)-D probe-rowspace alone.
    W = _build_shortcut_subspace(X_tr_n, np.asarray(hosp_train),
                                  method=subspace_method,
                                  subspace_dim=subspace_dim)
    P = _build_projector(W)               # (D, D)
    rank_subspace = int(np.linalg.matrix_rank(P, tol=1e-8))
    print(f"  Shortcut subspace: dim={rank_subspace} method={subspace_method}  "
          f"(probe train acc {hosp_acc_train:.3f})")

    # ──────────── 2. Build ablated versions of features
    # Apply the projection in the *normalized* feature space, then un-scale
    # for re-feeding to the classifier head (which was trained on raw features).
    def ablate_norm(X_n):
        return X_n - X_n @ P.T

    X_ood_ablated_n = ablate_norm(X_ood_n)
    # un-scale
    X_ood_ablated = scaler.inverse_transform(X_ood_ablated_n)

    # Sanity probe metrics
    print(f"  Re-fitting tumor probe on train features...")
    tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
                                    multi_class="auto", n_jobs=-1)
    tumor_clf.fit(X_tr_n, tumor_train)
    tumor_acc_train = tumor_clf.score(X_tr_n, tumor_train)

    # Probe accuracies on raw vs ablated OOD features
    hosp_acc_ood_raw     = hosp_clf.score(X_ood_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan")
    hosp_acc_ood_ablated = hosp_clf.score(X_ood_ablated_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan")
    tumor_acc_ood_raw     = tumor_clf.score(X_ood_n, tumor_ood)
    tumor_acc_ood_ablated = tumor_clf.score(X_ood_ablated_n, tumor_ood)

    # ──────────── 3. Head-level OOD classification accuracy
    print(f"  Re-classifying OOD with model head (raw vs ablated features)...")
    logits_raw     = _classifier_logits_from_features(model, X_ood,         layer, device)
    logits_ablated = _classifier_logits_from_features(model, X_ood_ablated, layer, device)

    head_acc_raw     = _accuracy(logits_raw,     tumor_ood)
    head_acc_ablated = _accuracy(logits_ablated, tumor_ood)

    # ──────────── 4. Pack + report
    result = {
        "run_id":   run_dir.name,
        "epoch":    epoch,
        "layer":    layer,
        "max_samples": max_samples,
        "shortcut_subspace_dim": rank_subspace,
        "hospital_probe_train_acc": hosp_acc_train,
        "tumor_probe_train_acc":    tumor_acc_train,
        "hospital_probe_ood_raw":     hosp_acc_ood_raw,
        "hospital_probe_ood_ablated": hosp_acc_ood_ablated,
        "tumor_probe_ood_raw":     tumor_acc_ood_raw,
        "tumor_probe_ood_ablated": tumor_acc_ood_ablated,
        "head_ood_acc_raw":     head_acc_raw,
        "head_ood_acc_ablated": head_acc_ablated,
        "intervention_effect": {
            "shortcut_collapse":  hosp_acc_ood_raw - hosp_acc_ood_ablated,
            "ood_improvement":    head_acc_ablated - head_acc_raw,
            "tumor_preservation": tumor_acc_ood_ablated - tumor_acc_ood_raw,
        },
    }

    print(f"\n  RESULTS")
    print(f"    hospital probe (OOD): {hosp_acc_ood_raw:.3f}{hosp_acc_ood_ablated:.3f}  "
          f"(Δ {result['intervention_effect']['shortcut_collapse']:+.3f})")
    print(f"    tumor probe (OOD)   : {tumor_acc_ood_raw:.3f}{tumor_acc_ood_ablated:.3f}  "
          f"(Δ {result['intervention_effect']['tumor_preservation']:+.3f})")
    print(f"    head OOD acc        : {head_acc_raw:.3f}{head_acc_ablated:.3f}  "
          f"(Δ {result['intervention_effect']['ood_improvement']:+.3f})")

    return result


def plot_ablation(result: Dict, out_path: Path):
    metrics = ["hospital_probe_ood", "tumor_probe_ood", "head_ood_acc"]
    raw_keys     = ["hospital_probe_ood_raw",     "tumor_probe_ood_raw",     "head_ood_acc_raw"]
    ablated_keys = ["hospital_probe_ood_ablated", "tumor_probe_ood_ablated", "head_ood_acc_ablated"]
    labels = ["Hospital probe\n(↓ = causal effect)",
              "Tumor probe\n(stable = good)",
              "Head OOD acc\n(↑ = causal effect)"]
    raws     = [result[k] for k in raw_keys]
    ablateds = [result[k] for k in ablated_keys]

    fig, ax = plt.subplots(figsize=(9, 5))
    x = np.arange(len(metrics))
    w = 0.35
    b1 = ax.bar(x - w / 2, raws,     w, label="raw features",       color="#444")
    b2 = ax.bar(x + w / 2, ablateds, w, label="shortcut-ablated",   color="#c33")
    for bars in (b1, b2):
        for b in bars:
            ax.text(b.get_x() + b.get_width() / 2, b.get_height() + 0.005,
                    f"{b.get_height():.3f}", ha="center", va="bottom", fontsize=9)
    ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=9)
    ax.set_ylim(0, 1.05); ax.set_ylabel("Accuracy")
    ax.set_title(f"M4 — Causal Ablation of Shortcut Subspace\n"
                 f"{result['run_id']}  •  ep{result['epoch']}  •  layer={result['layer']}  "
                 f"•  subspace dim={result['shortcut_subspace_dim']}",
                 fontsize=10, fontweight="bold")
    ax.legend(loc="upper right")
    ax.grid(alpha=0.3, axis="y")
    plt.tight_layout()
    fig.savefig(out_path, dpi=180, bbox_inches="tight")
    plt.close(fig)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--run_dir",   required=True)
    p.add_argument("--data_root", default="data/wilds")
    p.add_argument("--layer",     default="avgpool",
                   choices=["avgpool"])  # head-level intervention only at avgpool
    p.add_argument("--epoch",     type=int, default=None,
                   help="Specific checkpoint epoch; default = peak_ood_epoch from summary.json")
    p.add_argument("--max_samples", type=int, default=1000)
    p.add_argument("--device",    default="cuda")
    p.add_argument("--subspace_method", default="lda",
                   choices=["lda", "probe"],
                   help="lda = LDA-style between-class + hospital-correlated PCs; "
                        "probe = LR probe row-space (small, often only 2-D)")
    p.add_argument("--subspace_dim", type=int, default=32,
                   help="Target subspace dim for lda method")
    p.add_argument("--all_epochs", action="store_true",
                   help="Sweep across all periodic checkpoints")
    args = p.parse_args()

    run_dir = Path(args.run_dir)
    out_dir = run_dir / "mechinterp"
    out_dir.mkdir(parents=True, exist_ok=True)

    if args.all_epochs:
        # Sweep across every periodic checkpoint, build a trajectory.
        ckpts = find_checkpoints(str(run_dir))
        # de-duplicate (final.pt may share epoch with last ep*.pt)
        seen = set(); uniq = []
        for ep, p in ckpts:
            if ep in seen:
                continue
            seen.add(ep); uniq.append((ep, p))

        traj = []
        for ep, _ in uniq:
            try:
                r = run_ablation(
                    run_dir=run_dir, data_root=args.data_root, layer=args.layer,
                    epoch=ep, max_samples=args.max_samples, device=args.device,
                    subspace_method=args.subspace_method,
                    subspace_dim=args.subspace_dim,
                )
                traj.append(r)
            except Exception as e:
                print(f"  [skip ep{ep}] {e}")

        out = out_dir / f"m4_ablation_{args.layer}_trajectory.json"
        out.write_text(json.dumps(traj, indent=2))
        plot_trajectory(traj, out.with_suffix(".png"))
        print(f"\n  → {out}")
        print(f"  → {out.with_suffix('.png')}")
        return

    result = run_ablation(
        run_dir=run_dir,
        data_root=args.data_root,
        layer=args.layer,
        epoch=args.epoch,
        max_samples=args.max_samples,
        device=args.device,
        subspace_method=args.subspace_method,
        subspace_dim=args.subspace_dim,
    )

    base = out_dir / f"m4_ablation_{args.layer}_ep{result['epoch']:05d}"
    (base.with_suffix(".json")).write_text(json.dumps(result, indent=2))
    plot_ablation(result, base.with_suffix(".png"))
    print(f"\n  → {base.with_suffix('.json')}")
    print(f"  → {base.with_suffix('.png')}")


def plot_trajectory(traj, out_path: Path):
    """Plot the intervention effect across training epochs."""
    eps = [r["epoch"] for r in traj]
    head_raw = [r["head_ood_acc_raw"]     for r in traj]
    head_abl = [r["head_ood_acc_ablated"] for r in traj]
    tum_raw  = [r["tumor_probe_ood_raw"]      for r in traj]
    tum_abl  = [r["tumor_probe_ood_ablated"]  for r in traj]

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Panel A: head OOD acc raw vs ablated
    ax = axes[0]
    ax.plot(eps, head_raw, "k-o", lw=2, label="raw features")
    ax.plot(eps, head_abl, "r-s", lw=2, label="shortcut-ablated features")
    ax.fill_between(eps, head_raw, head_abl,
                    where=[a > b for a, b in zip(head_abl, head_raw)],
                    color="seagreen", alpha=0.3, label="ablation helps")
    ax.fill_between(eps, head_raw, head_abl,
                    where=[a < b for a, b in zip(head_abl, head_raw)],
                    color="salmon", alpha=0.3, label="ablation hurts")
    ax.set_xlabel("Training epoch"); ax.set_ylabel("OOD (H4) head accuracy")
    ax.set_title("Head OOD accuracy: raw vs shortcut-ablated", fontweight="bold")
    ax.legend(fontsize=9); ax.grid(alpha=0.3)

    # Panel B: tumor probe survival
    ax = axes[1]
    ax.plot(eps, tum_raw, "k-o", lw=2, label="raw features")
    ax.plot(eps, tum_abl, "g-s", lw=2, label="shortcut-ablated features")
    ax.set_xlabel("Training epoch"); ax.set_ylabel("Tumor probe OOD accuracy")
    ax.set_title("Tumor probe survival under ablation\n(stable line = causal feature preserved)",
                 fontweight="bold")
    ax.legend(fontsize=9); ax.grid(alpha=0.3); ax.set_ylim(0.4, 1.0)

    rid = traj[0]["run_id"] if traj else "?"
    layer = traj[0]["layer"] if traj else "?"
    fig.suptitle(f"M4 — Causal Ablation Trajectory: {rid}  •  layer={layer}",
                 fontsize=11, fontweight="bold")
    plt.tight_layout()
    fig.savefig(out_path, dpi=180, bbox_inches="tight")
    plt.close(fig)


if __name__ == "__main__":
    main()

14. Full source: experiments/mechinterp_m5_steering.py

"""M5 — Activation Steering: causally manipulate the shortcut direction.

For one checkpoint (default: peak_ood_epoch from summary.json), we:
  1. Extract avgpool features for train (H0-H2) + OOD (H4) splits.
  2. Identify the dominant shortcut direction `v_s` as the top eigenvector
     of the between-hospital covariance (LDA's first projection direction).
  3. Sweep α ∈ {-3, -2, -1, 0, +1, +2, +3} and apply
        h' = h + α · σ_align · v_s
     where σ_align is the std of features projected onto v_s (so α counts
     in 'standard deviations of shortcut activation').
  4. Re-classify OOD with the original head.
  5. Re-fit hospital + tumor probes on the steered features and report
     accuracy curves.

Strong mechanistic claim if:
  - tumor-head OOD acc declines monotonically as |α| grows
  - hospital-probe acc on steered features rises with |α|
  - tumor-probe acc on steered features stays approximately flat (the
    *causal* feature isn't aligned with the shortcut direction)

Usage
-----
    python -m experiments.mechinterp_m5_steering \\
        --run_dir experiments/runs/<id> \\
        --data_root data/wilds \\
        [--epoch 50]    # default: peak_ood_epoch from summary.json
        [--max_samples 1000] [--alphas " -3,-2,-1,0,1,2,3 "]
"""
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

from experiments.mechinterp_m1 import (
    register_hooks, extract_features, load_model_from_checkpoint,
    find_checkpoints,
)
from experiments.mechinterp_m4_ablation import (
    _select_epoch, _build_loaders,
    _classifier_logits_from_features, _accuracy,
)


def _top_lda_direction(X: np.ndarray, hospital_ids: np.ndarray) -> np.ndarray:
    """Return a unit vector aligned with the dominant between-hospital direction
    in feature space (LDA-1)."""
    classes = np.unique(hospital_ids)
    global_mean = X.mean(axis=0, keepdims=True)
    means = np.vstack([
        X[hospital_ids == c].mean(axis=0, keepdims=True) - global_mean
        for c in classes
    ])
    # SVD: rows of Vt are the orthonormal between-class directions ranked by
    # singular value (variance explained between hospitals).
    U, s, Vt = np.linalg.svd(means, full_matrices=False)
    return Vt[0]   # (D,) unit vector


def run_steering(
    run_dir: Path,
    data_root: str,
    epoch: int | None = None,
    max_samples: int = 1000,
    device: str = "cuda",
    alphas: List[float] = None,
) -> Dict:
    if alphas is None:
        alphas = [-3.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 3.0]

    epoch, ckpt_path = _select_epoch(run_dir, epoch)

    print(f"\n  M5 — Activation Steering")
    print(f"  run_dir : {run_dir.name}")
    print(f"  epoch   : {epoch} ({ckpt_path.name})")
    print(f"  alphas  : {alphas}")

    model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device)
    model.eval()
    register_hooks(model)

    cfg_path = run_dir / "config.json"
    seed = 42
    if cfg_path.exists():
        seed = json.loads(cfg_path.read_text()).get("seed", 42)
    train_loader, ood_loader = _build_loaders(data_root, max_samples, seed=seed)

    print(f"  Extracting features...")
    feats_train, hosp_train, tumor_train = extract_features(
        model, train_loader, device, max_samples=max_samples
    )
    feats_ood, hosp_ood, tumor_ood = extract_features(
        model, ood_loader, device, max_samples=max_samples // 2
    )

    layer = "avgpool"
    X_tr  = np.asarray(feats_train[layer]); X_tr = X_tr.reshape(X_tr.shape[0], -1)
    X_ood = np.asarray(feats_ood[layer]);   X_ood = X_ood.reshape(X_ood.shape[0], -1)
    hosp_train = np.asarray(hosp_train)
    hosp_ood   = np.asarray(hosp_ood)
    tumor_train = np.asarray(tumor_train)
    tumor_ood   = np.asarray(tumor_ood)

    # Standardize for probe-fitting; un-standardize when feeding head
    scaler = StandardScaler().fit(X_tr)
    X_tr_n  = scaler.transform(X_tr)
    X_ood_n = scaler.transform(X_ood)

    # 1. Top LDA direction in normalized feature space
    v = _top_lda_direction(X_tr_n, hosp_train)            # (D,) unit vec
    # Std of training features projected onto v (scale unit for α)
    sigma = float(np.std(X_tr_n @ v))
    print(f"  Top hospital direction v_s :  ‖v‖={np.linalg.norm(v):.3f}, "
          f"σ(X_tr·v)={sigma:.3f}")

    # 2. Pre-fit reference probes on un-steered train features
    hosp_clf  = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
                                    multi_class="auto", n_jobs=-1).fit(X_tr_n, hosp_train)
    tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
                                    multi_class="auto", n_jobs=-1).fit(X_tr_n, tumor_train)

    # 3. Sweep α
    sweep = []
    for alpha in alphas:
        # Steer features along v in normalized space, then un-scale for the head.
        X_ood_steered_n = X_ood_n + alpha * sigma * v[None, :]
        X_ood_steered   = scaler.inverse_transform(X_ood_steered_n)

        # Head OOD accuracy
        logits = _classifier_logits_from_features(model, X_ood_steered, layer, device)
        head_acc = _accuracy(logits, tumor_ood)

        # Probe accuracies on steered features
        if len(np.unique(hosp_ood)) > 1:
            hosp_acc = hosp_clf.score(X_ood_steered_n, hosp_ood)
        else:
            hosp_acc = float("nan")
        tumor_acc = tumor_clf.score(X_ood_steered_n, tumor_ood)

        sweep.append({
            "alpha":          float(alpha),
            "head_ood_acc":   head_acc,
            "hospital_probe": hosp_acc,
            "tumor_probe":    tumor_acc,
        })
        print(f"  α={alpha:+.2f}  head_ood={head_acc:.3f}  "
              f"hosp_probe={hosp_acc if not np.isnan(hosp_acc) else 'nan':<5} "
              f"tumor_probe={tumor_acc:.3f}")

    return {
        "run_id":   run_dir.name,
        "epoch":    epoch,
        "layer":    layer,
        "max_samples": max_samples,
        "v_norm":   float(np.linalg.norm(v)),
        "sigma":    sigma,
        "sweep":    sweep,
    }


def plot_steering(result: Dict, out_path: Path):
    sweep = result["sweep"]
    a = [r["alpha"] for r in sweep]
    head = [r["head_ood_acc"] for r in sweep]
    hosp = [r["hospital_probe"] for r in sweep]
    tumor = [r["tumor_probe"] for r in sweep]

    fig, axes = plt.subplots(1, 2, figsize=(13, 5))

    # Panel A — Head OOD acc vs α
    ax = axes[0]
    ax.plot(a, head, "k-o", lw=2, ms=7)
    ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5)
    ax.set_xlabel("Steering coefficient α (in σ-units of shortcut direction)")
    ax.set_ylabel("Head OOD (H4) accuracy")
    ax.set_title("Causal effect of steering activations along v_s\n"
                 "(monotonic decline as |α| grows = causal evidence)",
                 fontweight="bold", fontsize=10)
    ax.grid(alpha=0.3)
    ax.set_ylim(0.4, max(0.85, max(head) + 0.05))

    # Panel B — Probe accuracies vs α
    ax = axes[1]
    ax.plot(a, hosp,  "r-s", lw=2, ms=7, label="Hospital probe (↑ with |α| = good)")
    ax.plot(a, tumor, "g-^", lw=2, ms=7, label="Tumor probe (flat = causal disjoint)")
    ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5)
    ax.set_xlabel("Steering coefficient α")
    ax.set_ylabel("Probe accuracy")
    ax.set_title("Probe responses to steering", fontweight="bold", fontsize=10)
    ax.legend(loc="best", fontsize=9); ax.grid(alpha=0.3)
    ax.set_ylim(0, 1.05)

    fig.suptitle(f"M5 — Activation Steering: {result['run_id']}  "
                 f"•  ep{result['epoch']}  •  layer={result['layer']}",
                 fontsize=11, fontweight="bold")
    plt.tight_layout()
    fig.savefig(out_path, dpi=180, bbox_inches="tight")
    plt.close(fig)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--run_dir",   required=True)
    p.add_argument("--data_root", default="data/wilds")
    p.add_argument("--epoch",     type=int, default=None)
    p.add_argument("--max_samples", type=int, default=1000)
    p.add_argument("--device",    default="cuda")
    p.add_argument("--alphas",    default=None,
                   help="Comma-separated α values, e.g. ' -3,-2,-1,0,1,2,3 '")
    p.add_argument("--all_epochs", action="store_true",
                   help="Sweep across all periodic checkpoints; output a trajectory")
    args = p.parse_args()

    alphas = None
    if args.alphas is not None:
        alphas = [float(x) for x in args.alphas.split(",")]

    run_dir = Path(args.run_dir)
    out_dir = run_dir / "mechinterp"
    out_dir.mkdir(parents=True, exist_ok=True)

    if args.all_epochs:
        # Trajectory mode: run M5 at every periodic checkpoint
        ckpts = find_checkpoints(str(run_dir))
        seen = set(); uniq = []
        for ep, p in ckpts:
            if ep in seen:
                continue
            seen.add(ep); uniq.append((ep, p))
        traj = []
        for ep, _ in uniq:
            try:
                r = run_steering(
                    run_dir=run_dir, data_root=args.data_root, epoch=ep,
                    max_samples=args.max_samples, device=args.device, alphas=alphas,
                )
                traj.append(r)
            except Exception as e:
                print(f"  [skip ep{ep}] {e}")
        out = out_dir / "m5_steering_trajectory.json"
        out.write_text(json.dumps(traj, indent=2))
        print(f"\n  → {out}")
        return

    result = run_steering(
        run_dir=run_dir, data_root=args.data_root, epoch=args.epoch,
        max_samples=args.max_samples, device=args.device, alphas=alphas,
    )
    base = out_dir / f"m5_steering_ep{result['epoch']:05d}"
    base.with_suffix(".json").write_text(json.dumps(result, indent=2))
    plot_steering(result, base.with_suffix(".png"))
    print(f"\n  → {base.with_suffix('.json')}")
    print(f"  → {base.with_suffix('.png')}")


if __name__ == "__main__":
    main()

15. Full source: experiments/mechinterp_m6_neuron_ablation.py

"""M6 — Neuron-level Ablation (the textbook reviewer-asked intervention).

Pipeline:
  1. At a chosen checkpoint (default: peak_ood_epoch), extract avgpool
     features for train (H0-H2) and OOD (H4) splits.
  2. Score each of the 512 avgpool channels by *how predictive its activation
     is of hospital ID*: we use a one-vs-rest logistic-regression coefficient
     per channel × class as the per-neuron shortcut score:
         score_c = max_h |β_{h,c}|       (β = coefficients of LR fit per channel)
     ↑ score_c → channel c is more strongly stain-shortcut-aligned.
  3. Sweep top-K ∈ {0, 8, 16, 32, 64, 128} ablated neurons (zero out their
     activations) and measure:
        - head OOD acc (raw vs ablated)
        - hospital-probe acc on raw vs ablated features
        - tumor-probe acc on raw vs ablated features
  4. Strong mechanistic claim:
        - hospital-probe acc collapses sharply with K (these neurons are
          carrying hospital info)
        - head OOD acc *improves* (or at least preserves) at small K (the
          model was using shortcut neurons to harm OOD)
        - tumor-probe acc stays flat (causal info is distributed elsewhere)

Usage
-----
    python -m experiments.mechinterp_m6_neuron_ablation \\
        --run_dir experiments/runs/<id> \\
        --data_root data/wilds \\
        [--epoch 50] [--max_samples 1000] \\
        [--ks "0,4,8,16,32,64,128,256"]
"""
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

from torch.utils.data import DataLoader, Subset
from torchvision import transforms

from experiments.mechinterp_m1 import (
    register_hooks, extract_features, load_model_from_checkpoint,
)
from experiments.mechinterp_m4_ablation import (
    _select_epoch, _TransformWrapper,
    _classifier_logits_from_features, _accuracy,
)
from utils.camelyon_data import get_camelyon_subsets


def _build_loaders_with_id(data_root: str, max_samples: int, seed: int = 42):
    """Like M4's _build_loaders but also returns an ID validation loader so
    we can track ID acc and compute the OOD/ID degradation ratio."""
    transform = transforms.Compose([
        transforms.Resize((96, 96)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets(
        root_dir=data_root, download=False
    )
    train_t = _TransformWrapper(train_ds, transform)
    id_t    = _TransformWrapper(id_val_ds, transform)
    ood_t   = _TransformWrapper(ood_test_ds, transform)

    torch.manual_seed(seed)
    train_idx = torch.randperm(len(train_t))[:max_samples]
    id_idx    = torch.randperm(len(id_t))[:max_samples // 2]
    ood_idx   = torch.randperm(len(ood_t))[:max_samples // 2]

    train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128,
                              shuffle=False, num_workers=0)
    id_loader    = DataLoader(Subset(id_t,    id_idx),    batch_size=128,
                              shuffle=False, num_workers=0)
    ood_loader   = DataLoader(Subset(ood_t,   ood_idx),   batch_size=128,
                              shuffle=False, num_workers=0)
    return train_loader, id_loader, ood_loader


def _per_neuron_shortcut_scores(X_n: np.ndarray, hosp: np.ndarray) -> np.ndarray:
    """Return a (D,) array — score per channel c, larger = more hospital-predictive.

    Uses a 1-feature-at-a-time log-reg fit's |coef| would be dominated by feature
    scale; instead we fit a single multiclass LR over all features and use the
    L2 norm of (β_{:,c}) — the column norm of the LR coefficient matrix —
    as channel c's hospital-discrimination score.
    """
    clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
                              multi_class="auto", n_jobs=-1).fit(X_n, hosp)
    W = clf.coef_                       # (n_classes, D)
    # column norms — large means many class-discriminations rely on this neuron
    return np.linalg.norm(W, axis=0)    # (D,)


def _ablate_and_eval(
    X_n, mask, scaler, head_target, model, layer, device,
    hosp_clf, tumor_clf, hosp_target, tumor_target,
):
    """Apply mask to normalized features, unscale, evaluate everything."""
    X_ablated_n = X_n * mask[None, :]
    X_ablated   = scaler.inverse_transform(X_ablated_n)
    logits = _classifier_logits_from_features(model, X_ablated, layer, device)
    head_acc = _accuracy(logits, head_target)
    hosp_acc = hosp_clf.score(X_ablated_n, hosp_target) if hosp_clf is not None and len(np.unique(hosp_target)) > 1 else float("nan")
    tumor_acc = tumor_clf.score(X_ablated_n, tumor_target)
    return head_acc, hosp_acc, tumor_acc


def run_neuron_ablation(
    run_dir: Path,
    data_root: str,
    epoch: int | None = None,
    max_samples: int = 1000,
    device: str = "cuda",
    ks: List[int] = None,
    n_random_samples: int = 5,
    include_morphology: bool = True,
    include_id: bool = True,
) -> Dict:
    if ks is None:
        # Dose-response curve emphasizing small K (per reviewer guidance)
        ks = [0, 4, 8, 16, 32, 64, 128, 256]

    epoch, ckpt_path = _select_epoch(run_dir, epoch)

    print(f"\n  M6 — Neuron Ablation (with random + morphology controls)")
    print(f"  run_dir : {run_dir.name}")
    print(f"  epoch   : {epoch} ({ckpt_path.name})")
    print(f"  ks      : {ks}")
    print(f"  random ablation: {n_random_samples} samplings per K")

    model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device)
    model.eval()
    register_hooks(model)

    cfg_path = run_dir / "config.json"
    seed = 42
    if cfg_path.exists():
        seed = json.loads(cfg_path.read_text()).get("seed", 42)

    if include_id:
        train_loader, id_loader, ood_loader = _build_loaders_with_id(data_root, max_samples, seed=seed)
    else:
        from experiments.mechinterp_m4_ablation import _build_loaders as _bl
        train_loader, ood_loader = _bl(data_root, max_samples, seed=seed)
        id_loader = None

    print(f"  Extracting features (train + id + ood)...")
    feats_train, hosp_train, tumor_train = extract_features(
        model, train_loader, device, max_samples=max_samples
    )
    feats_ood, hosp_ood, tumor_ood = extract_features(
        model, ood_loader, device, max_samples=max_samples // 2
    )
    feats_id, hosp_id, tumor_id = (None, None, None)
    if id_loader is not None:
        feats_id, hosp_id, tumor_id = extract_features(
            model, id_loader, device, max_samples=max_samples // 2
        )

    layer = "avgpool"
    def _to_2d(arr):
        a = np.asarray(arr); return a.reshape(a.shape[0], -1)
    X_tr  = _to_2d(feats_train[layer])
    X_ood = _to_2d(feats_ood[layer])
    X_id  = _to_2d(feats_id[layer]) if feats_id is not None else None
    hosp_train  = np.asarray(hosp_train);  hosp_ood  = np.asarray(hosp_ood)
    tumor_train = np.asarray(tumor_train); tumor_ood = np.asarray(tumor_ood)
    if X_id is not None:
        hosp_id  = np.asarray(hosp_id);  tumor_id = np.asarray(tumor_id)

    scaler = StandardScaler().fit(X_tr)
    X_tr_n  = scaler.transform(X_tr)
    X_ood_n = scaler.transform(X_ood)
    X_id_n  = scaler.transform(X_id) if X_id is not None else None

    # 1. Per-neuron scores: shortcut (hospital) and morphology (tumor)
    print(f"  Scoring {X_tr.shape[1]} avgpool channels...")
    shortcut_scores  = _per_neuron_shortcut_scores(X_tr_n, hosp_train)
    morphology_scores = _per_neuron_shortcut_scores(X_tr_n, tumor_train) if include_morphology else None
    rank_shortcut   = np.argsort(-shortcut_scores)
    rank_morphology = np.argsort(-morphology_scores) if morphology_scores is not None else None

    hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
                                   multi_class="auto", n_jobs=-1).fit(X_tr_n, hosp_train)
    tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
                                    multi_class="auto", n_jobs=-1).fit(X_tr_n, tumor_train)

    rng = np.random.default_rng(seed)
    D = X_tr.shape[1]

    sweep = []
    for k in ks:
        row = {"k": int(k)}

        # Mask helpers
        def make_mask(indices):
            m = np.ones(D)
            if k > 0:
                m[indices[:k]] = 0.0
            return m

        # ── A: top-K SHORTCUT neurons (the targeted ablation) ──
        mask_s = make_mask(rank_shortcut)
        h_ood, hp_ood, tp_ood = _ablate_and_eval(
            X_ood_n, mask_s, scaler, tumor_ood, model, layer, device,
            hosp_clf, tumor_clf, hosp_ood, tumor_ood,
        )
        row["shortcut_head_ood"]  = float(h_ood)
        row["shortcut_hosp_probe"] = float(hp_ood)
        row["shortcut_tumor_probe"] = float(tp_ood)
        if X_id_n is not None:
            h_id, _, _ = _ablate_and_eval(
                X_id_n, mask_s, scaler, tumor_id, model, layer, device,
                None, tumor_clf, hosp_id, tumor_id,
            )
            row["shortcut_head_id"] = float(h_id)

        # ── B: top-K MORPHOLOGY neurons (control: ablate the causal neurons) ──
        if include_morphology and rank_morphology is not None:
            mask_m = make_mask(rank_morphology)
            h_ood_m, _, _ = _ablate_and_eval(
                X_ood_n, mask_m, scaler, tumor_ood, model, layer, device,
                None, tumor_clf, hosp_ood, tumor_ood,
            )
            row["morphology_head_ood"] = float(h_ood_m)
            if X_id_n is not None:
                h_id_m, _, _ = _ablate_and_eval(
                    X_id_n, mask_m, scaler, tumor_id, model, layer, device,
                    None, tumor_clf, hosp_id, tumor_id,
                )
                row["morphology_head_id"] = float(h_id_m)

        # ── C: K RANDOM neurons (control: damage uniformly) ──
        if k > 0:
            r_oods, r_ids = [], []
            for s_ in range(n_random_samples):
                idx = rng.permutation(D)[:k]
                m = np.ones(D); m[idx] = 0.0
                h_ood_r, _, _ = _ablate_and_eval(
                    X_ood_n, m, scaler, tumor_ood, model, layer, device,
                    None, tumor_clf, hosp_ood, tumor_ood,
                )
                r_oods.append(h_ood_r)
                if X_id_n is not None:
                    h_id_r, _, _ = _ablate_and_eval(
                        X_id_n, m, scaler, tumor_id, model, layer, device,
                        None, tumor_clf, hosp_id, tumor_id,
                    )
                    r_ids.append(h_id_r)
            row["random_head_ood_mean"] = float(np.mean(r_oods))
            row["random_head_ood_std"]  = float(np.std(r_oods))
            if r_ids:
                row["random_head_id_mean"] = float(np.mean(r_ids))
                row["random_head_id_std"]  = float(np.std(r_ids))
        else:
            row["random_head_ood_mean"] = row["shortcut_head_ood"]   # K=0 same as baseline
            row["random_head_ood_std"]  = 0.0
            if X_id_n is not None:
                row["random_head_id_mean"] = row.get("shortcut_head_id", float("nan"))
                row["random_head_id_std"]  = 0.0

        sweep.append(row)
        # Concise log line
        print(f"    K={k:>4}  shortcut={row['shortcut_head_ood']:.3f}  "
              f"random={row.get('random_head_ood_mean', float('nan')):.3f}±"
              f"{row.get('random_head_ood_std', 0):.3f}  "
              + (f"morphology={row.get('morphology_head_ood', float('nan')):.3f}"
                 if include_morphology else ""))

    return {
        "run_id":   run_dir.name,
        "epoch":    epoch,
        "layer":    layer,
        "max_samples": max_samples,
        "feature_dim": int(X_tr.shape[1]),
        "shortcut_scores_top10":  [int(i) for i in rank_shortcut[:10]],
        "morphology_scores_top10": ([int(i) for i in rank_morphology[:10]]
                                    if rank_morphology is not None else []),
        "n_random_samples": n_random_samples,
        "include_id":       include_id,
        "include_morphology": include_morphology,
        "sweep": sweep,
    }


def plot_neuron_ablation(result: Dict, out_path: Path):
    sweep = result["sweep"]
    ks = [r["k"] for r in sweep]

    has_id = result.get("include_id", False)
    has_morph = result.get("include_morphology", False)

    fig, axes = plt.subplots(1, 2 if has_id else 1, figsize=(13, 5)) if has_id else \
                plt.subplots(1, 1, figsize=(8, 5))
    if not has_id:
        axes = [axes]

    # Panel A — Head OOD: shortcut vs random (vs morphology)
    ax = axes[0]
    shortcut_ood   = [r.get("shortcut_head_ood")    for r in sweep]
    random_ood_mu  = [r.get("random_head_ood_mean") for r in sweep]
    random_ood_sd  = [r.get("random_head_ood_std", 0) for r in sweep]
    morphology_ood = [r.get("morphology_head_ood")  for r in sweep] if has_morph else None

    ax.plot(ks, shortcut_ood, "r-o", lw=2.2, ms=7, label="top-K shortcut neurons (targeted)")
    ax.plot(ks, random_ood_mu, "k-s", lw=1.8, ms=6, label="K random neurons (control)")
    ax.fill_between(ks,
                    [m - s for m, s in zip(random_ood_mu, random_ood_sd)],
                    [m + s for m, s in zip(random_ood_mu, random_ood_sd)],
                    color="black", alpha=0.15)
    if has_morph and morphology_ood is not None:
        ax.plot(ks, morphology_ood, "g-^", lw=1.8, ms=6,
                label="top-K morphology neurons (control)")

    base = shortcut_ood[0]
    ax.axhline(base, color="gray", ls=":", lw=1, alpha=0.5,
               label=f"K=0 baseline ({base:.3f})")
    ax.set_xlabel("K (neurons zeroed at avgpool)")
    ax.set_ylabel("Head OOD (H4) accuracy")
    ax.set_xscale("symlog", linthresh=4)
    ax.set_title("Targeted vs random ablation — OOD effect\n"
                 "(separation = shortcut neurons selectively hurt OOD)",
                 fontweight="bold", fontsize=10)
    ax.legend(loc="best", fontsize=8); ax.grid(alpha=0.3)

    # Panel B — ID/OOD tradeoff
    if has_id:
        ax = axes[1]
        shortcut_id  = [r.get("shortcut_head_id")  for r in sweep]
        random_id_mu = [r.get("random_head_id_mean") for r in sweep]
        random_id_sd = [r.get("random_head_id_std", 0) for r in sweep]
        ax.plot(ks, shortcut_id,  "r--o", lw=2, ms=7, alpha=0.85, label="ID (shortcut ablation)")
        ax.plot(ks, shortcut_ood, "r-o",  lw=2, ms=7,             label="OOD (shortcut ablation)")
        ax.plot(ks, random_id_mu, "k--s", lw=1.6, ms=5, alpha=0.7, label="ID (random ablation)")
        ax.plot(ks, random_ood_mu, "k-s",  lw=1.6, ms=5, alpha=0.7, label="OOD (random ablation)")
        ax.set_xlabel("K (neurons zeroed at avgpool)")
        ax.set_ylabel("Head accuracy")
        ax.set_xscale("symlog", linthresh=4)
        ax.set_title("ID vs OOD degradation tradeoff\n"
                     "(targeted: OOD steady or ↑ while ID slowly ↓ = good)",
                     fontweight="bold", fontsize=10)
        ax.legend(fontsize=8, loc="best"); ax.grid(alpha=0.3)

    fig.suptitle(f"M6 — Targeted Neuron Ablation vs Random Control: {result['run_id']}  "
                 f"•  ep{result['epoch']}",
                 fontsize=11, fontweight="bold")
    plt.tight_layout()
    fig.savefig(out_path, dpi=180, bbox_inches="tight")
    plt.close(fig)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--run_dir",   required=True)
    p.add_argument("--data_root", default="data/wilds")
    p.add_argument("--epoch",     type=int, default=None)
    p.add_argument("--max_samples", type=int, default=1000)
    p.add_argument("--device",    default="cuda")
    p.add_argument("--ks",        default=None,
                   help="Comma-separated K values, e.g. '0,4,8,16,32,64,128,256'")
    p.add_argument("--n_random_samples", type=int, default=5,
                   help="Random ablation: averages over this many random K-subsets")
    p.add_argument("--no_morphology", action="store_true",
                   help="Skip the morphology-targeted ablation control")
    p.add_argument("--no_id", action="store_true",
                   help="Skip ID accuracy evaluation (faster but loses ID/OOD ratio)")
    args = p.parse_args()

    ks = None
    if args.ks is not None:
        ks = [int(x) for x in args.ks.split(",")]

    run_dir = Path(args.run_dir)
    out_dir = run_dir / "mechinterp"
    out_dir.mkdir(parents=True, exist_ok=True)

    result = run_neuron_ablation(
        run_dir=run_dir, data_root=args.data_root, epoch=args.epoch,
        max_samples=args.max_samples, device=args.device, ks=ks,
        n_random_samples=args.n_random_samples,
        include_morphology=not args.no_morphology,
        include_id=not args.no_id,
    )
    base = out_dir / f"m6_neuron_ablation_ep{result['epoch']:05d}"
    base.with_suffix(".json").write_text(json.dumps(result, indent=2))
    plot_neuron_ablation(result, base.with_suffix(".png"))
    print(f"\n  → {base.with_suffix('.json')}")
    print(f"  → {base.with_suffix('.png')}")


if __name__ == "__main__":
    main()

16. Run inventory and summary results

14 runs at the canonical 3000-epoch config. The 11 runs with full M4/M5/M6 trajectories carry the paper's mechanistic claims; the 3 n=300 grokking runs have M1 only.

n = 1000 (5 grokking-favorable + 3 standard)

Cond. Seed Run ID Peak OOD Peak ep Final OOD Δ (ungrok) Best ID Final ‖W‖ Final rank Final IRM
Grok 7 20260508-183413_grokking_n1000_s7 0.6876 50 0.5882 −0.0995 0.8797 1516.7 69.83 4.47e-12
Grok 42 20260505-080445_grokking_n1000_s42 0.7336 350 0.6639 −0.0696 0.8976 1470.5 35.87 4.34e-12
Grok 123 20260505-100720_grokking_n1000_s123 0.7270 350 0.6447 −0.0823 0.8994 1457.2 56.65 6.73e-15
Grok 456 20260505-100720_grokking_n1000_s456 0.6722 1100 0.5224 −0.1498 0.8824 1493.6 64.54 2.87e-09
Grok 2024 20260508-183413_grokking_n1000_s2024 0.7056 400 0.5506 −0.1550 0.8959 1632.4 65.77 4.56e-07
Std 42 20260505-100720_standard_n1000_s42 0.7615 1 0.6482 −0.1133 0.9011 812.6 33.35 2.24e-13
Std 123 20260508-183413_standard_n1000_s123 0.8880* 1 0.6645 −0.2235 0.8957 798.3 37.08 9.53e-14
Std 456 20260508-183413_standard_n1000_s456 0.7450 1050 0.5783 −0.1667 0.8950 792.4 35.30 7.18e-09

* Std s123 peaks at epoch 1 on the random initialization (artifact).

Aggregates (n=1000): grokking 5-seed mean peak = 0.7052 ± 0.0237, mean Δ = −0.1112 ± 0.0345. Standard corrected 2-seed mean peak (s42, s456) = 0.7533; raw 3-seed mean = 0.7982.

n = 500 and n = 300

Cond. Seed Run ID Peak OOD Peak ep Final OOD Δ Best ID
Grok 42 20260505-080442_grokking_n500_s42 0.7924 50 0.5514 −0.2410 0.8874
Std 42 20260505-100720_standard_n500_s42 0.7576 1050 0.6526 −0.1050 0.8867
Grok 42 20260502-214859_grokking_n300_s42 0.7162 250 0.5189 −0.1974 0.8664
Grok 123 20260502-214859_grokking_n300_s123 0.6961 50 0.5154 −0.1807 0.8388
Grok 456 20260502-214859_grokking_n300_s456 0.6654 750 0.5469 −0.1184 0.8522
Std 42 20260505-080836_standard_n300_s42 0.7647 250 0.7052 −0.0596 0.8584

Universal finding: every run ungrokks (Δ < 0 for all 14). No run shows a plateau-then-jump. grokking_epoch = -1 everywhere; irm_drop_pct ≈ 100% everywhere.


17. Per-run config.json and summary.json (all 14 runs)

Every run's exact saved JSON, verbatim from disk.

20260502-214859_grokking_n300_s42

config.json:

{
  "seed": 42,
  "n_train": 300,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "grokking",
  "lr": 0.001,
  "weight_decay": 0.005,
  "n_epochs": 3000,
  "init_scale": 4.0,
  "use_grokfast": true,
  "grokfast_alpha": 0.98,
  "grokfast_lamb": 2.0,
  "grad_clip": 1.0,
  "run_id": "20260502-214859_grokking_n300_s42",
  "run_dir": "experiments/runs/20260502-214859_grokking_n300_s42"
}

results/summary.json:

{
  "run_id": "20260502-214859_grokking_n300_s42",
  "condition": "grokking",
  "n_train": 300,
  "seed": 42,
  "best_id_val": 0.866388557806913,
  "best_ood": 0.7162273379264937,
  "peak_ood_epoch": 250,
  "final_ood": 0.5188703647094787,
  "ood_delta": -0.19735697321701506,
  "ood_improvement": 0.01951701272132994,
  "grokking_epoch": -1,
  "final_weight_norm": 1131.2872887615824,
  "final_feature_rank": 39.32365036010742,
  "final_irm": 6.867336560523185e-14,
  "final_shortcut_ratio": 1.0053635176200695,
  "final_ood_gap": 0.3284478712857537,
  "ungrokking_detected": true
}

20260502-214859_grokking_n300_s123

config.json:

{
  "seed": 123,
  "n_train": 300,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "grokking",
  "lr": 0.001,
  "weight_decay": 0.005,
  "n_epochs": 3000,
  "init_scale": 4.0,
  "use_grokfast": true,
  "grokfast_alpha": 0.98,
  "grokfast_lamb": 2.0,
  "grad_clip": 1.0,
  "run_id": "20260502-214859_grokking_n300_s123",
  "run_dir": "experiments/runs/20260502-214859_grokking_n300_s123"
}

results/summary.json:

{
  "run_id": "20260502-214859_grokking_n300_s123",
  "condition": "grokking",
  "n_train": 300,
  "seed": 123,
  "best_id_val": 0.8387663885578069,
  "best_ood": 0.6961459778493663,
  "peak_ood_epoch": 50,
  "final_ood": 0.515413737155219,
  "ood_delta": -0.18073224069414728,
  "ood_improvement": 0.015413737155218987,
  "grokking_epoch": -1,
  "final_weight_norm": 981.2013665038359,
  "final_feature_rank": 44.795772552490234,
  "final_irm": 5.876544368309256e-13,
  "final_shortcut_ratio": 1.0042143792951101,
  "final_ood_gap": 0.23887708525240914,
  "ungrokking_detected": true
}

20260502-214859_grokking_n300_s456

config.json:

{
  "seed": 456,
  "n_train": 300,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "grokking",
  "lr": 0.001,
  "weight_decay": 0.005,
  "n_epochs": 3000,
  "init_scale": 4.0,
  "use_grokfast": true,
  "grokfast_alpha": 0.98,
  "grokfast_lamb": 2.0,
  "grad_clip": 1.0,
  "run_id": "20260502-214859_grokking_n300_s456",
  "run_dir": "experiments/runs/20260502-214859_grokking_n300_s456"
}

results/summary.json:

{
  "run_id": "20260502-214859_grokking_n300_s456",
  "condition": "grokking",
  "n_train": 300,
  "seed": 456,
  "best_id_val": 0.8521752085816449,
  "best_ood": 0.6653655324852447,
  "peak_ood_epoch": 750,
  "final_ood": 0.5469348884238249,
  "ood_delta": -0.11843064406141979,
  "ood_improvement": 0.04693488842382487,
  "grokking_epoch": -1,
  "final_weight_norm": 1109.5530019538146,
  "final_feature_rank": 49.14884567260742,
  "final_irm": 8.174740884214771e-10,
  "final_shortcut_ratio": 0.9680124053677005,
  "final_ood_gap": 0.25908418189798677,
  "ungrokking_detected": true
}

20260505-080836_standard_n300_s42

config.json:

{
  "seed": 42,
  "n_train": 300,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "standard",
  "lr": 0.001,
  "weight_decay": 0.0001,
  "n_epochs": 3000,
  "init_scale": 1.0,
  "use_grokfast": false,
  "grad_clip": 1.0,
  "run_id": "20260505-080836_standard_n300_s42",
  "run_dir": "experiments/runs/20260505-080836_standard_n300_s42"
}

results/summary.json:

{
  "run_id": "20260505-080836_standard_n300_s42",
  "condition": "standard",
  "n_train": 300,
  "seed": 42,
  "best_id_val": 0.858373063170441,
  "best_ood": 0.7647259388153408,
  "peak_ood_epoch": 250,
  "final_ood": 0.7051637783055471,
  "ood_delta": -0.05956216050979368,
  "ood_improvement": -0.026599572036588692,
  "grokking_epoch": -1,
  "irm_drop_pct": 99.99996794236725,
  "irm_drop_epoch": 100,
  "epoch_gap": -1,
  "final_weight_norm": 452.7899771783757,
  "final_feature_rank": 17.35470962524414,
  "final_irm": 1.032695706726372e-09,
  "final_shortcut_ratio": 1.0080115087353168,
  "final_ood_gap": 0.12806029797574014
}

20260505-080442_grokking_n500_s42

config.json:

{
  "seed": 42,
  "n_train": 500,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "grokking",
  "lr": 0.001,
  "weight_decay": 0.005,
  "n_epochs": 3000,
  "init_scale": 4.0,
  "use_grokfast": true,
  "grokfast_alpha": 0.98,
  "grokfast_lamb": 2.0,
  "grad_clip": 1.0,
  "run_id": "20260505-080442_grokking_n500_s42",
  "run_dir": "experiments/runs/20260505-080442_grokking_n500_s42"
}

results/summary.json:

{
  "run_id": "20260505-080442_grokking_n500_s42",
  "condition": "grokking",
  "n_train": 500,
  "seed": 42,
  "best_id_val": 0.8873957091775924,
  "best_ood": 0.7924495026688927,
  "peak_ood_epoch": 50,
  "final_ood": 0.5514496672702048,
  "ood_delta": -0.24099983539868786,
  "ood_improvement": -0.06643779246126003,
  "grokking_epoch": -1,
  "irm_drop_pct": 99.99999995667906,
  "irm_drop_epoch": 50,
  "epoch_gap": -1,
  "final_weight_norm": 1046.672482929869,
  "final_feature_rank": 34.301780700683594,
  "final_irm": 6.552892841682478e-07,
  "final_shortcut_ratio": 0.9969108279290858,
  "final_ood_gap": 0.32080897396936603
}

20260505-100720_standard_n500_s42

config.json:

{
  "seed": 42,
  "n_train": 500,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "standard",
  "lr": 0.001,
  "weight_decay": 0.0001,
  "n_epochs": 3000,
  "init_scale": 1.0,
  "use_grokfast": false,
  "grad_clip": 1.0,
  "run_id": "20260505-100720_standard_n500_s42",
  "run_dir": "experiments/runs/20260505-100720_standard_n500_s42"
}

results/summary.json:

{
  "run_id": "20260505-100720_standard_n500_s42",
  "condition": "standard",
  "n_train": 500,
  "seed": 42,
  "best_id_val": 0.8866507747318236,
  "best_ood": 0.7575775389752393,
  "peak_ood_epoch": 1050,
  "final_ood": 0.6525854163237472,
  "ood_delta": -0.10499212265149205,
  "ood_improvement": -0.08494603428410186,
  "grokking_epoch": -1,
  "irm_drop_pct": 99.99998953242944,
  "irm_drop_epoch": 50,
  "epoch_gap": -1,
  "final_weight_norm": 555.1170332945212,
  "final_feature_rank": 20.722705841064453,
  "final_irm": 5.05333608291636e-10,
  "final_shortcut_ratio": 0.9843676046096169,
  "final_ood_gap": 0.19765296269889876
}

20260505-080445_grokking_n1000_s42

config.json:

{
  "seed": 42,
  "n_train": 1000,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "grokking",
  "lr": 0.001,
  "weight_decay": 0.005,
  "n_epochs": 3000,
  "init_scale": 4.0,
  "use_grokfast": true,
  "grokfast_alpha": 0.98,
  "grokfast_lamb": 2.0,
  "grad_clip": 1.0,
  "run_id": "20260505-080445_grokking_n1000_s42",
  "run_dir": "experiments/runs/20260505-080445_grokking_n1000_s42"
}

results/summary.json:

{
  "run_id": "20260505-080445_grokking_n1000_s42",
  "condition": "grokking",
  "n_train": 1000,
  "seed": 42,
  "best_id_val": 0.8976460071513707,
  "best_ood": 0.7335575046441084,
  "peak_ood_epoch": 350,
  "final_ood": 0.6639076351494345,
  "ood_delta": -0.06964986949467389,
  "ood_improvement": 0.011399816587109313,
  "grokking_epoch": -1,
  "irm_drop_pct": 99.99999500910222,
  "irm_drop_epoch": 50,
  "epoch_gap": -1,
  "final_weight_norm": 1470.4773196265805,
  "final_feature_rank": 35.86945724487305,
  "final_irm": 4.340286376830482e-12,
  "final_shortcut_ratio": 0.9871634909839839,
  "final_ood_gap": 0.20680154244293736
}

20260505-100720_grokking_n1000_s123

config.json:

{
  "seed": 123,
  "n_train": 1000,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "grokking",
  "lr": 0.001,
  "weight_decay": 0.005,
  "n_epochs": 3000,
  "init_scale": 4.0,
  "use_grokfast": true,
  "grokfast_alpha": 0.98,
  "grokfast_lamb": 2.0,
  "grad_clip": 1.0,
  "run_id": "20260505-100720_grokking_n1000_s123",
  "run_dir": "experiments/runs/20260505-100720_grokking_n1000_s123"
}

results/summary.json:

{
  "run_id": "20260505-100720_grokking_n1000_s123",
  "condition": "grokking",
  "n_train": 1000,
  "seed": 123,
  "best_id_val": 0.8994338498212158,
  "best_ood": 0.7269734521598044,
  "peak_ood_epoch": 350,
  "final_ood": 0.6446610388694242,
  "ood_delta": -0.08231241329038019,
  "ood_improvement": 0.01956639311496222,
  "grokking_epoch": -1,
  "irm_drop_pct": 99.9999991755092,
  "irm_drop_epoch": 50,
  "epoch_gap": -1,
  "final_weight_norm": 1457.2260357210637,
  "final_feature_rank": 56.64516067504883,
  "final_irm": 6.7278793620213426e-15,
  "final_shortcut_ratio": 0.9760735232865748,
  "final_ood_gap": 0.23534492060614198
}

20260505-100720_grokking_n1000_s456

config.json:

{
  "seed": 456,
  "n_train": 1000,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "grokking",
  "lr": 0.001,
  "weight_decay": 0.005,
  "n_epochs": 3000,
  "init_scale": 4.0,
  "use_grokfast": true,
  "grokfast_alpha": 0.98,
  "grokfast_lamb": 2.0,
  "grad_clip": 1.0,
  "run_id": "20260505-100720_grokking_n1000_s456",
  "run_dir": "experiments/runs/20260505-100720_grokking_n1000_s456"
}

results/summary.json:

{
  "run_id": "20260505-100720_grokking_n1000_s456",
  "condition": "grokking",
  "n_train": 1000,
  "seed": 456,
  "best_id_val": 0.8824493444576877,
  "best_ood": 0.6721847297011311,
  "peak_ood_epoch": 1100,
  "final_ood": 0.522397535683213,
  "ood_delta": -0.14978719401791807,
  "ood_improvement": -0.08302960472170617,
  "grokking_epoch": -1,
  "irm_drop_pct": 99.99999977269624,
  "irm_drop_epoch": 50,
  "epoch_gap": -1,
  "final_weight_norm": 1493.580040593733,
  "final_feature_rank": 64.54296875,
  "final_irm": 2.8693030174054e-09,
  "final_shortcut_ratio": 1.0356636404914459,
  "final_ood_gap": 0.3221495441737596
}

20260508-183413_grokking_n1000_s7

config.json:

{
  "seed": 7,
  "n_train": 1000,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "grokking",
  "lr": 0.001,
  "weight_decay": 0.005,
  "n_epochs": 3000,
  "init_scale": 4.0,
  "use_grokfast": true,
  "grokfast_alpha": 0.98,
  "grokfast_lamb": 2.0,
  "grad_clip": 1.0,
  "run_id": "20260508-183413_grokking_n1000_s7",
  "run_dir": "experiments/runs/20260508-183413_grokking_n1000_s7"
}

results/summary.json:

{
  "run_id": "20260508-183413_grokking_n1000_s7",
  "condition": "grokking",
  "n_train": 1000,
  "seed": 7,
  "best_id_val": 0.8797079856972586,
  "best_ood": 0.6876454958026665,
  "peak_ood_epoch": 50,
  "final_ood": 0.5881910315799375,
  "ood_delta": -0.09945446422272908,
  "ood_improvement": -0.03798527993980294,
  "grokking_epoch": -1,
  "irm_drop_pct": 99.99999355254296,
  "irm_drop_epoch": 50,
  "epoch_gap": -1,
  "final_weight_norm": 1516.661424559645,
  "final_feature_rank": 69.8335952758789,
  "final_irm": 4.471252361415434e-12,
  "final_shortcut_ratio": 0.996297612800058,
  "final_ood_gap": 0.26586141180504463
}

20260508-183413_grokking_n1000_s2024

config.json:

{
  "seed": 2024,
  "n_train": 1000,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "grokking",
  "lr": 0.001,
  "weight_decay": 0.005,
  "n_epochs": 3000,
  "init_scale": 4.0,
  "use_grokfast": true,
  "grokfast_alpha": 0.98,
  "grokfast_lamb": 2.0,
  "grad_clip": 1.0,
  "run_id": "20260508-183413_grokking_n1000_s2024",
  "run_dir": "experiments/runs/20260508-183413_grokking_n1000_s2024"
}

results/summary.json:

{
  "run_id": "20260508-183413_grokking_n1000_s2024",
  "condition": "grokking",
  "n_train": 1000,
  "seed": 2024,
  "best_id_val": 0.8959177592371871,
  "best_ood": 0.7056105532955534,
  "peak_ood_epoch": 400,
  "final_ood": 0.5506031462365085,
  "ood_delta": -0.15500740705904492,
  "ood_improvement": -0.04230488865896964,
  "grokking_epoch": -1,
  "irm_drop_pct": 99.9999998741651,
  "irm_drop_epoch": 50,
  "epoch_gap": -1,
  "final_weight_norm": 1632.3948325021925,
  "final_feature_rank": 65.77389526367188,
  "final_irm": 4.5635979972757923e-07,
  "final_shortcut_ratio": 0.9633737610070339,
  "final_ood_gap": 0.308663838268855
}

20260505-100720_standard_n1000_s42

config.json:

{
  "seed": 42,
  "n_train": 1000,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "standard",
  "lr": 0.001,
  "weight_decay": 0.0001,
  "n_epochs": 3000,
  "init_scale": 1.0,
  "use_grokfast": false,
  "grad_clip": 1.0,
  "run_id": "20260505-100720_standard_n1000_s42",
  "run_dir": "experiments/runs/20260505-100720_standard_n1000_s42"
}

results/summary.json:

{
  "run_id": "20260505-100720_standard_n1000_s42",
  "condition": "standard",
  "n_train": 1000,
  "seed": 42,
  "best_id_val": 0.9011025029797378,
  "best_ood": 0.7615162132292426,
  "peak_ood_epoch": 1,
  "final_ood": 0.6482234815528958,
  "ood_delta": -0.11329273167634679,
  "ood_improvement": -0.022357561078844124,
  "grokking_epoch": -1,
  "irm_drop_pct": 99.99999329006182,
  "irm_drop_epoch": 50,
  "epoch_gap": -1,
  "final_weight_norm": 812.5540534619066,
  "final_feature_rank": 33.34842300415039,
  "final_irm": 2.2439123855463178e-13,
  "final_shortcut_ratio": 0.99423543483118,
  "final_ood_gap": 0.24736650652815306
}

20260508-183413_standard_n1000_s123

config.json:

{
  "seed": 123,
  "n_train": 1000,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "standard",
  "lr": 0.001,
  "weight_decay": 0.0001,
  "n_epochs": 3000,
  "init_scale": 1.0,
  "use_grokfast": false,
  "grad_clip": 1.0,
  "run_id": "20260508-183413_standard_n1000_s123",
  "run_dir": "experiments/runs/20260508-183413_standard_n1000_s123"
}

results/summary.json:

{
  "run_id": "20260508-183413_standard_n1000_s123",
  "condition": "standard",
  "n_train": 1000,
  "seed": 123,
  "best_id_val": 0.8957091775923719,
  "best_ood": 0.8879652926376185,
  "peak_ood_epoch": 1,
  "final_ood": 0.6644837397418111,
  "ood_delta": -0.2234815528958074,
  "ood_improvement": -0.10371998965363183,
  "grokking_epoch": -1,
  "irm_drop_pct": 99.9999667168104,
  "irm_drop_epoch": 150,
  "epoch_gap": -1,
  "final_weight_norm": 798.3091903337586,
  "final_feature_rank": 37.0809326171875,
  "final_irm": 9.526699497634447e-14,
  "final_shortcut_ratio": 0.9913218948516812,
  "final_ood_gap": 0.22869266073494698
}

20260508-183413_standard_n1000_s456

config.json:

{
  "seed": 456,
  "n_train": 1000,
  "batch_size": 32,
  "img_size": 96,
  "n_classes": 2,
  "log_every": 50,
  "device": "cuda",
  "condition": "standard",
  "lr": 0.001,
  "weight_decay": 0.0001,
  "n_epochs": 3000,
  "init_scale": 1.0,
  "use_grokfast": false,
  "grad_clip": 1.0,
  "run_id": "20260508-183413_standard_n1000_s456",
  "run_dir": "experiments/runs/20260508-183413_standard_n1000_s456"
}

results/summary.json:

{
  "run_id": "20260508-183413_standard_n1000_s456",
  "condition": "standard",
  "n_train": 1000,
  "seed": 456,
  "best_id_val": 0.8949940405244339,
  "best_ood": 0.7449737813624285,
  "peak_ood_epoch": 1050,
  "final_ood": 0.5783149528534813,
  "ood_delta": -0.1666588285089472,
  "ood_improvement": 0.02752839372633853,
  "grokking_epoch": -1,
  "irm_drop_pct": 99.99999467428148,
  "irm_drop_epoch": 50,
  "epoch_gap": -1,
  "final_weight_norm": 792.3747983792097,
  "final_feature_rank": 35.297889709472656,
  "final_irm": 7.180730676736857e-09,
  "final_shortcut_ratio": 0.987733651871859,
  "final_ood_gap": 0.2792237837376986
}

18. Full training log: grokking n=1000 seed=42

# launched: 2026-05-05T08:04:45Z
# host:     ubuntu-Standard-PC-Q35-ICH9-2009
# pwd:      /home/garima/CausalGrok
# cmd:      /home/garima/anaconda3/envs/causalgrok/bin/python -u -m experiments.causalgrok_camelyon_v2 --condition grokking --n_train 1000 --seed 42 --run_dir experiments/runs/20260505-080445_grokking_n1000_s42 --wandb_project causalgrok --wandb_mode offline
----

Device:  cuda
Run ID:  20260505-080445_grokking_n1000_s42
Started: 2026-05-05T08:04:50.408035+00:00
  Env hospital=0: 181 samples, positive rate=0.53
  Env hospital=3: 371 samples, positive rate=0.48
  Env hospital=4: 448 samples, positive rate=0.50
Train: 1000 | ID val (H3): 33560 | OOD test (H4): 85054
Params: 11,177,538

============================================================
  GROKKING | Camelyon17 v2 | 3000 epochs
  WD=0.005 | α=4.0 | n=1000
  Tracking: ID val (H3) + OOD test (H4) at every checkpoint
  Grokking detection: watching OOD acc, not ID val acc
  IRM envs: 3 hospitals
============================================================
  ep     1 | tr 0.734 | id 0.728 | ood 0.499 | gap +0.228 | ‖W‖ 356.1 | rank 109.6 | IRM 0.2004 | sc 1.12x
  ep    50 | tr 0.996 | id 0.859 | ood 0.697 | gap +0.162 | ‖W‖ 495.4 | rank 91.4 | IRM 0.0001 | sc 0.97x
  ep   100 | tr 1.000 | id 0.883 | ood 0.592 | gap +0.291 | ‖W‖ 551.6 | rank 81.7 | IRM 0.0000 | sc 0.96x
  ep   150 | tr 1.000 | id 0.872 | ood 0.641 | gap +0.231 | ‖W‖ 650.2 | rank 72.2 | IRM 0.0000 | sc 0.98x
  ✓ Checkpoint → ep00200.pt
  ep   200 | tr 1.000 | id 0.885 | ood 0.613 | gap +0.272 | ‖W‖ 741.5 | rank 62.6 | IRM 0.0000 | sc 0.97x
  ep   250 | tr 1.000 | id 0.890 | ood 0.659 | gap +0.231 | ‖W‖ 842.7 | rank 61.1 | IRM 0.0000 | sc 0.99x
  ep   300 | tr 1.000 | id 0.876 | ood 0.650 | gap +0.227 | ‖W‖ 878.1 | rank 60.7 | IRM 0.0000 | sc 0.99x
  ep   350 | tr 1.000 | id 0.881 | ood 0.734 | gap +0.147 | ‖W‖ 907.9 | rank 53.9 | IRM 0.0000 | sc 0.97x
  ✓ Checkpoint → ep00400.pt
  ep   400 | tr 1.000 | id 0.875 | ood 0.647 | gap +0.229 | ‖W‖ 1005.7 | rank 60.9 | IRM 0.0000 | sc 1.00x
  ep   450 | tr 1.000 | id 0.876 | ood 0.615 | gap +0.261 | ‖W‖ 1031.2 | rank 57.8 | IRM 0.0000 | sc 0.99x
  ep   500 | tr 1.000 | id 0.880 | ood 0.607 | gap +0.273 | ‖W‖ 1023.1 | rank 55.7 | IRM 0.0000 | sc 0.99x
  ep   550 | tr 1.000 | id 0.888 | ood 0.611 | gap +0.276 | ‖W‖ 1101.7 | rank 52.6 | IRM 0.0000 | sc 0.99x
  ✓ Checkpoint → ep00600.pt
  ep   600 | tr 1.000 | id 0.884 | ood 0.546 | gap +0.337 | ‖W‖ 1102.1 | rank 54.5 | IRM 0.0000 | sc 1.00x
  ep   650 | tr 1.000 | id 0.878 | ood 0.605 | gap +0.273 | ‖W‖ 1142.4 | rank 55.2 | IRM 0.0000 | sc 0.98x
  ep   700 | tr 1.000 | id 0.889 | ood 0.587 | gap +0.303 | ‖W‖ 1187.3 | rank 47.9 | IRM 0.0000 | sc 0.98x
  ep   750 | tr 1.000 | id 0.892 | ood 0.670 | gap +0.222 | ‖W‖ 1180.7 | rank 48.4 | IRM 0.0000 | sc 0.98x
  ✓ Checkpoint → ep00800.pt
  ep   800 | tr 1.000 | id 0.879 | ood 0.675 | gap +0.204 | ‖W‖ 1265.4 | rank 48.0 | IRM 0.0000 | sc 0.99x
  ep   850 | tr 1.000 | id 0.886 | ood 0.627 | gap +0.259 | ‖W‖ 1300.7 | rank 48.3 | IRM 0.0000 | sc 0.98x
  ep   900 | tr 1.000 | id 0.886 | ood 0.653 | gap +0.233 | ‖W‖ 1290.6 | rank 48.0 | IRM 0.0000 | sc 0.98x
  ep   950 | tr 1.000 | id 0.883 | ood 0.640 | gap +0.243 | ‖W‖ 1280.4 | rank 47.1 | IRM 0.0000 | sc 0.98x
  ✓ Checkpoint → ep01000.pt
  ep  1000 | tr 1.000 | id 0.886 | ood 0.653 | gap +0.233 | ‖W‖ 1270.3 | rank 47.3 | IRM 0.0000 | sc 0.99x
  ep  1050 | tr 1.000 | id 0.898 | ood 0.697 | gap +0.201 | ‖W‖ 1289.1 | rank 42.4 | IRM 0.0000 | sc 0.99x
  ep  1100 | tr 0.999 | id 0.876 | ood 0.641 | gap +0.235 | ‖W‖ 1308.0 | rank 46.4 | IRM 0.0000 | sc 0.99x
  ep  1150 | tr 1.000 | id 0.894 | ood 0.663 | gap +0.231 | ‖W‖ 1313.5 | rank 45.8 | IRM 0.0000 | sc 0.99x
  ✓ Checkpoint → ep01200.pt
  ep  1200 | tr 0.999 | id 0.878 | ood 0.592 | gap +0.286 | ‖W‖ 1320.4 | rank 42.9 | IRM 0.0000 | sc 0.99x
  ep  1250 | tr 1.000 | id 0.877 | ood 0.546 | gap +0.330 | ‖W‖ 1373.5 | rank 47.2 | IRM 0.0000 | sc 0.99x
  ep  1300 | tr 1.000 | id 0.853 | ood 0.600 | gap +0.254 | ‖W‖ 1375.3 | rank 48.0 | IRM 0.0000 | sc 0.98x
  ep  1350 | tr 1.000 | id 0.881 | ood 0.575 | gap +0.306 | ‖W‖ 1415.1 | rank 49.4 | IRM 0.0000 | sc 0.99x
  ✓ Checkpoint → ep01400.pt
  ep  1400 | tr 1.000 | id 0.892 | ood 0.605 | gap +0.287 | ‖W‖ 1428.3 | rank 42.5 | IRM 0.0000 | sc 1.00x
  ep  1450 | tr 1.000 | id 0.867 | ood 0.619 | gap +0.248 | ‖W‖ 1438.2 | rank 41.4 | IRM 0.0000 | sc 0.99x
  ep  1500 | tr 1.000 | id 0.875 | ood 0.652 | gap +0.223 | ‖W‖ 1440.8 | rank 47.2 | IRM 0.0000 | sc 0.97x
  ep  1550 | tr 1.000 | id 0.879 | ood 0.643 | gap +0.236 | ‖W‖ 1429.5 | rank 45.8 | IRM 0.0000 | sc 0.97x
  ✓ Checkpoint → ep01600.pt
  ep  1600 | tr 1.000 | id 0.878 | ood 0.630 | gap +0.248 | ‖W‖ 1418.2 | rank 46.4 | IRM 0.0000 | sc 0.98x
  ep  1650 | tr 1.000 | id 0.881 | ood 0.636 | gap +0.245 | ‖W‖ 1406.9 | rank 45.9 | IRM 0.0000 | sc 0.98x
  ep  1700 | tr 1.000 | id 0.883 | ood 0.634 | gap +0.249 | ‖W‖ 1395.7 | rank 47.0 | IRM 0.0000 | sc 0.98x
  ep  1750 | tr 1.000 | id 0.875 | ood 0.702 | gap +0.173 | ‖W‖ 1441.2 | rank 45.9 | IRM 0.0000 | sc 0.99x
  ✓ Checkpoint → ep01800.pt
  ep  1800 | tr 1.000 | id 0.880 | ood 0.648 | gap +0.232 | ‖W‖ 1433.0 | rank 45.6 | IRM 0.0000 | sc 0.98x
  ep  1850 | tr 1.000 | id 0.880 | ood 0.644 | gap +0.236 | ‖W‖ 1421.6 | rank 45.8 | IRM 0.0000 | sc 0.98x
  ep  1900 | tr 1.000 | id 0.885 | ood 0.644 | gap +0.242 | ‖W‖ 1410.4 | rank 43.9 | IRM 0.0000 | sc 0.98x
  ep  1950 | tr 1.000 | id 0.883 | ood 0.549 | gap +0.334 | ‖W‖ 1415.3 | rank 49.1 | IRM 0.0000 | sc 0.98x
  ✓ Checkpoint → ep02000.pt
  ep  2000 | tr 1.000 | id 0.889 | ood 0.581 | gap +0.308 | ‖W‖ 1404.6 | rank 47.1 | IRM 0.0000 | sc 0.98x
  ep  2050 | tr 1.000 | id 0.888 | ood 0.577 | gap +0.311 | ‖W‖ 1393.4 | rank 46.6 | IRM 0.0000 | sc 0.98x
  ep  2100 | tr 1.000 | id 0.884 | ood 0.617 | gap +0.266 | ‖W‖ 1460.6 | rank 33.9 | IRM 0.0000 | sc 1.00x
  ep  2150 | tr 1.000 | id 0.870 | ood 0.597 | gap +0.273 | ‖W‖ 1470.9 | rank 37.5 | IRM 0.0000 | sc 1.00x
  ✓ Checkpoint → ep02200.pt
  ep  2200 | tr 1.000 | id 0.869 | ood 0.568 | gap +0.301 | ‖W‖ 1460.2 | rank 38.5 | IRM 0.0000 | sc 0.99x
  ep  2250 | tr 1.000 | id 0.870 | ood 0.588 | gap +0.282 | ‖W‖ 1448.6 | rank 36.9 | IRM 0.0000 | sc 0.98x
  ep  2300 | tr 0.998 | id 0.872 | ood 0.706 | gap +0.166 | ‖W‖ 1485.2 | rank 41.9 | IRM 0.0000 | sc 0.97x
  ep  2350 | tr 1.000 | id 0.877 | ood 0.648 | gap +0.229 | ‖W‖ 1506.9 | rank 41.2 | IRM 0.0000 | sc 1.00x
  ✓ Checkpoint → ep02400.pt
  ep  2400 | tr 1.000 | id 0.876 | ood 0.650 | gap +0.226 | ‖W‖ 1495.0 | rank 40.1 | IRM 0.0000 | sc 1.00x
  ep  2450 | tr 1.000 | id 0.869 | ood 0.682 | gap +0.187 | ‖W‖ 1486.6 | rank 36.8 | IRM 0.0000 | sc 0.98x
  ep  2500 | tr 1.000 | id 0.884 | ood 0.621 | gap +0.263 | ‖W‖ 1510.2 | rank 40.7 | IRM 0.0000 | sc 1.00x
  ep  2550 | tr 1.000 | id 0.876 | ood 0.667 | gap +0.209 | ‖W‖ 1499.1 | rank 38.9 | IRM 0.0000 | sc 0.99x
  ✓ Checkpoint → ep02600.pt
  ep  2600 | tr 1.000 | id 0.871 | ood 0.635 | gap +0.236 | ‖W‖ 1506.6 | rank 39.4 | IRM 0.0000 | sc 1.01x
  ep  2650 | tr 1.000 | id 0.874 | ood 0.692 | gap +0.183 | ‖W‖ 1500.6 | rank 36.8 | IRM 0.0000 | sc 0.99x
  ep  2700 | tr 1.000 | id 0.877 | ood 0.607 | gap +0.269 | ‖W‖ 1510.5 | rank 39.9 | IRM 0.0000 | sc 0.97x
  ep  2750 | tr 1.000 | id 0.873 | ood 0.599 | gap +0.273 | ‖W‖ 1521.2 | rank 38.5 | IRM 0.0000 | sc 1.00x
  ✓ Checkpoint → ep02800.pt
  ep  2800 | tr 1.000 | id 0.878 | ood 0.584 | gap +0.295 | ‖W‖ 1515.0 | rank 36.9 | IRM 0.0000 | sc 1.00x
  ep  2850 | tr 1.000 | id 0.878 | ood 0.619 | gap +0.259 | ‖W‖ 1504.0 | rank 36.2 | IRM 0.0000 | sc 0.99x
  ep  2900 | tr 1.000 | id 0.880 | ood 0.614 | gap +0.266 | ‖W‖ 1492.6 | rank 35.7 | IRM 0.0000 | sc 0.99x
  ep  2950 | tr 1.000 | id 0.871 | ood 0.619 | gap +0.252 | ‖W‖ 1482.1 | rank 36.4 | IRM 0.0000 | sc 0.99x
  ✓ Checkpoint → ep03000.pt
  ep  3000 | tr 1.000 | id 0.871 | ood 0.664 | gap +0.207 | ‖W‖ 1470.5 | rank 35.9 | IRM 0.0000 | sc 0.99x

  Best ID val (H3): 0.8976
  Best OOD (H4):    0.7336
  OOD improvement:  +0.0114  ← did OOD grok?
  Grokking at:      None
  IRM drop:         100.0%

Wall time: 358.5 min

19. Full training log: standard n=1000 seed=42

# launched: 2026-05-05T10:07:20Z
# host:     ubuntu-Standard-PC-Q35-ICH9-2009
# pwd:      /home/garima/CausalGrok
# cmd:      /home/garima/anaconda3/envs/causalgrok/bin/python -u -m experiments.causalgrok_camelyon_v2 --condition standard --n_train 1000 --seed 42 --run_dir experiments/runs/20260505-100720_standard_n1000_s42 --wandb_project causalgrok --wandb_mode offline --n_epochs 3000
----

Device:  cuda
Run ID:  20260505-100720_standard_n1000_s42
Started: 2026-05-05T10:07:36.596218+00:00
  Env hospital=0: 181 samples, positive rate=0.53
  Env hospital=3: 371 samples, positive rate=0.48
  Env hospital=4: 448 samples, positive rate=0.50
Train: 1000 | ID val (H3): 33560 | OOD test (H4): 85054
Params: 11,177,538

============================================================
  STANDARD | Camelyon17 v2 | 3000 epochs
  WD=0.0001 | α=1.0 | n=1000
  Tracking: ID val (H3) + OOD test (H4) at every checkpoint
  Grokking detection: watching OOD acc, not ID val acc
  IRM envs: 3 hospitals
============================================================
  ep     1 | tr 0.681 | id 0.664 | ood 0.762 | gap -0.098 | ‖W‖ 105.0 | rank 78.5 | IRM 0.1490 | sc 1.25x
  ep    50 | tr 0.999 | id 0.897 | ood 0.620 | gap +0.276 | ‖W‖ 162.5 | rank 45.1 | IRM 0.0000 | sc 0.99x
  ep   100 | tr 1.000 | id 0.901 | ood 0.722 | gap +0.179 | ‖W‖ 196.1 | rank 35.1 | IRM 0.0000 | sc 0.98x
  ep   150 | tr 0.997 | id 0.880 | ood 0.648 | gap +0.232 | ‖W‖ 228.8 | rank 29.4 | IRM 0.0000 | sc 0.98x
  ✓ Checkpoint → ep00200.pt
  ep   200 | tr 1.000 | id 0.886 | ood 0.611 | gap +0.274 | ‖W‖ 268.7 | rank 30.1 | IRM 0.0000 | sc 0.99x
  ep   250 | tr 1.000 | id 0.890 | ood 0.672 | gap +0.218 | ‖W‖ 294.3 | rank 30.9 | IRM 0.0000 | sc 0.97x
  ep   300 | tr 1.000 | id 0.900 | ood 0.684 | gap +0.216 | ‖W‖ 323.3 | rank 26.7 | IRM 0.0000 | sc 0.98x
  ep   350 | tr 1.000 | id 0.891 | ood 0.573 | gap +0.318 | ‖W‖ 343.5 | rank 27.9 | IRM 0.0000 | sc 0.98x
  ✓ Checkpoint → ep00400.pt
  ep   400 | tr 1.000 | id 0.885 | ood 0.642 | gap +0.243 | ‖W‖ 361.4 | rank 28.6 | IRM 0.0000 | sc 0.98x
  ep   450 | tr 1.000 | id 0.894 | ood 0.700 | gap +0.194 | ‖W‖ 377.9 | rank 31.3 | IRM 0.0000 | sc 0.98x
  ep   500 | tr 1.000 | id 0.890 | ood 0.705 | gap +0.185 | ‖W‖ 378.2 | rank 29.3 | IRM 0.0000 | sc 0.97x
  ep   550 | tr 1.000 | id 0.895 | ood 0.656 | gap +0.239 | ‖W‖ 412.9 | rank 26.5 | IRM 0.0000 | sc 0.97x
  ✓ Checkpoint → ep00600.pt
  ep   600 | tr 1.000 | id 0.862 | ood 0.717 | gap +0.145 | ‖W‖ 426.0 | rank 29.7 | IRM 0.0000 | sc 0.97x
  ep   650 | tr 1.000 | id 0.885 | ood 0.713 | gap +0.172 | ‖W‖ 445.0 | rank 25.9 | IRM 0.0000 | sc 0.99x
  ep   700 | tr 1.000 | id 0.892 | ood 0.639 | gap +0.253 | ‖W‖ 454.4 | rank 28.0 | IRM 0.0000 | sc 0.98x
  ep   750 | tr 1.000 | id 0.880 | ood 0.648 | gap +0.232 | ‖W‖ 472.1 | rank 25.5 | IRM 0.0000 | sc 0.98x
  ✓ Checkpoint → ep00800.pt
  ep   800 | tr 1.000 | id 0.888 | ood 0.681 | gap +0.207 | ‖W‖ 489.6 | rank 28.8 | IRM 0.0000 | sc 0.97x
  ep   850 | tr 1.000 | id 0.887 | ood 0.626 | gap +0.262 | ‖W‖ 506.4 | rank 28.2 | IRM 0.0000 | sc 0.99x
  ep   900 | tr 1.000 | id 0.888 | ood 0.703 | gap +0.185 | ‖W‖ 515.4 | rank 31.3 | IRM 0.0000 | sc 0.99x
  ep   950 | tr 1.000 | id 0.883 | ood 0.667 | gap +0.215 | ‖W‖ 526.2 | rank 27.6 | IRM 0.0000 | sc 1.00x
  ✓ Checkpoint → ep01000.pt
  ep  1000 | tr 1.000 | id 0.897 | ood 0.674 | gap +0.222 | ‖W‖ 530.5 | rank 26.2 | IRM 0.0000 | sc 1.00x
  ep  1050 | tr 1.000 | id 0.896 | ood 0.581 | gap +0.315 | ‖W‖ 544.2 | rank 26.3 | IRM 0.0000 | sc 0.98x
  ep  1100 | tr 1.000 | id 0.877 | ood 0.655 | gap +0.222 | ‖W‖ 559.3 | rank 26.2 | IRM 0.0000 | sc 1.00x
  ep  1150 | tr 1.000 | id 0.899 | ood 0.694 | gap +0.205 | ‖W‖ 562.6 | rank 23.7 | IRM 0.0000 | sc 0.99x
  ✓ Checkpoint → ep01200.pt
  ep  1200 | tr 1.000 | id 0.892 | ood 0.578 | gap +0.315 | ‖W‖ 581.9 | rank 27.0 | IRM 0.0000 | sc 0.98x
  ep  1250 | tr 1.000 | id 0.889 | ood 0.647 | gap +0.242 | ‖W‖ 595.8 | rank 30.1 | IRM 0.0000 | sc 0.99x
  ep  1300 | tr 1.000 | id 0.880 | ood 0.616 | gap +0.264 | ‖W‖ 604.9 | rank 31.7 | IRM 0.0000 | sc 0.98x
  ep  1350 | tr 1.000 | id 0.878 | ood 0.649 | gap +0.228 | ‖W‖ 606.3 | rank 30.0 | IRM 0.0000 | sc 0.98x
  ✓ Checkpoint → ep01400.pt
  ep  1400 | tr 1.000 | id 0.879 | ood 0.735 | gap +0.144 | ‖W‖ 624.0 | rank 32.7 | IRM 0.0000 | sc 0.96x
  ep  1450 | tr 1.000 | id 0.885 | ood 0.703 | gap +0.182 | ‖W‖ 625.3 | rank 33.0 | IRM 0.0000 | sc 0.98x
  ep  1500 | tr 1.000 | id 0.894 | ood 0.686 | gap +0.207 | ‖W‖ 626.2 | rank 30.6 | IRM 0.0000 | sc 0.98x
  ep  1550 | tr 1.000 | id 0.879 | ood 0.670 | gap +0.209 | ‖W‖ 640.3 | rank 32.3 | IRM 0.0000 | sc 0.97x
  ✓ Checkpoint → ep01600.pt
  ep  1600 | tr 1.000 | id 0.871 | ood 0.669 | gap +0.202 | ‖W‖ 653.6 | rank 31.0 | IRM 0.0000 | sc 0.98x
  ep  1650 | tr 1.000 | id 0.886 | ood 0.540 | gap +0.346 | ‖W‖ 662.7 | rank 34.4 | IRM 0.0000 | sc 0.98x
  ep  1700 | tr 1.000 | id 0.897 | ood 0.659 | gap +0.239 | ‖W‖ 667.4 | rank 38.2 | IRM 0.0000 | sc 0.97x
  ep  1750 | tr 1.000 | id 0.871 | ood 0.716 | gap +0.155 | ‖W‖ 678.7 | rank 31.7 | IRM 0.0000 | sc 0.97x
  ✓ Checkpoint → ep01800.pt
  ep  1800 | tr 1.000 | id 0.887 | ood 0.665 | gap +0.222 | ‖W‖ 682.0 | rank 33.0 | IRM 0.0000 | sc 0.98x
  ep  1850 | tr 1.000 | id 0.892 | ood 0.598 | gap +0.294 | ‖W‖ 685.4 | rank 31.1 | IRM 0.0000 | sc 0.98x
  ep  1900 | tr 1.000 | id 0.890 | ood 0.574 | gap +0.316 | ‖W‖ 690.9 | rank 32.8 | IRM 0.0000 | sc 0.99x
  ep  1950 | tr 1.000 | id 0.889 | ood 0.623 | gap +0.265 | ‖W‖ 706.3 | rank 30.0 | IRM 0.0000 | sc 0.98x
  ✓ Checkpoint → ep02000.pt
  ep  2000 | tr 1.000 | id 0.890 | ood 0.595 | gap +0.296 | ‖W‖ 714.3 | rank 31.5 | IRM 0.0000 | sc 0.98x
  ep  2050 | tr 1.000 | id 0.888 | ood 0.596 | gap +0.292 | ‖W‖ 720.1 | rank 31.4 | IRM 0.0000 | sc 0.99x
  ep  2100 | tr 1.000 | id 0.860 | ood 0.686 | gap +0.173 | ‖W‖ 730.1 | rank 30.6 | IRM 0.0000 | sc 0.99x
  ep  2150 | tr 1.000 | id 0.892 | ood 0.712 | gap +0.180 | ‖W‖ 732.0 | rank 28.5 | IRM 0.0000 | sc 0.99x
  ✓ Checkpoint → ep02200.pt
  ep  2200 | tr 1.000 | id 0.881 | ood 0.621 | gap +0.260 | ‖W‖ 736.6 | rank 31.0 | IRM 0.0000 | sc 0.99x
  ep  2250 | tr 1.000 | id 0.882 | ood 0.621 | gap +0.261 | ‖W‖ 736.5 | rank 29.6 | IRM 0.0000 | sc 0.99x
  ep  2300 | tr 1.000 | id 0.877 | ood 0.641 | gap +0.236 | ‖W‖ 748.9 | rank 32.6 | IRM 0.0000 | sc 0.99x
  ep  2350 | tr 1.000 | id 0.883 | ood 0.652 | gap +0.230 | ‖W‖ 753.4 | rank 35.2 | IRM 0.0000 | sc 0.99x
  ✓ Checkpoint → ep02400.pt
  ep  2400 | tr 1.000 | id 0.884 | ood 0.673 | gap +0.212 | ‖W‖ 753.4 | rank 34.2 | IRM 0.0000 | sc 1.00x
  ep  2450 | tr 1.000 | id 0.887 | ood 0.677 | gap +0.209 | ‖W‖ 755.5 | rank 32.2 | IRM 0.0000 | sc 1.00x
  ep  2500 | tr 1.000 | id 0.887 | ood 0.601 | gap +0.286 | ‖W‖ 768.2 | rank 35.0 | IRM 0.0000 | sc 0.98x
  ep  2550 | tr 1.000 | id 0.880 | ood 0.645 | gap +0.235 | ‖W‖ 772.8 | rank 35.5 | IRM 0.0000 | sc 0.98x
  ✓ Checkpoint → ep02600.pt
  ep  2600 | tr 1.000 | id 0.884 | ood 0.631 | gap +0.253 | ‖W‖ 774.0 | rank 34.8 | IRM 0.0000 | sc 0.99x
  ep  2650 | tr 1.000 | id 0.885 | ood 0.584 | gap +0.301 | ‖W‖ 776.4 | rank 37.4 | IRM 0.0000 | sc 0.99x
  ep  2700 | tr 1.000 | id 0.888 | ood 0.597 | gap +0.291 | ‖W‖ 780.4 | rank 36.2 | IRM 0.0000 | sc 0.99x
  ep  2750 | tr 1.000 | id 0.895 | ood 0.683 | gap +0.212 | ‖W‖ 786.8 | rank 35.2 | IRM 0.0000 | sc 0.99x
  ✓ Checkpoint → ep02800.pt
  ep  2800 | tr 1.000 | id 0.894 | ood 0.597 | gap +0.297 | ‖W‖ 794.5 | rank 34.0 | IRM 0.0000 | sc 0.99x
  ep  2850 | tr 1.000 | id 0.874 | ood 0.623 | gap +0.252 | ‖W‖ 803.4 | rank 32.7 | IRM 0.0000 | sc 0.99x
  ep  2900 | tr 1.000 | id 0.891 | ood 0.679 | gap +0.212 | ‖W‖ 808.3 | rank 35.3 | IRM 0.0000 | sc 0.98x
  ep  2950 | tr 1.000 | id 0.886 | ood 0.704 | gap +0.182 | ‖W‖ 811.6 | rank 33.7 | IRM 0.0000 | sc 1.00x
  ✓ Checkpoint → ep03000.pt
  ep  3000 | tr 1.000 | id 0.896 | ood 0.648 | gap +0.247 | ‖W‖ 812.6 | rank 33.3 | IRM 0.0000 | sc 0.99x

  Best ID val (H3): 0.9011
  Best OOD (H4):    0.7615
  OOD improvement:  -0.0224  ← did OOD grok?
  Grokking at:      None
  IRM drop:         100.0%

Wall time: 327.3 min

20. M5 — Full activation-steering JSONs (8 runs at n=1000)

Every run's full m5_steering_ep*.json, verbatim from disk. head_ood_acc is head OOD accuracy on the steered features h' = h + alpha * sigma * v_s. tumor_probe is the linear-probe accuracy on the same steered features. hospital_probe is NaN by construction (H4 hospital labels do not overlap with the training-hospital probe's class set).

20260505-080445_grokking_n1000_s42

{
  "run_id": "20260505-080445_grokking_n1000_s42",
  "epoch": 400,
  "layer": "avgpool",
  "max_samples": 800,
  "v_norm": 1.0,
  "sigma": 8.685188293457031,
  "sweep": [
    {
      "alpha": -3.0,
      "head_ood_acc": 0.63,
      "hospital_probe": NaN,
      "tumor_probe": 0.6625
    },
    {
      "alpha": -2.0,
      "head_ood_acc": 0.625,
      "hospital_probe": NaN,
      "tumor_probe": 0.6525
    },
    {
      "alpha": -1.0,
      "head_ood_acc": 0.64,
      "hospital_probe": NaN,
      "tumor_probe": 0.665
    },
    {
      "alpha": -0.5,
      "head_ood_acc": 0.6325,
      "hospital_probe": NaN,
      "tumor_probe": 0.655
    },
    {
      "alpha": 0.0,
      "head_ood_acc": 0.6525,
      "hospital_probe": NaN,
      "tumor_probe": 0.66
    },
    {
      "alpha": 0.5,
      "head_ood_acc": 0.6575,
      "hospital_probe": NaN,
      "tumor_probe": 0.665
    },
    {
      "alpha": 1.0,
      "head_ood_acc": 0.6575,
      "hospital_probe": NaN,
      "tumor_probe": 0.6675
    },
    {
      "alpha": 2.0,
      "head_ood_acc": 0.62,
      "hospital_probe": NaN,
      "tumor_probe": 0.65
    },
    {
      "alpha": 3.0,
      "head_ood_acc": 0.59,
      "hospital_probe": NaN,
      "tumor_probe": 0.635
    }
  ]
}

20260505-100720_grokking_n1000_s123

{
  "run_id": "20260505-100720_grokking_n1000_s123",
  "epoch": 400,
  "layer": "avgpool",
  "max_samples": 800,
  "v_norm": 1.0,
  "sigma": 6.000692844390869,
  "sweep": [
    {
      "alpha": -3.0,
      "head_ood_acc": 0.7325,
      "hospital_probe": NaN,
      "tumor_probe": 0.6925
    },
    {
      "alpha": -2.0,
      "head_ood_acc": 0.7325,
      "hospital_probe": NaN,
      "tumor_probe": 0.6925
    },
    {
      "alpha": -1.0,
      "head_ood_acc": 0.73,
      "hospital_probe": NaN,
      "tumor_probe": 0.695
    },
    {
      "alpha": -0.5,
      "head_ood_acc": 0.71,
      "hospital_probe": NaN,
      "tumor_probe": 0.695
    },
    {
      "alpha": 0.0,
      "head_ood_acc": 0.6975,
      "hospital_probe": NaN,
      "tumor_probe": 0.6925
    },
    {
      "alpha": 0.5,
      "head_ood_acc": 0.6925,
      "hospital_probe": NaN,
      "tumor_probe": 0.69
    },
    {
      "alpha": 1.0,
      "head_ood_acc": 0.69,
      "hospital_probe": NaN,
      "tumor_probe": 0.685
    },
    {
      "alpha": 2.0,
      "head_ood_acc": 0.68,
      "hospital_probe": NaN,
      "tumor_probe": 0.685
    },
    {
      "alpha": 3.0,
      "head_ood_acc": 0.6575,
      "hospital_probe": NaN,
      "tumor_probe": 0.685
    }
  ]
}

20260505-100720_grokking_n1000_s456

{
  "run_id": "20260505-100720_grokking_n1000_s456",
  "epoch": 1000,
  "layer": "avgpool",
  "max_samples": 800,
  "v_norm": 1.0,
  "sigma": 6.7488112449646,
  "sweep": [
    {
      "alpha": -3.0,
      "head_ood_acc": 0.6825,
      "hospital_probe": NaN,
      "tumor_probe": 0.63
    },
    {
      "alpha": -2.0,
      "head_ood_acc": 0.6575,
      "hospital_probe": NaN,
      "tumor_probe": 0.615
    },
    {
      "alpha": -1.0,
      "head_ood_acc": 0.64,
      "hospital_probe": NaN,
      "tumor_probe": 0.605
    },
    {
      "alpha": -0.5,
      "head_ood_acc": 0.63,
      "hospital_probe": NaN,
      "tumor_probe": 0.595
    },
    {
      "alpha": 0.0,
      "head_ood_acc": 0.6225,
      "hospital_probe": NaN,
      "tumor_probe": 0.595
    },
    {
      "alpha": 0.5,
      "head_ood_acc": 0.62,
      "hospital_probe": NaN,
      "tumor_probe": 0.5925
    },
    {
      "alpha": 1.0,
      "head_ood_acc": 0.61,
      "hospital_probe": NaN,
      "tumor_probe": 0.5925
    },
    {
      "alpha": 2.0,
      "head_ood_acc": 0.6025,
      "hospital_probe": NaN,
      "tumor_probe": 0.5825
    },
    {
      "alpha": 3.0,
      "head_ood_acc": 0.5975,
      "hospital_probe": NaN,
      "tumor_probe": 0.59
    }
  ]
}

20260508-183413_grokking_n1000_s7

{
  "run_id": "20260508-183413_grokking_n1000_s7",
  "epoch": 200,
  "layer": "avgpool",
  "max_samples": 800,
  "v_norm": 0.9999999403953552,
  "sigma": 5.532620429992676,
  "sweep": [
    {
      "alpha": -3.0,
      "head_ood_acc": 0.495,
      "hospital_probe": NaN,
      "tumor_probe": 0.58
    },
    {
      "alpha": -2.0,
      "head_ood_acc": 0.49,
      "hospital_probe": NaN,
      "tumor_probe": 0.5625
    },
    {
      "alpha": -1.0,
      "head_ood_acc": 0.485,
      "hospital_probe": NaN,
      "tumor_probe": 0.5425
    },
    {
      "alpha": -0.5,
      "head_ood_acc": 0.485,
      "hospital_probe": NaN,
      "tumor_probe": 0.5325
    },
    {
      "alpha": 0.0,
      "head_ood_acc": 0.485,
      "hospital_probe": NaN,
      "tumor_probe": 0.52
    },
    {
      "alpha": 0.5,
      "head_ood_acc": 0.4875,
      "hospital_probe": NaN,
      "tumor_probe": 0.5125
    },
    {
      "alpha": 1.0,
      "head_ood_acc": 0.4875,
      "hospital_probe": NaN,
      "tumor_probe": 0.52
    },
    {
      "alpha": 2.0,
      "head_ood_acc": 0.4775,
      "hospital_probe": NaN,
      "tumor_probe": 0.51
    },
    {
      "alpha": 3.0,
      "head_ood_acc": 0.4825,
      "hospital_probe": NaN,
      "tumor_probe": 0.51
    }
  ]
}

20260508-183413_grokking_n1000_s2024

{
  "run_id": "20260508-183413_grokking_n1000_s2024",
  "epoch": 400,
  "layer": "avgpool",
  "max_samples": 800,
  "v_norm": 1.0,
  "sigma": 9.224357604980469,
  "sweep": [
    {
      "alpha": -3.0,
      "head_ood_acc": 0.7425,
      "hospital_probe": NaN,
      "tumor_probe": 0.69
    },
    {
      "alpha": -2.0,
      "head_ood_acc": 0.7775,
      "hospital_probe": NaN,
      "tumor_probe": 0.69
    },
    {
      "alpha": -1.0,
      "head_ood_acc": 0.7325,
      "hospital_probe": NaN,
      "tumor_probe": 0.69
    },
    {
      "alpha": -0.5,
      "head_ood_acc": 0.715,
      "hospital_probe": NaN,
      "tumor_probe": 0.69
    },
    {
      "alpha": 0.0,
      "head_ood_acc": 0.7125,
      "hospital_probe": NaN,
      "tumor_probe": 0.69
    },
    {
      "alpha": 0.5,
      "head_ood_acc": 0.705,
      "hospital_probe": NaN,
      "tumor_probe": 0.69
    },
    {
      "alpha": 1.0,
      "head_ood_acc": 0.68,
      "hospital_probe": NaN,
      "tumor_probe": 0.69
    },
    {
      "alpha": 2.0,
      "head_ood_acc": 0.6225,
      "hospital_probe": NaN,
      "tumor_probe": 0.69
    },
    {
      "alpha": 3.0,
      "head_ood_acc": 0.595,
      "hospital_probe": NaN,
      "tumor_probe": 0.69
    }
  ]
}

20260505-100720_standard_n1000_s42

{
  "run_id": "20260505-100720_standard_n1000_s42",
  "epoch": 200,
  "layer": "avgpool",
  "max_samples": 800,
  "v_norm": 1.0,
  "sigma": 7.23820686340332,
  "sweep": [
    {
      "alpha": -3.0,
      "head_ood_acc": 0.585,
      "hospital_probe": NaN,
      "tumor_probe": 0.6125
    },
    {
      "alpha": -2.0,
      "head_ood_acc": 0.6,
      "hospital_probe": NaN,
      "tumor_probe": 0.62
    },
    {
      "alpha": -1.0,
      "head_ood_acc": 0.61,
      "hospital_probe": NaN,
      "tumor_probe": 0.625
    },
    {
      "alpha": -0.5,
      "head_ood_acc": 0.6175,
      "hospital_probe": NaN,
      "tumor_probe": 0.63
    },
    {
      "alpha": 0.0,
      "head_ood_acc": 0.6175,
      "hospital_probe": NaN,
      "tumor_probe": 0.625
    },
    {
      "alpha": 0.5,
      "head_ood_acc": 0.6175,
      "hospital_probe": NaN,
      "tumor_probe": 0.6225
    },
    {
      "alpha": 1.0,
      "head_ood_acc": 0.61,
      "hospital_probe": NaN,
      "tumor_probe": 0.6275
    },
    {
      "alpha": 2.0,
      "head_ood_acc": 0.6025,
      "hospital_probe": NaN,
      "tumor_probe": 0.6325
    },
    {
      "alpha": 3.0,
      "head_ood_acc": 0.5925,
      "hospital_probe": NaN,
      "tumor_probe": 0.625
    }
  ]
}

20260508-183413_standard_n1000_s123

{
  "run_id": "20260508-183413_standard_n1000_s123",
  "epoch": 200,
  "layer": "avgpool",
  "max_samples": 800,
  "v_norm": 1.0,
  "sigma": 13.366204261779785,
  "sweep": [
    {
      "alpha": -3.0,
      "head_ood_acc": 0.5425,
      "hospital_probe": NaN,
      "tumor_probe": 0.565
    },
    {
      "alpha": -2.0,
      "head_ood_acc": 0.5975,
      "hospital_probe": NaN,
      "tumor_probe": 0.615
    },
    {
      "alpha": -1.0,
      "head_ood_acc": 0.6725,
      "hospital_probe": NaN,
      "tumor_probe": 0.68
    },
    {
      "alpha": -0.5,
      "head_ood_acc": 0.7125,
      "hospital_probe": NaN,
      "tumor_probe": 0.71
    },
    {
      "alpha": 0.0,
      "head_ood_acc": 0.72,
      "hospital_probe": NaN,
      "tumor_probe": 0.7375
    },
    {
      "alpha": 0.5,
      "head_ood_acc": 0.7175,
      "hospital_probe": NaN,
      "tumor_probe": 0.7425
    },
    {
      "alpha": 1.0,
      "head_ood_acc": 0.6825,
      "hospital_probe": NaN,
      "tumor_probe": 0.7425
    },
    {
      "alpha": 2.0,
      "head_ood_acc": 0.6275,
      "hospital_probe": NaN,
      "tumor_probe": 0.695
    },
    {
      "alpha": 3.0,
      "head_ood_acc": 0.555,
      "hospital_probe": NaN,
      "tumor_probe": 0.6425
    }
  ]
}

20260508-183413_standard_n1000_s456

{
  "run_id": "20260508-183413_standard_n1000_s456",
  "epoch": 1000,
  "layer": "avgpool",
  "max_samples": 800,
  "v_norm": 1.0,
  "sigma": 10.473133087158203,
  "sweep": [
    {
      "alpha": -3.0,
      "head_ood_acc": 0.6375,
      "hospital_probe": NaN,
      "tumor_probe": 0.6475
    },
    {
      "alpha": -2.0,
      "head_ood_acc": 0.6575,
      "hospital_probe": NaN,
      "tumor_probe": 0.66
    },
    {
      "alpha": -1.0,
      "head_ood_acc": 0.675,
      "hospital_probe": NaN,
      "tumor_probe": 0.69
    },
    {
      "alpha": -0.5,
      "head_ood_acc": 0.6725,
      "hospital_probe": NaN,
      "tumor_probe": 0.69
    },
    {
      "alpha": 0.0,
      "head_ood_acc": 0.6175,
      "hospital_probe": NaN,
      "tumor_probe": 0.6575
    },
    {
      "alpha": 0.5,
      "head_ood_acc": 0.58,
      "hospital_probe": NaN,
      "tumor_probe": 0.5975
    },
    {
      "alpha": 1.0,
      "head_ood_acc": 0.5375,
      "hospital_probe": NaN,
      "tumor_probe": 0.585
    },
    {
      "alpha": 2.0,
      "head_ood_acc": 0.505,
      "hospital_probe": NaN,
      "tumor_probe": 0.5225
    },
    {
      "alpha": 3.0,
      "head_ood_acc": 0.505,
      "hospital_probe": NaN,
      "tumor_probe": 0.51
    }
  ]
}

21. M5 — Aggregated sweep tables

Grokking-favorable

α s7 (ep200, σ=5.53) s42 (ep400, σ=8.69) s123 (ep400, σ=6.00) s456 (ep1000, σ=6.75) s2024 (ep400, σ=9.22)
−3.0 0.4950 0.6300 0.7325 0.6825 0.7425
−2.0 0.4900 0.6250 0.7325 0.6575 0.7775
−1.0 0.4850 0.6400 0.7300 0.6400 0.7325
−0.5 0.4850 0.6325 0.7100 0.6300 0.7150
0.0 0.4850 0.6525 0.6975 0.6225 0.7125
+0.5 0.4875 0.6575 0.6925 0.6200 0.7050
+1.0 0.4875 0.6575 0.6900 0.6100 0.6800
+2.0 0.4775 0.6200 0.6800 0.6025 0.6225
+3.0 0.4825 0.5900 0.6575 0.5975 0.5950
Strict mono? yes no (α=0 ≥ α=−3) yes yes yes

Standard

α s42 (ep200, σ=7.24) s123 (ep200, σ=13.37) s456 (ep1000, σ=10.47)
−3.0 0.5850 0.5425 0.6375
−2.0 0.6000 0.5975 0.6575
−1.0 0.6100 0.6725 0.6750
−0.5 0.6175 0.7125 0.6725
0.0 0.6175 0.7200 0.6175
+0.5 0.6175 0.7175 0.5800
+1.0 0.6100 0.6825 0.5375
+2.0 0.6025 0.6275 0.5050
+3.0 0.5925 0.5550 0.5050
Strict mono? no (peak at α=0) no (peak at α=0) yes

Aggregates and statistics:

  • Strict monotonicity (acc(−3) ≥ acc(0) ≥ acc(+3)): 4/5 grokking vs 1/3 standard.
  • Mean σ: grokking 7.24, standard 10.36 (1.43× ratio — the σ-scaling confound).
  • Mean Δ(α=0→−3): grokking +0.0225 ± 0.0276, standard −0.0633 ± 0.0835.
  • Mean Δ(α=0→+3): grokking −0.0500 ± 0.039, standard −0.101 ± 0.058.
  • Fisher exact one-sided p = 0.286, Mann-Whitney U one-sided p = 0.071 (continuous statistic, primary), binomial sign test p = 0.188. None reaches p < 0.05.

22. M6 — Full K-sweep results (per-seed, all K)

Aggregated from paper_figures/m6_summary.csv. Δ(targ−rand) is head_OOD(top-K shortcut ablated) − mean(head_OOD over 5 K-random ablations). Positive = targeted shortcut ablation beats random.

Grokking-favorable (n=1000)

K s7 s42 s123 s456 s2024 N_+
0 0.0000 0.0000 0.0000 0.0000 0.0000
4 +0.0015 +0.0045 +0.0015 +0.0025 −0.0005 4/5
8 −0.0010 +0.0025 +0.0010 0.0000 −0.0020 2/5
16 +0.0005 −0.0010 +0.0055 +0.0035 −0.0015 3/5
32 −0.0045 −0.0025 +0.0010 +0.0045 −0.0020 2/5
64 −0.0015 +0.0005 +0.0115 +0.0120 −0.0115 3/5
128 −0.0060 +0.0005 +0.0090 +0.0015 −0.0225 3/5
256 −0.0345 −0.0010 +0.0060 +0.0075 −0.0205 2/5

Standard (n=1000)

K s42 s123 s456 N_+
0 0.0000 0.0000 0.0000
4 +0.0015 −0.0025 −0.0025 1/3
8 −0.0035 −0.0010 −0.0005 0/3
16 −0.0030 −0.0025 +0.0025 1/3
32 −0.0040 −0.0055 −0.0060 0/3
64 −0.0035 −0.0100 −0.0040 0/3
128 −0.0085 −0.0090 −0.0025 0/3
256 −0.0115 +0.0040 +0.0075 2/3

Per-seed full ablation rows (K=64 and K=256)

grokking s42   K= 64: base=0.6575  short=0.6600  rand=0.6595±0.0043  morph=0.6475  Δshort-base=+0.0025  Δtarg-rand=+0.0005
grokking s42   K=256: base=0.6575  short=0.6625  rand=0.6635±0.0054  morph=0.5500  Δshort-base=+0.0050  Δtarg-rand=−0.0010
grokking s123  K= 64: base=0.6825  short=0.6950  rand=0.6835±0.0066  morph=0.6625  Δshort-base=+0.0125  Δtarg-rand=+0.0115
grokking s123  K=256: base=0.6825  short=0.7275  rand=0.7215±0.0179  morph=0.5575  Δshort-base=+0.0450  Δtarg-rand=+0.0060
grokking s456  K= 64: base=0.6450  short=0.6425  rand=0.6305±0.0120  morph=0.6550  Δshort-base=−0.0025  Δtarg-rand=+0.0120
grokking s456  K=256: base=0.6450  short=0.6325  rand=0.6250±0.0105  morph=0.5275  Δshort-base=−0.0125  Δtarg-rand=+0.0075
grokking s7    K= 64: base=0.4925  short=0.4875  rand=0.4890±0.0075  morph=0.4475  Δshort-base=−0.0050  Δtarg-rand=−0.0015
grokking s7    K=256: base=0.4925  short=0.4650  rand=0.4995±0.0056  morph=0.3800  Δshort-base=−0.0275  Δtarg-rand=−0.0345
grokking s2024 K= 64: base=0.7100  short=0.7125  rand=0.7240±0.0086  morph=0.7225  Δshort-base=+0.0025  Δtarg-rand=−0.0115
grokking s2024 K=256: base=0.7100  short=0.7225  rand=0.7430±0.0129  morph=0.5050  Δshort-base=+0.0125  Δtarg-rand=−0.0205

standard s42   K= 64: base=0.6150  short=0.6125  rand=0.6160±0.0034  morph=0.6100  Δshort-base=−0.0025  Δtarg-rand=−0.0035
standard s42   K=256: base=0.6150  short=0.6100  rand=0.6215±0.0020  morph=0.5950  Δshort-base=−0.0050  Δtarg-rand=−0.0115
standard s123  K= 64: base=0.7225  short=0.7150  rand=0.7250±0.0016  morph=0.7125  Δshort-base=−0.0075  Δtarg-rand=−0.0100
standard s123  K=256: base=0.7225  short=0.6900  rand=0.6860±0.0025  morph=0.6975  Δshort-base=−0.0325  Δtarg-rand=+0.0040
standard s456  K= 64: base=0.5975  short=0.5975  rand=0.6015±0.0030  morph=0.5850  Δshort-base= 0.0000  Δtarg-rand=−0.0040
standard s456  K=256: base=0.5975  short=0.5500  rand=0.5425±0.0055  morph=0.5525  Δshort-base=−0.0475  Δtarg-rand=+0.0075

Aggregates:

  • K=64: grokking 3/5 positive Δ(targ−rand) (mean +0.0022 ± 0.0088); standard 0/3 (mean −0.0058 ± 0.0030). Fisher one-sided p = 0.179.
  • K=256: grokking 2/5 positive; standard 2/3 positive — essentially symmetric.
  • ID accuracy stays within 0.01 of baseline across all targeted ablations.
  • Random control averaged over 5 samplings per K (the main M6 weakness).

The complete 88-row reviewer CSV is at paper_figures/m6_summary.csv — every run, every K, every condition (baseline / shortcut / random ± sd / morphology).


23. Exact commands

Training (launched detached under nohup via scripts/launch.sh):

python -u -m experiments.causalgrok_camelyon_v2 \
    --condition grokking --n_train 1000 --seed 42 \
    --run_dir experiments/runs/<run_id> \
    --wandb_project causalgrok --wandb_mode offline
# standard: --condition standard   (wd/init/grokfast set automatically by get_config)

Mechanistic interpretability:

python -m experiments.mechinterp_m1                 --run_dir experiments/runs/<id> --data_root data/wilds
python -m experiments.mechinterp_m4_ablation        --run_dir experiments/runs/<id> --data_root data/wilds --layer avgpool --all_epochs
python -m experiments.mechinterp_m5_steering        --run_dir experiments/runs/<id> --data_root data/wilds
python -m experiments.mechinterp_m6_neuron_ablation --run_dir experiments/runs/<id> --data_root data/wilds \
    --ks "0,4,8,16,32,64,128,256"

Figures:

python -m experiments.regenerate_all_figures   # rebuilds all 7 paper figures from saved JSON in <2 min

24. Output layout (per run)

experiments/runs/<run_id>/
├── config.json                              # full hyperparameter config
├── run.pid
├── checkpoints/
│   ├── ep00200.pt … ep03000.pt              # 15 periodic checkpoints, ~44 MB each
│   └── final.pt
├── logs/
│   ├── train.log                            # launch cmd + per-checkpoint log lines
│   └── train.err
├── results/
│   ├── history.json                         # 61 per-checkpoint metric rows
│   └── summary.json                         # final-summary fields
├── wandb/                                   # offline wandb run metadata
└── mechinterp/
    ├── m1_probe_data.json                   + heatmap/curves PNG
    ├── m4_ablation_avgpool_trajectory.json  + PNG
    ├── m5_steering_ep<E>.json               + PNG
    └── m6_neuron_ablation_ep<E>.json        + PNG

Aggregated: paper_figures/m6_summary.csv (88 data rows), paper_figures/*.{png,pdf} (7 figures). Checkpoints (~10 GB, 240 .pt files) are mirrored to Hugging Face at nileshsarkar-ai/CausalGrok.


All numerical values in this document were read directly from on-disk config.json / results/summary.json / results/history.json / mechinterp/*.json / paper_figures/m6_summary.csv. Training logs in §18 and §19 are verbatim from each run's logs/train.log. Source code in §10–§15 is the exact code that ran for every reported result.