| """ |
| CausalGrok — Ablation Grid Runner |
| |
| Launches every condition × size × seed as its own subprocess. Each run |
| gets its own experiments/runs/<run_id>/ directory with isolated logs, |
| results, checkpoints, and figures. |
| |
| Use the launchers (they fork this under nohup): |
| bash scripts/run_quick_ablations.sh |
| bash scripts/run_full_grid.sh |
| |
| You can also call directly: |
| python -m experiments.run_ablations --quick |
| python -m experiments.run_ablations --parallel |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import itertools |
| import os |
| import subprocess |
| import sys |
| from datetime import datetime, timezone |
|
|
| from utils.run_dir import DEFAULT_BASE, ensure_run_dir |
|
|
| |
| |
| GRID = [ |
| (0.0, 1.0, False, "wd0_a1"), |
| (1e-4, 1.0, False, "wd1e4_a1"), |
| (5e-3, 1.0, False, "wd5e3_a1"), |
| (5e-3, 4.0, False, "wd5e3_a4"), |
| (5e-3, 4.0, True, "wd5e3_a4_gf"), |
| (1e-2, 4.0, True, "wd1e2_a4_gf"), |
| ] |
|
|
| SIZES = [100, 250, 500, 1000, 2000] |
| SEEDS = [42, 123, 456] |
|
|
|
|
| def cmd_for(wd, alpha, gf, n_train, seed, run_dir, n_epochs=None): |
| """ |
| Build the per-cell training command. The trainer exposes |
| --weight_decay, --init_scale, --grokfast {on,off}, --n_epochs as |
| overrides, so each grid cell actually runs with its own knobs. |
| |
| Compute saver: cells that cannot grok (low/no weight decay AND |
| Grokfast disabled) get only 300 epochs instead of 3000. They |
| flatline either way; we just want the baseline non-grokking |
| accuracy to compare against. This trims roughly 30% off the full |
| grid wall time without losing any signal. |
| """ |
| if n_epochs is None: |
| can_grok = (wd >= 1e-3) or gf |
| n_epochs = 3000 if can_grok else 300 |
| return [ |
| sys.executable, "-m", "experiments.causalgrok_camelyon_v2", |
| "--condition", "grokking", |
| "--n_train", str(n_train), |
| "--seed", str(seed), |
| "--weight_decay", str(wd), |
| "--init_scale", str(alpha), |
| "--n_epochs", str(n_epochs), |
| "--grokfast", "on" if gf else "off", |
| "--wandb_project", "causalgrok", |
| "--run_dir", run_dir, |
| ] |
|
|
|
|
| def build_run_dir(label, n_train, seed): |
| stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") |
| run_id = f"{stamp}_{label}_n{n_train}_s{seed}" |
| run_dir = os.path.join(DEFAULT_BASE, run_id) |
| ensure_run_dir(run_dir) |
| return run_dir |
|
|
|
|
| def run_quick(): |
| """Minimal sanity probe: control vs. full recipe at n=500, seed 42.""" |
| for wd, alpha, gf, label in [GRID[0], GRID[4]]: |
| run_dir = build_run_dir(label, 500, 42) |
| log = os.path.join(run_dir, "logs", "train.log") |
| err = os.path.join(run_dir, "logs", "train.err") |
| print(f"\n>> {label} n=500 seed=42 → {run_dir}") |
| with open(log, "w") as out, open(err, "w") as ferr: |
| subprocess.run(cmd_for(wd, alpha, gf, 500, 42, run_dir), |
| stdout=out, stderr=ferr, check=True) |
|
|
|
|
| def run_parallel(grid, sizes, seeds, n_gpus=None): |
| if n_gpus is None: |
| try: |
| import torch |
| n_gpus = max(1, torch.cuda.device_count()) |
| except Exception: |
| n_gpus = 1 |
| procs = [] |
| for idx, ((wd, alpha, gf, label), n, seed) in enumerate( |
| itertools.product(grid, sizes, seeds)): |
| env = os.environ.copy() |
| env["CUDA_VISIBLE_DEVICES"] = str(idx % n_gpus) |
| run_dir = build_run_dir(label, n, seed) |
| log = os.path.join(run_dir, "logs", "train.log") |
| err = os.path.join(run_dir, "logs", "train.err") |
| out_f = open(log, "w") |
| err_f = open(err, "w") |
| p = subprocess.Popen(cmd_for(wd, alpha, gf, n, seed, run_dir), |
| env=env, stdout=out_f, stderr=err_f) |
| print(f" GPU {idx % n_gpus}: {label} n={n} seed={seed} PID={p.pid} → {run_dir}") |
| procs.append((p, out_f, err_f, run_dir)) |
| print(f"\nLaunched {len(procs)} jobs. Waiting...", flush=True) |
| for p, out_f, err_f, run_dir in procs: |
| rc = p.wait() |
| out_f.close(); err_f.close() |
| status = "OK" if rc == 0 else f"FAILED rc={rc}" |
| print(f" {status} {run_dir}", flush=True) |
| print("All done.") |
|
|
|
|
| if __name__ == "__main__": |
| p = argparse.ArgumentParser() |
| p.add_argument("--quick", action="store_true") |
| p.add_argument("--parallel", action="store_true") |
| p.add_argument("--n_gpus", type=int, default=None, |
| help="Override torch.cuda.device_count()") |
| args = p.parse_args() |
|
|
| if args.quick: |
| run_quick() |
| elif args.parallel: |
| run_parallel(GRID, SIZES, SEEDS, n_gpus=args.n_gpus) |
| else: |
| run_parallel(GRID[:1], [500], [42], n_gpus=args.n_gpus) |
|
|