CausalGrok / code /experiments /run_ablations.py
nileshsarkar-ai's picture
Upload code/experiments
50fa85c verified
"""
CausalGrok — Ablation Grid Runner
Launches every condition × size × seed as its own subprocess. Each run
gets its own experiments/runs/<run_id>/ directory with isolated logs,
results, checkpoints, and figures.
Use the launchers (they fork this under nohup):
bash scripts/run_quick_ablations.sh
bash scripts/run_full_grid.sh
You can also call directly:
python -m experiments.run_ablations --quick
python -m experiments.run_ablations --parallel
"""
from __future__ import annotations
import argparse
import itertools
import os
import subprocess
import sys
from datetime import datetime, timezone
from utils.run_dir import DEFAULT_BASE, ensure_run_dir
# (weight_decay, init_scale, use_grokfast, label) — labels feed into wandb tags
# For Camelyon17: baseline WD is 5e-3 (empirically optimal)
GRID = [
(0.0, 1.0, False, "wd0_a1"), # control — no regularization
(1e-4, 1.0, False, "wd1e4_a1"), # mild WD only
(5e-3, 1.0, False, "wd5e3_a1"), # baseline WD, standard init
(5e-3, 4.0, False, "wd5e3_a4"), # baseline WD + large init
(5e-3, 4.0, True, "wd5e3_a4_gf"), # full recipe (main condition)
(1e-2, 4.0, True, "wd1e2_a4_gf"), # higher WD variant
]
SIZES = [100, 250, 500, 1000, 2000]
SEEDS = [42, 123, 456]
def cmd_for(wd, alpha, gf, n_train, seed, run_dir, n_epochs=None):
"""
Build the per-cell training command. The trainer exposes
--weight_decay, --init_scale, --grokfast {on,off}, --n_epochs as
overrides, so each grid cell actually runs with its own knobs.
Compute saver: cells that cannot grok (low/no weight decay AND
Grokfast disabled) get only 300 epochs instead of 3000. They
flatline either way; we just want the baseline non-grokking
accuracy to compare against. This trims roughly 30% off the full
grid wall time without losing any signal.
"""
if n_epochs is None:
can_grok = (wd >= 1e-3) or gf
n_epochs = 3000 if can_grok else 300
return [
sys.executable, "-m", "experiments.causalgrok_camelyon_v2",
"--condition", "grokking",
"--n_train", str(n_train),
"--seed", str(seed),
"--weight_decay", str(wd),
"--init_scale", str(alpha),
"--n_epochs", str(n_epochs),
"--grokfast", "on" if gf else "off",
"--wandb_project", "causalgrok",
"--run_dir", run_dir,
]
def build_run_dir(label, n_train, seed):
stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
run_id = f"{stamp}_{label}_n{n_train}_s{seed}"
run_dir = os.path.join(DEFAULT_BASE, run_id)
ensure_run_dir(run_dir)
return run_dir
def run_quick():
"""Minimal sanity probe: control vs. full recipe at n=500, seed 42."""
for wd, alpha, gf, label in [GRID[0], GRID[4]]:
run_dir = build_run_dir(label, 500, 42)
log = os.path.join(run_dir, "logs", "train.log")
err = os.path.join(run_dir, "logs", "train.err")
print(f"\n>> {label} n=500 seed=42 → {run_dir}")
with open(log, "w") as out, open(err, "w") as ferr:
subprocess.run(cmd_for(wd, alpha, gf, 500, 42, run_dir),
stdout=out, stderr=ferr, check=True)
def run_parallel(grid, sizes, seeds, n_gpus=None):
if n_gpus is None:
try:
import torch
n_gpus = max(1, torch.cuda.device_count())
except Exception:
n_gpus = 1
procs = []
for idx, ((wd, alpha, gf, label), n, seed) in enumerate(
itertools.product(grid, sizes, seeds)):
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(idx % n_gpus)
run_dir = build_run_dir(label, n, seed)
log = os.path.join(run_dir, "logs", "train.log")
err = os.path.join(run_dir, "logs", "train.err")
out_f = open(log, "w")
err_f = open(err, "w")
p = subprocess.Popen(cmd_for(wd, alpha, gf, n, seed, run_dir),
env=env, stdout=out_f, stderr=err_f)
print(f" GPU {idx % n_gpus}: {label} n={n} seed={seed} PID={p.pid}{run_dir}")
procs.append((p, out_f, err_f, run_dir))
print(f"\nLaunched {len(procs)} jobs. Waiting...", flush=True)
for p, out_f, err_f, run_dir in procs:
rc = p.wait()
out_f.close(); err_f.close()
status = "OK" if rc == 0 else f"FAILED rc={rc}"
print(f" {status} {run_dir}", flush=True)
print("All done.")
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--quick", action="store_true")
p.add_argument("--parallel", action="store_true")
p.add_argument("--n_gpus", type=int, default=None,
help="Override torch.cuda.device_count()")
args = p.parse_args()
if args.quick:
run_quick()
elif args.parallel:
run_parallel(GRID, SIZES, SEEDS, n_gpus=args.n_gpus)
else:
run_parallel(GRID[:1], [500], [42], n_gpus=args.n_gpus)