File size: 5,015 Bytes
50fa85c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """
CausalGrok — Ablation Grid Runner
Launches every condition × size × seed as its own subprocess. Each run
gets its own experiments/runs/<run_id>/ directory with isolated logs,
results, checkpoints, and figures.
Use the launchers (they fork this under nohup):
bash scripts/run_quick_ablations.sh
bash scripts/run_full_grid.sh
You can also call directly:
python -m experiments.run_ablations --quick
python -m experiments.run_ablations --parallel
"""
from __future__ import annotations
import argparse
import itertools
import os
import subprocess
import sys
from datetime import datetime, timezone
from utils.run_dir import DEFAULT_BASE, ensure_run_dir
# (weight_decay, init_scale, use_grokfast, label) — labels feed into wandb tags
# For Camelyon17: baseline WD is 5e-3 (empirically optimal)
GRID = [
(0.0, 1.0, False, "wd0_a1"), # control — no regularization
(1e-4, 1.0, False, "wd1e4_a1"), # mild WD only
(5e-3, 1.0, False, "wd5e3_a1"), # baseline WD, standard init
(5e-3, 4.0, False, "wd5e3_a4"), # baseline WD + large init
(5e-3, 4.0, True, "wd5e3_a4_gf"), # full recipe (main condition)
(1e-2, 4.0, True, "wd1e2_a4_gf"), # higher WD variant
]
SIZES = [100, 250, 500, 1000, 2000]
SEEDS = [42, 123, 456]
def cmd_for(wd, alpha, gf, n_train, seed, run_dir, n_epochs=None):
"""
Build the per-cell training command. The trainer exposes
--weight_decay, --init_scale, --grokfast {on,off}, --n_epochs as
overrides, so each grid cell actually runs with its own knobs.
Compute saver: cells that cannot grok (low/no weight decay AND
Grokfast disabled) get only 300 epochs instead of 3000. They
flatline either way; we just want the baseline non-grokking
accuracy to compare against. This trims roughly 30% off the full
grid wall time without losing any signal.
"""
if n_epochs is None:
can_grok = (wd >= 1e-3) or gf
n_epochs = 3000 if can_grok else 300
return [
sys.executable, "-m", "experiments.causalgrok_camelyon_v2",
"--condition", "grokking",
"--n_train", str(n_train),
"--seed", str(seed),
"--weight_decay", str(wd),
"--init_scale", str(alpha),
"--n_epochs", str(n_epochs),
"--grokfast", "on" if gf else "off",
"--wandb_project", "causalgrok",
"--run_dir", run_dir,
]
def build_run_dir(label, n_train, seed):
stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
run_id = f"{stamp}_{label}_n{n_train}_s{seed}"
run_dir = os.path.join(DEFAULT_BASE, run_id)
ensure_run_dir(run_dir)
return run_dir
def run_quick():
"""Minimal sanity probe: control vs. full recipe at n=500, seed 42."""
for wd, alpha, gf, label in [GRID[0], GRID[4]]:
run_dir = build_run_dir(label, 500, 42)
log = os.path.join(run_dir, "logs", "train.log")
err = os.path.join(run_dir, "logs", "train.err")
print(f"\n>> {label} n=500 seed=42 → {run_dir}")
with open(log, "w") as out, open(err, "w") as ferr:
subprocess.run(cmd_for(wd, alpha, gf, 500, 42, run_dir),
stdout=out, stderr=ferr, check=True)
def run_parallel(grid, sizes, seeds, n_gpus=None):
if n_gpus is None:
try:
import torch
n_gpus = max(1, torch.cuda.device_count())
except Exception:
n_gpus = 1
procs = []
for idx, ((wd, alpha, gf, label), n, seed) in enumerate(
itertools.product(grid, sizes, seeds)):
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(idx % n_gpus)
run_dir = build_run_dir(label, n, seed)
log = os.path.join(run_dir, "logs", "train.log")
err = os.path.join(run_dir, "logs", "train.err")
out_f = open(log, "w")
err_f = open(err, "w")
p = subprocess.Popen(cmd_for(wd, alpha, gf, n, seed, run_dir),
env=env, stdout=out_f, stderr=err_f)
print(f" GPU {idx % n_gpus}: {label} n={n} seed={seed} PID={p.pid} → {run_dir}")
procs.append((p, out_f, err_f, run_dir))
print(f"\nLaunched {len(procs)} jobs. Waiting...", flush=True)
for p, out_f, err_f, run_dir in procs:
rc = p.wait()
out_f.close(); err_f.close()
status = "OK" if rc == 0 else f"FAILED rc={rc}"
print(f" {status} {run_dir}", flush=True)
print("All done.")
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--quick", action="store_true")
p.add_argument("--parallel", action="store_true")
p.add_argument("--n_gpus", type=int, default=None,
help="Override torch.cuda.device_count()")
args = p.parse_args()
if args.quick:
run_quick()
elif args.parallel:
run_parallel(GRID, SIZES, SEEDS, n_gpus=args.n_gpus)
else:
run_parallel(GRID[:1], [500], [42], n_gpus=args.n_gpus)
|