CausalGrok / code /experiments /causalgrok_baseline.py
nileshsarkar-ai's picture
Upload code/experiments
50fa85c verified
"""
CausalGrok β€” Main Training Loop
Nilesh
Core experiment: does the IRM invariance penalty drop at the SAME epoch
as validation accuracy jumps (the grokking transition)?
If yes β†’ the paper's central claim is confirmed.
Run via the launchers (always nohup-detached so SSH disconnects don't kill it):
bash scripts/launch.sh grokking 500 42
All artifacts (config, logs, history, checkpoints, figures) for every
invocation land in experiments/runs/<run_id>/ and are kept forever.
"""
from __future__ import annotations
import argparse
import json
import os
import time
from datetime import datetime, timezone
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.models import resnet18
from medmnist import PneumoniaMNIST
import wandb
from utils.metrics import (
accuracy, weight_norm, feature_rank, irm_penalty, shortcut_ratio,
)
from utils.grokfast import gradfilter_ema
from utils.pseudo_envs import make_brightness_envs
from utils.run_dir import make_run_dir, ensure_run_dir, save_config
# ──────────────────────────────────────────────
# CONFIG
# ──────────────────────────────────────────────
def get_config(condition):
base = dict(
seed=42, n_train=500, batch_size=32, img_size=28,
n_classes=2, log_every=50, n_pseudo_envs=3,
device="cuda" if torch.cuda.is_available() else "cpu",
)
if condition == "standard":
base.update(dict(condition="standard", lr=1e-3, weight_decay=1e-4,
n_epochs=300, init_scale=1.0, use_grokfast=False))
elif condition == "grokking":
base.update(dict(condition="grokking", lr=1e-3, weight_decay=1e-3,
n_epochs=3000, init_scale=4.0, use_grokfast=True,
grokfast_alpha=0.98, grokfast_lamb=2.0))
return base
# ──────────────────────────────────────────────
# DATA
# ──────────────────────────────────────────────
class SpuriousColorPatchDataset(Dataset):
"""
Wraps a (image-tensor, label) dataset and stamps a colored corner
patch correlated with the label at probability `rho`.
Encoding (after Normalize mean=.5/std=.5, image is in [-1,1] across
3 identical grayscale channels):
encoded label 0 β†’ channel-0 high, channels 1,2 low (red corner)
encoded label 1 β†’ channel-2 high, channels 0,1 low (blue corner)
With prob rho the encoded label matches the true label β€” a usable
shortcut. With prob (1-rho) it's flipped β€” pure noise on the patch.
The same `seed` produces the same per-sample correlation decisions
across val/test so the spurious feature is stable across runs and
the ceiling effect (val plateau β‰ˆ rho before grokking) is clean.
"""
def __init__(self, base, rho=0.8, patch_size=4, seed=0,
hi=1.0, lo=-1.0):
self.base = base
self.rho = float(rho)
self.patch_size = int(patch_size)
self.hi = hi
self.lo = lo
rng = torch.Generator().manual_seed(int(seed))
self.is_correlated = (torch.rand(len(base), generator=rng) < self.rho)
def __len__(self):
return len(self.base)
def __getitem__(self, idx):
img, label = self.base[idx]
# label may be a 1-element tensor or a python scalar
try:
label_int = int(label.squeeze().item())
except AttributeError:
label_int = int(label)
encoded = label_int if bool(self.is_correlated[idx]) else (1 - label_int)
ps = self.patch_size
if encoded == 0:
img[0, :ps, :ps] = self.hi
img[1, :ps, :ps] = self.lo
img[2, :ps, :ps] = self.lo
else:
img[0, :ps, :ps] = self.lo
img[1, :ps, :ps] = self.lo
img[2, :ps, :ps] = self.hi
return img, label
def get_dataloaders(cfg, data_root):
# medmnist 3.x raises if root doesn't exist; create it ourselves
# rather than relying on its default-root fallback.
os.makedirs(data_root, exist_ok=True)
transform = transforms.Compose([
transforms.Resize((cfg["img_size"], cfg["img_size"])),
transforms.ToTensor(),
transforms.Normalize(mean=[.5], std=[.5]),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
])
train_ds = PneumoniaMNIST(split="train", transform=transform, download=True, root=data_root)
val_ds = PneumoniaMNIST(split="val", transform=transform, download=True, root=data_root)
test_ds = PneumoniaMNIST(split="test", transform=transform, download=True, root=data_root)
# Spurious-feature injection: colored corner patch at correlation rho.
# Same rho on all splits so the shortcut model plateaus at valβ‰ˆrho;
# grokking transition is the model breaking through that ceiling.
rho = cfg.get("spurious_rho")
if rho:
ps = cfg.get("spurious_patch_size", 4)
sd = cfg.get("spurious_seed", cfg["seed"])
train_ds = SpuriousColorPatchDataset(train_ds, rho=rho, patch_size=ps, seed=sd + 1)
val_ds = SpuriousColorPatchDataset(val_ds, rho=rho, patch_size=ps, seed=sd + 2)
test_ds = SpuriousColorPatchDataset(test_ds, rho=rho, patch_size=ps, seed=sd + 3)
torch.manual_seed(cfg["seed"])
indices = torch.randperm(len(train_ds))[:cfg["n_train"]]
train_subset = Subset(train_ds, indices)
train_loader = DataLoader(train_subset, batch_size=cfg["batch_size"], shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)
return train_loader, val_loader, test_loader, train_subset
# ──────────────────────────────────────────────
# MODEL
# ──────────────────────────────────────────────
def build_model(cfg):
model = resnet18(weights=None, num_classes=cfg["n_classes"])
if cfg["init_scale"] != 1.0:
with torch.no_grad():
for name, p in model.named_parameters():
if "weight" in name and p.dim() > 1:
p.data *= cfg["init_scale"]
return model.to(cfg["device"])
# ──────────────────────────────────────────────
# TRAIN
# ──────────────────────────────────────────────
def train(cfg, model, train_loader, val_loader, test_loader,
pseudo_envs, optimizer, run_dir):
criterion = nn.CrossEntropyLoss()
grads_ema = None
history = []
best_val = 0.0
grok_epoch = None
irm_base = None
print(f"\n{'='*55}")
print(f" {cfg['condition'].upper()} | {cfg['n_epochs']} epochs | "
f"WD={cfg['weight_decay']} | Ξ±={cfg['init_scale']}")
print(f" run_dir: {run_dir}")
print(f"{'='*55}", flush=True)
history_path = os.path.join(run_dir, "results", "history.json")
grad_clip = cfg.get("grad_clip", 1.0)
plateau_window = 10
plateau_eps = 0.01 # |Ξ”val_acc| within this counts as flat
for epoch in range(1, cfg["n_epochs"] + 1):
model.train()
loss_sum = 0.0
n_b = 0
for imgs, labels in train_loader:
imgs = imgs.to(cfg["device"])
labels = labels.squeeze().long().to(cfg["device"])
optimizer.zero_grad()
loss = criterion(model(imgs), labels)
loss.backward()
# Order matters: Grokfast amplifies, THEN we clip the
# amplified result. Clipping before Grokfast would let the
# amplification re-blow up the gradient and partially
# undo the safety bound.
if cfg.get("use_grokfast"):
grads_ema = gradfilter_ema(
model, grads_ema,
alpha=cfg.get("grokfast_alpha", 0.98),
lamb=cfg.get("grokfast_lamb", 2.0))
if grad_clip and grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
optimizer.step()
loss_sum += loss.item(); n_b += 1
if epoch % cfg["log_every"] == 0 or epoch == 1:
tr_acc = accuracy(model, train_loader, cfg["device"])
vl_acc = accuracy(model, val_loader, cfg["device"])
wn = weight_norm(model)
fr = feature_rank(model, val_loader, cfg["device"])
irm_m, irm_v = irm_penalty(model, pseudo_envs, cfg["device"])
cconf, bconf = shortcut_ratio(model, val_loader, cfg["device"])
if irm_base is None:
irm_base = irm_m
# Robust grokking detection β€” require a sustained plateau in
# val_acc (β‰₯ plateau_window-2 of the last `plateau_window`
# checkpoints flat within `plateau_eps`) BEFORE the jump.
# Otherwise early-training noise (0.50 β†’ 0.56) can trigger.
if grok_epoch is None and len(history) >= plateau_window:
last = history[-plateau_window:]
ref = last[-1]["val_acc"]
flat = sum(1 for r in last if abs(r["val_acc"] - ref) < plateau_eps)
if flat >= plateau_window - 2 and vl_acc > best_val + 0.05:
grok_epoch = epoch
irm_drop = (irm_base - irm_m) / (irm_base + 1e-8) * 100
print(f"\n *** GROKKING at epoch {epoch} ***")
print(f" Val: {best_val:.3f}β†’{vl_acc:.3f} | IRM drop: {irm_drop:.1f}%",
flush=True)
if vl_acc > best_val:
best_val = vl_acc
# Cap the shortcut ratio β€” early training can give cconfβ‰ˆbconfβ‰ˆ0
# which makes the raw ratio explode.
sc_ratio = min(bconf / (cconf + 1e-8), 10.0)
row = dict(epoch=epoch, train_loss=loss_sum / n_b,
train_acc=tr_acc, val_acc=vl_acc,
weight_norm=wn, feature_rank=fr,
irm_mean=irm_m, irm_var=irm_v,
center_conf=cconf, border_conf=bconf,
shortcut_ratio=sc_ratio,
grokking_detected=grok_epoch is not None)
history.append(row)
wandb.log(row)
with open(history_path, "w") as f:
json.dump(history, f, indent=2)
print(f" ep {epoch:5d} | loss {loss_sum/n_b:.4f} | "
f"tr {tr_acc:.3f} | vl {vl_acc:.3f} | "
f"β€–Wβ€– {wn:.1f} | rank {fr:.1f} | "
f"IRM {irm_m:.4f} | sc {sc_ratio:.2f}x",
flush=True)
test_acc = accuracy(model, test_loader, cfg["device"])
wandb.log({"test_acc": test_acc, "grokking_epoch": grok_epoch or -1})
# Compute the four decision numbers right here so summary.json is
# the single source of truth for go/no-go.
irm_drop_pct = float("nan")
irm_drop_ep = -1
epoch_gap = -1
if history:
irm0 = history[0]["irm_mean"]
irm_min = min(r["irm_mean"] for r in history)
if irm0:
irm_drop_pct = (irm0 - irm_min) / (irm0 + 1e-8) * 100.0
# Epoch of biggest IRM step-change (proxy for "the IRM drop")
if len(history) > 1:
biggest = 0.0
for prev, cur in zip(history[:-1], history[1:]):
d = abs(cur["irm_mean"] - prev["irm_mean"])
if d > biggest:
biggest = d
irm_drop_ep = cur["epoch"]
if grok_epoch and irm_drop_ep > 0:
epoch_gap = abs(grok_epoch - irm_drop_ep)
summary = dict(
run_id = cfg["run_id"],
condition = cfg["condition"],
n_train = cfg["n_train"],
seed = cfg["seed"],
test_acc = test_acc,
best_val = best_val,
grokking_epoch = grok_epoch if grok_epoch else -1,
irm_drop_pct = irm_drop_pct,
irm_drop_epoch = irm_drop_ep,
epoch_gap = epoch_gap,
final_weight_norm = history[-1]["weight_norm"] if history else None,
final_feature_rank = history[-1]["feature_rank"] if history else None,
final_irm = history[-1]["irm_mean"] if history else None,
final_shortcut_ratio = history[-1]["shortcut_ratio"] if history else None,
)
with open(os.path.join(run_dir, "results", "summary.json"), "w") as f:
json.dump(summary, f, indent=2)
ckpt_path = os.path.join(run_dir, "checkpoints", "final.pt")
torch.save(model.state_dict(), ckpt_path)
print(f"\n Test acc: {test_acc:.4f} | Grokking at: {grok_epoch}")
print(f" History β†’ {history_path}")
print(f" Checkpoint β†’ {ckpt_path}", flush=True)
return history
# ──────────────────────────────────────────────
# MAIN
# ──────────────────────────────────────────────
def main():
p = argparse.ArgumentParser()
p.add_argument("--condition", default="grokking", choices=["standard", "grokking"])
p.add_argument("--n_train", type=int, default=500)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--log_every", type=int, default=50)
p.add_argument("--wandb_project", default="causalgrok")
p.add_argument("--wandb_mode", default="online",
choices=["online", "offline", "disabled"])
p.add_argument("--run_dir", default=None,
help="Override the auto-generated experiments/runs/<run_id>/ path")
p.add_argument("--data_root", default="data",
help="Where MedMNIST cache lives")
# Per-knob overrides for the ablation grid. When set, they override
# the preset chosen by --condition. When omitted, the preset wins.
p.add_argument("--weight_decay", type=float, default=None)
p.add_argument("--init_scale", type=float, default=None)
p.add_argument("--n_epochs", type=int, default=None)
p.add_argument("--lr", type=float, default=None)
p.add_argument("--grokfast", choices=["on", "off"], default=None,
help="Force Grokfast on/off, overriding the preset")
p.add_argument("--grad_clip", type=float, default=1.0,
help="Max β„“2 gradient norm; 0 disables clipping")
# Spurious-feature injection (Outcome-C variant).
p.add_argument("--spurious_rho", type=float, default=None,
help="Probability that the colored corner patch is correctly correlated with the label. None/0 disables injection.")
p.add_argument("--spurious_patch_size", type=int, default=4)
p.add_argument("--spurious_seed", type=int, default=None,
help="Defaults to --seed; controls per-sample correlation decisions")
args = p.parse_args()
cfg = get_config(args.condition)
cfg.update(n_train=args.n_train, seed=args.seed, log_every=args.log_every)
# CLI overrides take precedence over preset
if args.weight_decay is not None: cfg["weight_decay"] = args.weight_decay
if args.init_scale is not None: cfg["init_scale"] = args.init_scale
if args.n_epochs is not None: cfg["n_epochs"] = args.n_epochs
if args.lr is not None: cfg["lr"] = args.lr
if args.grokfast is not None: cfg["use_grokfast"] = (args.grokfast == "on")
cfg["grad_clip"] = args.grad_clip
cfg["spurious_rho"] = args.spurious_rho
cfg["spurious_patch_size"] = args.spurious_patch_size
cfg["spurious_seed"] = args.spurious_seed if args.spurious_seed is not None else args.seed
# ── Use the remaining compute on a shared GPU more aggressively ──
# TF32 matmuls are A100-native and ~2Γ— faster than fp32 with no
# measurable effect on grokking dynamics for our scale of model.
# cudnn.benchmark autotunes conv algorithms for our fixed shape.
if cfg["device"] == "cuda":
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(cfg["seed"])
np.random.seed(cfg["seed"])
if args.run_dir is None:
# Tag spurious runs in the run_id so the dirs are
# distinguishable on disk and globs like
# `experiments/runs/*spurious*/` work without ambiguity.
parts = [cfg["condition"]]
if cfg.get("spurious_rho"):
parts.append(f"spurious{cfg['spurious_rho']}")
parts += [f"n{cfg['n_train']}", f"s{cfg['seed']}"]
run_dir, run_id = make_run_dir(parts)
else:
run_dir = args.run_dir
ensure_run_dir(run_dir)
run_id = os.path.basename(os.path.normpath(run_dir))
cfg["run_id"] = run_id
cfg["run_dir"] = run_dir
save_config(cfg, run_dir)
wandb.init(project=args.wandb_project, config=cfg, name=run_id,
mode=args.wandb_mode, dir=run_dir)
print(f"\nDevice: {cfg['device']}")
print(f"Run ID: {run_id}")
print(f"Started (UTC): {datetime.now(timezone.utc).isoformat()}", flush=True)
train_loader, val_loader, test_loader, train_subset = get_dataloaders(cfg, args.data_root)
pseudo_envs = make_brightness_envs(train_subset, cfg["n_pseudo_envs"], cfg["device"])
model = build_model(cfg)
optimizer = torch.optim.AdamW(model.parameters(),
lr=cfg["lr"], weight_decay=cfg["weight_decay"])
print(f"Train: {len(train_subset)} | Val: {len(val_loader.dataset)} | "
f"Test: {len(test_loader.dataset)}")
print(f"Params: {sum(p.numel() for p in model.parameters()):,}", flush=True)
t0 = time.time()
train(cfg, model, train_loader, val_loader, test_loader,
pseudo_envs, optimizer, run_dir)
print(f"\nWall time: {(time.time() - t0) / 60:.1f} min", flush=True)
wandb.finish()
if __name__ == "__main__":
main()