File size: 5,015 Bytes
50fa85c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
CausalGrok — Ablation Grid Runner

Launches every condition × size × seed as its own subprocess. Each run
gets its own experiments/runs/<run_id>/ directory with isolated logs,
results, checkpoints, and figures.

Use the launchers (they fork this under nohup):
    bash scripts/run_quick_ablations.sh
    bash scripts/run_full_grid.sh

You can also call directly:
    python -m experiments.run_ablations --quick
    python -m experiments.run_ablations --parallel
"""

from __future__ import annotations

import argparse
import itertools
import os
import subprocess
import sys
from datetime import datetime, timezone

from utils.run_dir import DEFAULT_BASE, ensure_run_dir

# (weight_decay, init_scale, use_grokfast, label) — labels feed into wandb tags
# For Camelyon17: baseline WD is 5e-3 (empirically optimal)
GRID = [
    (0.0,   1.0, False, "wd0_a1"),       # control — no regularization
    (1e-4,  1.0, False, "wd1e4_a1"),     # mild WD only
    (5e-3,  1.0, False, "wd5e3_a1"),     # baseline WD, standard init
    (5e-3,  4.0, False, "wd5e3_a4"),     # baseline WD + large init
    (5e-3,  4.0, True,  "wd5e3_a4_gf"),  # full recipe (main condition)
    (1e-2,  4.0, True,  "wd1e2_a4_gf"),  # higher WD variant
]

SIZES = [100, 250, 500, 1000, 2000]
SEEDS = [42, 123, 456]


def cmd_for(wd, alpha, gf, n_train, seed, run_dir, n_epochs=None):
    """
    Build the per-cell training command. The trainer exposes
    --weight_decay, --init_scale, --grokfast {on,off}, --n_epochs as
    overrides, so each grid cell actually runs with its own knobs.

    Compute saver: cells that cannot grok (low/no weight decay AND
    Grokfast disabled) get only 300 epochs instead of 3000. They
    flatline either way; we just want the baseline non-grokking
    accuracy to compare against. This trims roughly 30% off the full
    grid wall time without losing any signal.
    """
    if n_epochs is None:
        can_grok = (wd >= 1e-3) or gf
        n_epochs = 3000 if can_grok else 300
    return [
        sys.executable, "-m", "experiments.causalgrok_camelyon_v2",
        "--condition",     "grokking",
        "--n_train",       str(n_train),
        "--seed",          str(seed),
        "--weight_decay",  str(wd),
        "--init_scale",    str(alpha),
        "--n_epochs",      str(n_epochs),
        "--grokfast",      "on" if gf else "off",
        "--wandb_project", "causalgrok",
        "--run_dir",       run_dir,
    ]


def build_run_dir(label, n_train, seed):
    stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
    run_id = f"{stamp}_{label}_n{n_train}_s{seed}"
    run_dir = os.path.join(DEFAULT_BASE, run_id)
    ensure_run_dir(run_dir)
    return run_dir


def run_quick():
    """Minimal sanity probe: control vs. full recipe at n=500, seed 42."""
    for wd, alpha, gf, label in [GRID[0], GRID[4]]:
        run_dir = build_run_dir(label, 500, 42)
        log = os.path.join(run_dir, "logs", "train.log")
        err = os.path.join(run_dir, "logs", "train.err")
        print(f"\n>> {label} n=500 seed=42  → {run_dir}")
        with open(log, "w") as out, open(err, "w") as ferr:
            subprocess.run(cmd_for(wd, alpha, gf, 500, 42, run_dir),
                           stdout=out, stderr=ferr, check=True)


def run_parallel(grid, sizes, seeds, n_gpus=None):
    if n_gpus is None:
        try:
            import torch
            n_gpus = max(1, torch.cuda.device_count())
        except Exception:
            n_gpus = 1
    procs = []
    for idx, ((wd, alpha, gf, label), n, seed) in enumerate(
            itertools.product(grid, sizes, seeds)):
        env = os.environ.copy()
        env["CUDA_VISIBLE_DEVICES"] = str(idx % n_gpus)
        run_dir = build_run_dir(label, n, seed)
        log = os.path.join(run_dir, "logs", "train.log")
        err = os.path.join(run_dir, "logs", "train.err")
        out_f  = open(log, "w")
        err_f  = open(err, "w")
        p = subprocess.Popen(cmd_for(wd, alpha, gf, n, seed, run_dir),
                             env=env, stdout=out_f, stderr=err_f)
        print(f"  GPU {idx % n_gpus}: {label} n={n} seed={seed}  PID={p.pid}{run_dir}")
        procs.append((p, out_f, err_f, run_dir))
    print(f"\nLaunched {len(procs)} jobs. Waiting...", flush=True)
    for p, out_f, err_f, run_dir in procs:
        rc = p.wait()
        out_f.close(); err_f.close()
        status = "OK" if rc == 0 else f"FAILED rc={rc}"
        print(f"  {status}  {run_dir}", flush=True)
    print("All done.")


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--quick",    action="store_true")
    p.add_argument("--parallel", action="store_true")
    p.add_argument("--n_gpus",   type=int, default=None,
                   help="Override torch.cuda.device_count()")
    args = p.parse_args()

    if args.quick:
        run_quick()
    elif args.parallel:
        run_parallel(GRID, SIZES, SEEDS, n_gpus=args.n_gpus)
    else:
        run_parallel(GRID[:1], [500], [42], n_gpus=args.n_gpus)