| """ |
| CausalGrok β Main Training Loop |
| Nilesh |
| |
| Core experiment: does the IRM invariance penalty drop at the SAME epoch |
| as validation accuracy jumps (the grokking transition)? |
| If yes β the paper's central claim is confirmed. |
| |
| Run via the launchers (always nohup-detached so SSH disconnects don't kill it): |
| bash scripts/launch.sh grokking 500 42 |
| |
| All artifacts (config, logs, history, checkpoints, figures) for every |
| invocation land in experiments/runs/<run_id>/ and are kept forever. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import time |
| from datetime import datetime, timezone |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as transforms |
| from torch.utils.data import DataLoader, Dataset, Subset |
| from torchvision.models import resnet18 |
| from medmnist import PneumoniaMNIST |
| import wandb |
|
|
| from utils.metrics import ( |
| accuracy, weight_norm, feature_rank, irm_penalty, shortcut_ratio, |
| ) |
| from utils.grokfast import gradfilter_ema |
| from utils.pseudo_envs import make_brightness_envs |
| from utils.run_dir import make_run_dir, ensure_run_dir, save_config |
|
|
|
|
| |
| |
| |
|
|
| def get_config(condition): |
| base = dict( |
| seed=42, n_train=500, batch_size=32, img_size=28, |
| n_classes=2, log_every=50, n_pseudo_envs=3, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| ) |
| if condition == "standard": |
| base.update(dict(condition="standard", lr=1e-3, weight_decay=1e-4, |
| n_epochs=300, init_scale=1.0, use_grokfast=False)) |
| elif condition == "grokking": |
| base.update(dict(condition="grokking", lr=1e-3, weight_decay=1e-3, |
| n_epochs=3000, init_scale=4.0, use_grokfast=True, |
| grokfast_alpha=0.98, grokfast_lamb=2.0)) |
| return base |
|
|
|
|
| |
| |
| |
|
|
| class SpuriousColorPatchDataset(Dataset): |
| """ |
| Wraps a (image-tensor, label) dataset and stamps a colored corner |
| patch correlated with the label at probability `rho`. |
| |
| Encoding (after Normalize mean=.5/std=.5, image is in [-1,1] across |
| 3 identical grayscale channels): |
| encoded label 0 β channel-0 high, channels 1,2 low (red corner) |
| encoded label 1 β channel-2 high, channels 0,1 low (blue corner) |
| |
| With prob rho the encoded label matches the true label β a usable |
| shortcut. With prob (1-rho) it's flipped β pure noise on the patch. |
| |
| The same `seed` produces the same per-sample correlation decisions |
| across val/test so the spurious feature is stable across runs and |
| the ceiling effect (val plateau β rho before grokking) is clean. |
| """ |
| def __init__(self, base, rho=0.8, patch_size=4, seed=0, |
| hi=1.0, lo=-1.0): |
| self.base = base |
| self.rho = float(rho) |
| self.patch_size = int(patch_size) |
| self.hi = hi |
| self.lo = lo |
| rng = torch.Generator().manual_seed(int(seed)) |
| self.is_correlated = (torch.rand(len(base), generator=rng) < self.rho) |
|
|
| def __len__(self): |
| return len(self.base) |
|
|
| def __getitem__(self, idx): |
| img, label = self.base[idx] |
| |
| try: |
| label_int = int(label.squeeze().item()) |
| except AttributeError: |
| label_int = int(label) |
| encoded = label_int if bool(self.is_correlated[idx]) else (1 - label_int) |
| ps = self.patch_size |
| if encoded == 0: |
| img[0, :ps, :ps] = self.hi |
| img[1, :ps, :ps] = self.lo |
| img[2, :ps, :ps] = self.lo |
| else: |
| img[0, :ps, :ps] = self.lo |
| img[1, :ps, :ps] = self.lo |
| img[2, :ps, :ps] = self.hi |
| return img, label |
|
|
|
|
| def get_dataloaders(cfg, data_root): |
| |
| |
| os.makedirs(data_root, exist_ok=True) |
| transform = transforms.Compose([ |
| transforms.Resize((cfg["img_size"], cfg["img_size"])), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[.5], std=[.5]), |
| transforms.Lambda(lambda x: x.repeat(3, 1, 1)), |
| ]) |
| train_ds = PneumoniaMNIST(split="train", transform=transform, download=True, root=data_root) |
| val_ds = PneumoniaMNIST(split="val", transform=transform, download=True, root=data_root) |
| test_ds = PneumoniaMNIST(split="test", transform=transform, download=True, root=data_root) |
|
|
| |
| |
| |
| rho = cfg.get("spurious_rho") |
| if rho: |
| ps = cfg.get("spurious_patch_size", 4) |
| sd = cfg.get("spurious_seed", cfg["seed"]) |
| train_ds = SpuriousColorPatchDataset(train_ds, rho=rho, patch_size=ps, seed=sd + 1) |
| val_ds = SpuriousColorPatchDataset(val_ds, rho=rho, patch_size=ps, seed=sd + 2) |
| test_ds = SpuriousColorPatchDataset(test_ds, rho=rho, patch_size=ps, seed=sd + 3) |
|
|
| torch.manual_seed(cfg["seed"]) |
| indices = torch.randperm(len(train_ds))[:cfg["n_train"]] |
| train_subset = Subset(train_ds, indices) |
|
|
| train_loader = DataLoader(train_subset, batch_size=cfg["batch_size"], shuffle=True, num_workers=4, pin_memory=True) |
| val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=4, pin_memory=True) |
| test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=4, pin_memory=True) |
| return train_loader, val_loader, test_loader, train_subset |
|
|
|
|
| |
| |
| |
|
|
| def build_model(cfg): |
| model = resnet18(weights=None, num_classes=cfg["n_classes"]) |
| if cfg["init_scale"] != 1.0: |
| with torch.no_grad(): |
| for name, p in model.named_parameters(): |
| if "weight" in name and p.dim() > 1: |
| p.data *= cfg["init_scale"] |
| return model.to(cfg["device"]) |
|
|
|
|
| |
| |
| |
|
|
| def train(cfg, model, train_loader, val_loader, test_loader, |
| pseudo_envs, optimizer, run_dir): |
| criterion = nn.CrossEntropyLoss() |
| grads_ema = None |
| history = [] |
| best_val = 0.0 |
| grok_epoch = None |
| irm_base = None |
|
|
| print(f"\n{'='*55}") |
| print(f" {cfg['condition'].upper()} | {cfg['n_epochs']} epochs | " |
| f"WD={cfg['weight_decay']} | Ξ±={cfg['init_scale']}") |
| print(f" run_dir: {run_dir}") |
| print(f"{'='*55}", flush=True) |
|
|
| history_path = os.path.join(run_dir, "results", "history.json") |
|
|
| grad_clip = cfg.get("grad_clip", 1.0) |
| plateau_window = 10 |
| plateau_eps = 0.01 |
|
|
| for epoch in range(1, cfg["n_epochs"] + 1): |
| model.train() |
| loss_sum = 0.0 |
| n_b = 0 |
| for imgs, labels in train_loader: |
| imgs = imgs.to(cfg["device"]) |
| labels = labels.squeeze().long().to(cfg["device"]) |
| optimizer.zero_grad() |
| loss = criterion(model(imgs), labels) |
| loss.backward() |
| |
| |
| |
| |
| if cfg.get("use_grokfast"): |
| grads_ema = gradfilter_ema( |
| model, grads_ema, |
| alpha=cfg.get("grokfast_alpha", 0.98), |
| lamb=cfg.get("grokfast_lamb", 2.0)) |
| if grad_clip and grad_clip > 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip) |
| optimizer.step() |
| loss_sum += loss.item(); n_b += 1 |
|
|
| if epoch % cfg["log_every"] == 0 or epoch == 1: |
| tr_acc = accuracy(model, train_loader, cfg["device"]) |
| vl_acc = accuracy(model, val_loader, cfg["device"]) |
| wn = weight_norm(model) |
| fr = feature_rank(model, val_loader, cfg["device"]) |
| irm_m, irm_v = irm_penalty(model, pseudo_envs, cfg["device"]) |
| cconf, bconf = shortcut_ratio(model, val_loader, cfg["device"]) |
|
|
| if irm_base is None: |
| irm_base = irm_m |
|
|
| |
| |
| |
| |
| if grok_epoch is None and len(history) >= plateau_window: |
| last = history[-plateau_window:] |
| ref = last[-1]["val_acc"] |
| flat = sum(1 for r in last if abs(r["val_acc"] - ref) < plateau_eps) |
| if flat >= plateau_window - 2 and vl_acc > best_val + 0.05: |
| grok_epoch = epoch |
| irm_drop = (irm_base - irm_m) / (irm_base + 1e-8) * 100 |
| print(f"\n *** GROKKING at epoch {epoch} ***") |
| print(f" Val: {best_val:.3f}β{vl_acc:.3f} | IRM drop: {irm_drop:.1f}%", |
| flush=True) |
|
|
| if vl_acc > best_val: |
| best_val = vl_acc |
|
|
| |
| |
| sc_ratio = min(bconf / (cconf + 1e-8), 10.0) |
|
|
| row = dict(epoch=epoch, train_loss=loss_sum / n_b, |
| train_acc=tr_acc, val_acc=vl_acc, |
| weight_norm=wn, feature_rank=fr, |
| irm_mean=irm_m, irm_var=irm_v, |
| center_conf=cconf, border_conf=bconf, |
| shortcut_ratio=sc_ratio, |
| grokking_detected=grok_epoch is not None) |
| history.append(row) |
| wandb.log(row) |
|
|
| with open(history_path, "w") as f: |
| json.dump(history, f, indent=2) |
|
|
| print(f" ep {epoch:5d} | loss {loss_sum/n_b:.4f} | " |
| f"tr {tr_acc:.3f} | vl {vl_acc:.3f} | " |
| f"βWβ {wn:.1f} | rank {fr:.1f} | " |
| f"IRM {irm_m:.4f} | sc {sc_ratio:.2f}x", |
| flush=True) |
|
|
| test_acc = accuracy(model, test_loader, cfg["device"]) |
| wandb.log({"test_acc": test_acc, "grokking_epoch": grok_epoch or -1}) |
|
|
| |
| |
| irm_drop_pct = float("nan") |
| irm_drop_ep = -1 |
| epoch_gap = -1 |
| if history: |
| irm0 = history[0]["irm_mean"] |
| irm_min = min(r["irm_mean"] for r in history) |
| if irm0: |
| irm_drop_pct = (irm0 - irm_min) / (irm0 + 1e-8) * 100.0 |
| |
| if len(history) > 1: |
| biggest = 0.0 |
| for prev, cur in zip(history[:-1], history[1:]): |
| d = abs(cur["irm_mean"] - prev["irm_mean"]) |
| if d > biggest: |
| biggest = d |
| irm_drop_ep = cur["epoch"] |
| if grok_epoch and irm_drop_ep > 0: |
| epoch_gap = abs(grok_epoch - irm_drop_ep) |
|
|
| summary = dict( |
| run_id = cfg["run_id"], |
| condition = cfg["condition"], |
| n_train = cfg["n_train"], |
| seed = cfg["seed"], |
| test_acc = test_acc, |
| best_val = best_val, |
| grokking_epoch = grok_epoch if grok_epoch else -1, |
| irm_drop_pct = irm_drop_pct, |
| irm_drop_epoch = irm_drop_ep, |
| epoch_gap = epoch_gap, |
| final_weight_norm = history[-1]["weight_norm"] if history else None, |
| final_feature_rank = history[-1]["feature_rank"] if history else None, |
| final_irm = history[-1]["irm_mean"] if history else None, |
| final_shortcut_ratio = history[-1]["shortcut_ratio"] if history else None, |
| ) |
| with open(os.path.join(run_dir, "results", "summary.json"), "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
| ckpt_path = os.path.join(run_dir, "checkpoints", "final.pt") |
| torch.save(model.state_dict(), ckpt_path) |
|
|
| print(f"\n Test acc: {test_acc:.4f} | Grokking at: {grok_epoch}") |
| print(f" History β {history_path}") |
| print(f" Checkpoint β {ckpt_path}", flush=True) |
| return history |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--condition", default="grokking", choices=["standard", "grokking"]) |
| p.add_argument("--n_train", type=int, default=500) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--log_every", type=int, default=50) |
| p.add_argument("--wandb_project", default="causalgrok") |
| p.add_argument("--wandb_mode", default="online", |
| choices=["online", "offline", "disabled"]) |
| p.add_argument("--run_dir", default=None, |
| help="Override the auto-generated experiments/runs/<run_id>/ path") |
| p.add_argument("--data_root", default="data", |
| help="Where MedMNIST cache lives") |
|
|
| |
| |
| p.add_argument("--weight_decay", type=float, default=None) |
| p.add_argument("--init_scale", type=float, default=None) |
| p.add_argument("--n_epochs", type=int, default=None) |
| p.add_argument("--lr", type=float, default=None) |
| p.add_argument("--grokfast", choices=["on", "off"], default=None, |
| help="Force Grokfast on/off, overriding the preset") |
| p.add_argument("--grad_clip", type=float, default=1.0, |
| help="Max β2 gradient norm; 0 disables clipping") |
|
|
| |
| p.add_argument("--spurious_rho", type=float, default=None, |
| help="Probability that the colored corner patch is correctly correlated with the label. None/0 disables injection.") |
| p.add_argument("--spurious_patch_size", type=int, default=4) |
| p.add_argument("--spurious_seed", type=int, default=None, |
| help="Defaults to --seed; controls per-sample correlation decisions") |
|
|
| args = p.parse_args() |
|
|
| cfg = get_config(args.condition) |
| cfg.update(n_train=args.n_train, seed=args.seed, log_every=args.log_every) |
|
|
| |
| if args.weight_decay is not None: cfg["weight_decay"] = args.weight_decay |
| if args.init_scale is not None: cfg["init_scale"] = args.init_scale |
| if args.n_epochs is not None: cfg["n_epochs"] = args.n_epochs |
| if args.lr is not None: cfg["lr"] = args.lr |
| if args.grokfast is not None: cfg["use_grokfast"] = (args.grokfast == "on") |
| cfg["grad_clip"] = args.grad_clip |
|
|
| cfg["spurious_rho"] = args.spurious_rho |
| cfg["spurious_patch_size"] = args.spurious_patch_size |
| cfg["spurious_seed"] = args.spurious_seed if args.spurious_seed is not None else args.seed |
|
|
| |
| |
| |
| |
| if cfg["device"] == "cuda": |
| torch.set_float32_matmul_precision("high") |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| torch.manual_seed(cfg["seed"]) |
| np.random.seed(cfg["seed"]) |
|
|
| if args.run_dir is None: |
| |
| |
| |
| parts = [cfg["condition"]] |
| if cfg.get("spurious_rho"): |
| parts.append(f"spurious{cfg['spurious_rho']}") |
| parts += [f"n{cfg['n_train']}", f"s{cfg['seed']}"] |
| run_dir, run_id = make_run_dir(parts) |
| else: |
| run_dir = args.run_dir |
| ensure_run_dir(run_dir) |
| run_id = os.path.basename(os.path.normpath(run_dir)) |
|
|
| cfg["run_id"] = run_id |
| cfg["run_dir"] = run_dir |
| save_config(cfg, run_dir) |
|
|
| wandb.init(project=args.wandb_project, config=cfg, name=run_id, |
| mode=args.wandb_mode, dir=run_dir) |
|
|
| print(f"\nDevice: {cfg['device']}") |
| print(f"Run ID: {run_id}") |
| print(f"Started (UTC): {datetime.now(timezone.utc).isoformat()}", flush=True) |
|
|
| train_loader, val_loader, test_loader, train_subset = get_dataloaders(cfg, args.data_root) |
| pseudo_envs = make_brightness_envs(train_subset, cfg["n_pseudo_envs"], cfg["device"]) |
| model = build_model(cfg) |
| optimizer = torch.optim.AdamW(model.parameters(), |
| lr=cfg["lr"], weight_decay=cfg["weight_decay"]) |
|
|
| print(f"Train: {len(train_subset)} | Val: {len(val_loader.dataset)} | " |
| f"Test: {len(test_loader.dataset)}") |
| print(f"Params: {sum(p.numel() for p in model.parameters()):,}", flush=True) |
|
|
| t0 = time.time() |
| train(cfg, model, train_loader, val_loader, test_loader, |
| pseudo_envs, optimizer, run_dir) |
| print(f"\nWall time: {(time.time() - t0) / 60:.1f} min", flush=True) |
| wandb.finish() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|