File size: 19,037 Bytes
50fa85c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
"""
CausalGrok β€” Main Training Loop
Nilesh

Core experiment: does the IRM invariance penalty drop at the SAME epoch
as validation accuracy jumps (the grokking transition)?
If yes β†’ the paper's central claim is confirmed.

Run via the launchers (always nohup-detached so SSH disconnects don't kill it):
    bash scripts/launch.sh grokking 500 42

All artifacts (config, logs, history, checkpoints, figures) for every
invocation land in experiments/runs/<run_id>/ and are kept forever.
"""

from __future__ import annotations

import argparse
import json
import os
import time
from datetime import datetime, timezone

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.models import resnet18
from medmnist import PneumoniaMNIST
import wandb

from utils.metrics import (
    accuracy, weight_norm, feature_rank, irm_penalty, shortcut_ratio,
)
from utils.grokfast import gradfilter_ema
from utils.pseudo_envs import make_brightness_envs
from utils.run_dir import make_run_dir, ensure_run_dir, save_config


# ──────────────────────────────────────────────
# CONFIG
# ──────────────────────────────────────────────

def get_config(condition):
    base = dict(
        seed=42, n_train=500, batch_size=32, img_size=28,
        n_classes=2, log_every=50, n_pseudo_envs=3,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )
    if condition == "standard":
        base.update(dict(condition="standard", lr=1e-3, weight_decay=1e-4,
                         n_epochs=300, init_scale=1.0, use_grokfast=False))
    elif condition == "grokking":
        base.update(dict(condition="grokking", lr=1e-3, weight_decay=1e-3,
                         n_epochs=3000, init_scale=4.0, use_grokfast=True,
                         grokfast_alpha=0.98, grokfast_lamb=2.0))
    return base


# ──────────────────────────────────────────────
# DATA
# ──────────────────────────────────────────────

class SpuriousColorPatchDataset(Dataset):
    """
    Wraps a (image-tensor, label) dataset and stamps a colored corner
    patch correlated with the label at probability `rho`.

    Encoding (after Normalize mean=.5/std=.5, image is in [-1,1] across
    3 identical grayscale channels):
        encoded label 0 β†’ channel-0 high, channels 1,2 low (red corner)
        encoded label 1 β†’ channel-2 high, channels 0,1 low (blue corner)

    With prob rho the encoded label matches the true label β€” a usable
    shortcut. With prob (1-rho) it's flipped β€” pure noise on the patch.

    The same `seed` produces the same per-sample correlation decisions
    across val/test so the spurious feature is stable across runs and
    the ceiling effect (val plateau β‰ˆ rho before grokking) is clean.
    """
    def __init__(self, base, rho=0.8, patch_size=4, seed=0,
                 hi=1.0, lo=-1.0):
        self.base = base
        self.rho = float(rho)
        self.patch_size = int(patch_size)
        self.hi = hi
        self.lo = lo
        rng = torch.Generator().manual_seed(int(seed))
        self.is_correlated = (torch.rand(len(base), generator=rng) < self.rho)

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, label = self.base[idx]
        # label may be a 1-element tensor or a python scalar
        try:
            label_int = int(label.squeeze().item())
        except AttributeError:
            label_int = int(label)
        encoded = label_int if bool(self.is_correlated[idx]) else (1 - label_int)
        ps = self.patch_size
        if encoded == 0:
            img[0, :ps, :ps] = self.hi
            img[1, :ps, :ps] = self.lo
            img[2, :ps, :ps] = self.lo
        else:
            img[0, :ps, :ps] = self.lo
            img[1, :ps, :ps] = self.lo
            img[2, :ps, :ps] = self.hi
        return img, label


def get_dataloaders(cfg, data_root):
    # medmnist 3.x raises if root doesn't exist; create it ourselves
    # rather than relying on its default-root fallback.
    os.makedirs(data_root, exist_ok=True)
    transform = transforms.Compose([
        transforms.Resize((cfg["img_size"], cfg["img_size"])),
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5], std=[.5]),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    ])
    train_ds = PneumoniaMNIST(split="train", transform=transform, download=True, root=data_root)
    val_ds   = PneumoniaMNIST(split="val",   transform=transform, download=True, root=data_root)
    test_ds  = PneumoniaMNIST(split="test",  transform=transform, download=True, root=data_root)

    # Spurious-feature injection: colored corner patch at correlation rho.
    # Same rho on all splits so the shortcut model plateaus at valβ‰ˆrho;
    # grokking transition is the model breaking through that ceiling.
    rho = cfg.get("spurious_rho")
    if rho:
        ps  = cfg.get("spurious_patch_size", 4)
        sd  = cfg.get("spurious_seed", cfg["seed"])
        train_ds = SpuriousColorPatchDataset(train_ds, rho=rho, patch_size=ps, seed=sd + 1)
        val_ds   = SpuriousColorPatchDataset(val_ds,   rho=rho, patch_size=ps, seed=sd + 2)
        test_ds  = SpuriousColorPatchDataset(test_ds,  rho=rho, patch_size=ps, seed=sd + 3)

    torch.manual_seed(cfg["seed"])
    indices      = torch.randperm(len(train_ds))[:cfg["n_train"]]
    train_subset = Subset(train_ds, indices)

    train_loader = DataLoader(train_subset, batch_size=cfg["batch_size"], shuffle=True,  num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_ds,       batch_size=256,               shuffle=False, num_workers=4, pin_memory=True)
    test_loader  = DataLoader(test_ds,      batch_size=256,               shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, val_loader, test_loader, train_subset


# ──────────────────────────────────────────────
# MODEL
# ──────────────────────────────────────────────

def build_model(cfg):
    model = resnet18(weights=None, num_classes=cfg["n_classes"])
    if cfg["init_scale"] != 1.0:
        with torch.no_grad():
            for name, p in model.named_parameters():
                if "weight" in name and p.dim() > 1:
                    p.data *= cfg["init_scale"]
    return model.to(cfg["device"])


# ──────────────────────────────────────────────
# TRAIN
# ──────────────────────────────────────────────

def train(cfg, model, train_loader, val_loader, test_loader,
          pseudo_envs, optimizer, run_dir):
    criterion  = nn.CrossEntropyLoss()
    grads_ema  = None
    history    = []
    best_val   = 0.0
    grok_epoch = None
    irm_base   = None

    print(f"\n{'='*55}")
    print(f"  {cfg['condition'].upper()} | {cfg['n_epochs']} epochs | "
          f"WD={cfg['weight_decay']} | Ξ±={cfg['init_scale']}")
    print(f"  run_dir: {run_dir}")
    print(f"{'='*55}", flush=True)

    history_path = os.path.join(run_dir, "results", "history.json")

    grad_clip = cfg.get("grad_clip", 1.0)
    plateau_window = 10
    plateau_eps    = 0.01   # |Ξ”val_acc| within this counts as flat

    for epoch in range(1, cfg["n_epochs"] + 1):
        model.train()
        loss_sum = 0.0
        n_b = 0
        for imgs, labels in train_loader:
            imgs   = imgs.to(cfg["device"])
            labels = labels.squeeze().long().to(cfg["device"])
            optimizer.zero_grad()
            loss   = criterion(model(imgs), labels)
            loss.backward()
            # Order matters: Grokfast amplifies, THEN we clip the
            # amplified result. Clipping before Grokfast would let the
            # amplification re-blow up the gradient and partially
            # undo the safety bound.
            if cfg.get("use_grokfast"):
                grads_ema = gradfilter_ema(
                    model, grads_ema,
                    alpha=cfg.get("grokfast_alpha", 0.98),
                    lamb=cfg.get("grokfast_lamb", 2.0))
            if grad_clip and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
            optimizer.step()
            loss_sum += loss.item(); n_b += 1

        if epoch % cfg["log_every"] == 0 or epoch == 1:
            tr_acc = accuracy(model, train_loader, cfg["device"])
            vl_acc = accuracy(model, val_loader,   cfg["device"])
            wn     = weight_norm(model)
            fr     = feature_rank(model, val_loader, cfg["device"])
            irm_m, irm_v = irm_penalty(model, pseudo_envs, cfg["device"])
            cconf, bconf = shortcut_ratio(model, val_loader, cfg["device"])

            if irm_base is None:
                irm_base = irm_m

            # Robust grokking detection β€” require a sustained plateau in
            # val_acc (β‰₯ plateau_window-2 of the last `plateau_window`
            # checkpoints flat within `plateau_eps`) BEFORE the jump.
            # Otherwise early-training noise (0.50 β†’ 0.56) can trigger.
            if grok_epoch is None and len(history) >= plateau_window:
                last = history[-plateau_window:]
                ref  = last[-1]["val_acc"]
                flat = sum(1 for r in last if abs(r["val_acc"] - ref) < plateau_eps)
                if flat >= plateau_window - 2 and vl_acc > best_val + 0.05:
                    grok_epoch = epoch
                    irm_drop   = (irm_base - irm_m) / (irm_base + 1e-8) * 100
                    print(f"\n  *** GROKKING at epoch {epoch} ***")
                    print(f"      Val: {best_val:.3f}β†’{vl_acc:.3f} | IRM drop: {irm_drop:.1f}%",
                          flush=True)

            if vl_acc > best_val:
                best_val = vl_acc

            # Cap the shortcut ratio β€” early training can give cconfβ‰ˆbconfβ‰ˆ0
            # which makes the raw ratio explode.
            sc_ratio = min(bconf / (cconf + 1e-8), 10.0)

            row = dict(epoch=epoch, train_loss=loss_sum / n_b,
                       train_acc=tr_acc, val_acc=vl_acc,
                       weight_norm=wn, feature_rank=fr,
                       irm_mean=irm_m, irm_var=irm_v,
                       center_conf=cconf, border_conf=bconf,
                       shortcut_ratio=sc_ratio,
                       grokking_detected=grok_epoch is not None)
            history.append(row)
            wandb.log(row)

            with open(history_path, "w") as f:
                json.dump(history, f, indent=2)

            print(f"  ep {epoch:5d} | loss {loss_sum/n_b:.4f} | "
                  f"tr {tr_acc:.3f} | vl {vl_acc:.3f} | "
                  f"β€–Wβ€– {wn:.1f} | rank {fr:.1f} | "
                  f"IRM {irm_m:.4f} | sc {sc_ratio:.2f}x",
                  flush=True)

    test_acc = accuracy(model, test_loader, cfg["device"])
    wandb.log({"test_acc": test_acc, "grokking_epoch": grok_epoch or -1})

    # Compute the four decision numbers right here so summary.json is
    # the single source of truth for go/no-go.
    irm_drop_pct = float("nan")
    irm_drop_ep  = -1
    epoch_gap    = -1
    if history:
        irm0    = history[0]["irm_mean"]
        irm_min = min(r["irm_mean"] for r in history)
        if irm0:
            irm_drop_pct = (irm0 - irm_min) / (irm0 + 1e-8) * 100.0
        # Epoch of biggest IRM step-change (proxy for "the IRM drop")
        if len(history) > 1:
            biggest = 0.0
            for prev, cur in zip(history[:-1], history[1:]):
                d = abs(cur["irm_mean"] - prev["irm_mean"])
                if d > biggest:
                    biggest = d
                    irm_drop_ep = cur["epoch"]
        if grok_epoch and irm_drop_ep > 0:
            epoch_gap = abs(grok_epoch - irm_drop_ep)

    summary = dict(
        run_id               = cfg["run_id"],
        condition            = cfg["condition"],
        n_train              = cfg["n_train"],
        seed                 = cfg["seed"],
        test_acc             = test_acc,
        best_val             = best_val,
        grokking_epoch       = grok_epoch if grok_epoch else -1,
        irm_drop_pct         = irm_drop_pct,
        irm_drop_epoch       = irm_drop_ep,
        epoch_gap            = epoch_gap,
        final_weight_norm    = history[-1]["weight_norm"]    if history else None,
        final_feature_rank   = history[-1]["feature_rank"]   if history else None,
        final_irm            = history[-1]["irm_mean"]       if history else None,
        final_shortcut_ratio = history[-1]["shortcut_ratio"] if history else None,
    )
    with open(os.path.join(run_dir, "results", "summary.json"), "w") as f:
        json.dump(summary, f, indent=2)

    ckpt_path = os.path.join(run_dir, "checkpoints", "final.pt")
    torch.save(model.state_dict(), ckpt_path)

    print(f"\n  Test acc: {test_acc:.4f} | Grokking at: {grok_epoch}")
    print(f"  History β†’ {history_path}")
    print(f"  Checkpoint β†’ {ckpt_path}", flush=True)
    return history


# ──────────────────────────────────────────────
# MAIN
# ──────────────────────────────────────────────

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--condition",     default="grokking", choices=["standard", "grokking"])
    p.add_argument("--n_train",       type=int, default=500)
    p.add_argument("--seed",          type=int, default=42)
    p.add_argument("--log_every",     type=int, default=50)
    p.add_argument("--wandb_project", default="causalgrok")
    p.add_argument("--wandb_mode",    default="online",
                   choices=["online", "offline", "disabled"])
    p.add_argument("--run_dir",       default=None,
                   help="Override the auto-generated experiments/runs/<run_id>/ path")
    p.add_argument("--data_root",     default="data",
                   help="Where MedMNIST cache lives")

    # Per-knob overrides for the ablation grid. When set, they override
    # the preset chosen by --condition. When omitted, the preset wins.
    p.add_argument("--weight_decay", type=float, default=None)
    p.add_argument("--init_scale",   type=float, default=None)
    p.add_argument("--n_epochs",     type=int,   default=None)
    p.add_argument("--lr",           type=float, default=None)
    p.add_argument("--grokfast",     choices=["on", "off"], default=None,
                   help="Force Grokfast on/off, overriding the preset")
    p.add_argument("--grad_clip",    type=float, default=1.0,
                   help="Max β„“2 gradient norm; 0 disables clipping")

    # Spurious-feature injection (Outcome-C variant).
    p.add_argument("--spurious_rho",        type=float, default=None,
                   help="Probability that the colored corner patch is correctly correlated with the label. None/0 disables injection.")
    p.add_argument("--spurious_patch_size", type=int,   default=4)
    p.add_argument("--spurious_seed",       type=int,   default=None,
                   help="Defaults to --seed; controls per-sample correlation decisions")

    args = p.parse_args()

    cfg = get_config(args.condition)
    cfg.update(n_train=args.n_train, seed=args.seed, log_every=args.log_every)

    # CLI overrides take precedence over preset
    if args.weight_decay is not None:  cfg["weight_decay"] = args.weight_decay
    if args.init_scale   is not None:  cfg["init_scale"]   = args.init_scale
    if args.n_epochs     is not None:  cfg["n_epochs"]     = args.n_epochs
    if args.lr           is not None:  cfg["lr"]           = args.lr
    if args.grokfast     is not None:  cfg["use_grokfast"] = (args.grokfast == "on")
    cfg["grad_clip"] = args.grad_clip

    cfg["spurious_rho"]        = args.spurious_rho
    cfg["spurious_patch_size"] = args.spurious_patch_size
    cfg["spurious_seed"]       = args.spurious_seed if args.spurious_seed is not None else args.seed

    # ── Use the remaining compute on a shared GPU more aggressively ──
    # TF32 matmuls are A100-native and ~2Γ— faster than fp32 with no
    # measurable effect on grokking dynamics for our scale of model.
    # cudnn.benchmark autotunes conv algorithms for our fixed shape.
    if cfg["device"] == "cuda":
        torch.set_float32_matmul_precision("high")
        torch.backends.cudnn.benchmark      = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32     = True

    torch.manual_seed(cfg["seed"])
    np.random.seed(cfg["seed"])

    if args.run_dir is None:
        # Tag spurious runs in the run_id so the dirs are
        # distinguishable on disk and globs like
        # `experiments/runs/*spurious*/` work without ambiguity.
        parts = [cfg["condition"]]
        if cfg.get("spurious_rho"):
            parts.append(f"spurious{cfg['spurious_rho']}")
        parts += [f"n{cfg['n_train']}", f"s{cfg['seed']}"]
        run_dir, run_id = make_run_dir(parts)
    else:
        run_dir = args.run_dir
        ensure_run_dir(run_dir)
        run_id = os.path.basename(os.path.normpath(run_dir))

    cfg["run_id"]  = run_id
    cfg["run_dir"] = run_dir
    save_config(cfg, run_dir)

    wandb.init(project=args.wandb_project, config=cfg, name=run_id,
               mode=args.wandb_mode, dir=run_dir)

    print(f"\nDevice: {cfg['device']}")
    print(f"Run ID: {run_id}")
    print(f"Started (UTC): {datetime.now(timezone.utc).isoformat()}", flush=True)

    train_loader, val_loader, test_loader, train_subset = get_dataloaders(cfg, args.data_root)
    pseudo_envs = make_brightness_envs(train_subset, cfg["n_pseudo_envs"], cfg["device"])
    model     = build_model(cfg)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=cfg["lr"], weight_decay=cfg["weight_decay"])

    print(f"Train: {len(train_subset)} | Val: {len(val_loader.dataset)} | "
          f"Test: {len(test_loader.dataset)}")
    print(f"Params: {sum(p.numel() for p in model.parameters()):,}", flush=True)

    t0 = time.time()
    train(cfg, model, train_loader, val_loader, test_loader,
          pseudo_envs, optimizer, run_dir)
    print(f"\nWall time: {(time.time() - t0) / 60:.1f} min", flush=True)
    wandb.finish()


if __name__ == "__main__":
    main()