| |
| """Combined Waymo + AV2 real-data training (GPU-saturating batch). |
| |
| Trains on the union of the unified roots (Waymo + AV2), evals interpolation + |
| lane consistency on the AV2 val split. Batch sized from the max-batch profile to |
| fill the H200 (batch_size * grad_accum = effective batch).""" |
|
|
| import argparse, time |
| import torch |
| from mapgs.config import load_config |
| from mapgs.data import UnifiedClipDataset, collate_samples |
| from mapgs.eval import Evaluator |
| from mapgs.train import Trainer |
|
|
|
|
| def fmt(d): return {k: (round(v, 3) if isinstance(v, float) else v) for k, v in d.items()} |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--iters", type=int, default=4000) |
| ap.add_argument("--roots", default="/mnt/william/data/unified/waymo,/mnt/william/data/unified/av2") |
| ap.add_argument("--val-root", default="/mnt/william/data/unified/av2") |
| ap.add_argument("--batch", type=int, default=12) |
| ap.add_argument("--grad-accum", type=int, default=2) |
| args = ap.parse_args() |
|
|
| cfg = load_config("configs/base.yaml", [ |
| "data.name=unified", f"data.root={args.roots}", "data.num_frames=20", |
| "data.height=256", "data.width=384", |
| "model.embed_dim=768", "model.enc_depth=4", "model.dec_depth=12", "model.n_heads=12", |
| "model.tokens.gaussians_per_token=32", "model.feature_dim=64", |
| "train.amp=true", "train.grad_checkpoint=true", "train.num_workers=4", |
| f"train.batch_size={args.batch}", f"train.grad_accum={args.grad_accum}", |
| "train.lr=1.0e-4", "train.warmup=500", f"train.iters={args.iters}", |
| "train.extrap_ramp_iter=1000", "train.log_every=20", "train.ckpt_every=500", |
| "train.out_dir=runs/mapgs_combined", |
| ]) |
| roots = [r.strip() for r in args.roots.split(",")] |
| train_ds = UnifiedClipDataset(cfg, roots=roots, split="train", n_sup_views=4) |
| val_ds = UnifiedClipDataset(cfg, roots=args.val_root, split="val", n_sup_views=6) |
| print(f"combined train clips: {len(train_ds)} | val: {len(val_ds)} | " |
| f"batch {args.batch} x accum {args.grad_accum} = eff {args.batch*args.grad_accum}", flush=True) |
|
|
| trainer = Trainer(cfg) |
| ev = Evaluator(trainer.model, cfg, device="cuda") |
| print("BEFORE:", fmt(ev.interpolation(val_ds, max_scenes=40)), flush=True) |
| t = time.time() |
| trainer.fit(train_ds, max_iters=args.iters) |
| print(f"trained {args.iters} steps in {time.time()-t:.0f}s", flush=True) |
| trainer.model.eval() |
| print("AFTER :", fmt(ev.interpolation(val_ds, max_scenes=40)), flush=True) |
| print("LANE :", fmt(ev.lane_consistency(val_ds, max_scenes=30, frame=cfg.data.num_frames // 2)), flush=True) |
| trainer.save("runs/mapgs_combined/ckpt_final.pt") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|