#!/usr/bin/env python3 """Combined Waymo + AV2 real-data training (GPU-saturating batch). Trains on the union of the unified roots (Waymo + AV2), evals interpolation + lane consistency on the AV2 val split. Batch sized from the max-batch profile to fill the H200 (batch_size * grad_accum = effective batch).""" import argparse, time import torch from mapgs.config import load_config from mapgs.data import UnifiedClipDataset, collate_samples from mapgs.eval import Evaluator from mapgs.train import Trainer def fmt(d): return {k: (round(v, 3) if isinstance(v, float) else v) for k, v in d.items()} def main(): ap = argparse.ArgumentParser() ap.add_argument("--iters", type=int, default=4000) ap.add_argument("--roots", default="/mnt/william/data/unified/waymo,/mnt/william/data/unified/av2") ap.add_argument("--val-root", default="/mnt/william/data/unified/av2") ap.add_argument("--batch", type=int, default=12) # from find_max_batch (57% of 150GB, spike headroom) ap.add_argument("--grad-accum", type=int, default=2) # effective batch = batch * grad_accum args = ap.parse_args() cfg = load_config("configs/base.yaml", [ "data.name=unified", f"data.root={args.roots}", "data.num_frames=20", "data.height=256", "data.width=384", "model.embed_dim=768", "model.enc_depth=4", "model.dec_depth=12", "model.n_heads=12", "model.tokens.gaussians_per_token=32", "model.feature_dim=64", "train.amp=true", "train.grad_checkpoint=true", "train.num_workers=4", f"train.batch_size={args.batch}", f"train.grad_accum={args.grad_accum}", "train.lr=1.0e-4", "train.warmup=500", f"train.iters={args.iters}", "train.extrap_ramp_iter=1000", "train.log_every=20", "train.ckpt_every=500", "train.out_dir=runs/mapgs_combined", ]) roots = [r.strip() for r in args.roots.split(",")] train_ds = UnifiedClipDataset(cfg, roots=roots, split="train", n_sup_views=4) val_ds = UnifiedClipDataset(cfg, roots=args.val_root, split="val", n_sup_views=6) print(f"combined train clips: {len(train_ds)} | val: {len(val_ds)} | " f"batch {args.batch} x accum {args.grad_accum} = eff {args.batch*args.grad_accum}", flush=True) trainer = Trainer(cfg) ev = Evaluator(trainer.model, cfg, device="cuda") print("BEFORE:", fmt(ev.interpolation(val_ds, max_scenes=40)), flush=True) t = time.time() trainer.fit(train_ds, max_iters=args.iters) print(f"trained {args.iters} steps in {time.time()-t:.0f}s", flush=True) trainer.model.eval() print("AFTER :", fmt(ev.interpolation(val_ds, max_scenes=40)), flush=True) print("LANE :", fmt(ev.lane_consistency(val_ds, max_scenes=30, frame=cfg.data.num_frames // 2)), flush=True) trainer.save("runs/mapgs_combined/ckpt_final.pt") if __name__ == "__main__": main()