mapvggt / scripts /train_combined.py
ChenmingWu's picture
Upload folder using huggingface_hub
8cf92b3 verified
Raw
History Blame Contribute Delete
2.86 kB
#!/usr/bin/env python3
"""Combined Waymo + AV2 real-data training (GPU-saturating batch).
Trains on the union of the unified roots (Waymo + AV2), evals interpolation +
lane consistency on the AV2 val split. Batch sized from the max-batch profile to
fill the H200 (batch_size * grad_accum = effective batch)."""
import argparse, time
import torch
from mapgs.config import load_config
from mapgs.data import UnifiedClipDataset, collate_samples
from mapgs.eval import Evaluator
from mapgs.train import Trainer
def fmt(d): return {k: (round(v, 3) if isinstance(v, float) else v) for k, v in d.items()}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--iters", type=int, default=4000)
ap.add_argument("--roots", default="/mnt/william/data/unified/waymo,/mnt/william/data/unified/av2")
ap.add_argument("--val-root", default="/mnt/william/data/unified/av2")
ap.add_argument("--batch", type=int, default=12) # from find_max_batch (57% of 150GB, spike headroom)
ap.add_argument("--grad-accum", type=int, default=2) # effective batch = batch * grad_accum
args = ap.parse_args()
cfg = load_config("configs/base.yaml", [
"data.name=unified", f"data.root={args.roots}", "data.num_frames=20",
"data.height=256", "data.width=384",
"model.embed_dim=768", "model.enc_depth=4", "model.dec_depth=12", "model.n_heads=12",
"model.tokens.gaussians_per_token=32", "model.feature_dim=64",
"train.amp=true", "train.grad_checkpoint=true", "train.num_workers=4",
f"train.batch_size={args.batch}", f"train.grad_accum={args.grad_accum}",
"train.lr=1.0e-4", "train.warmup=500", f"train.iters={args.iters}",
"train.extrap_ramp_iter=1000", "train.log_every=20", "train.ckpt_every=500",
"train.out_dir=runs/mapgs_combined",
])
roots = [r.strip() for r in args.roots.split(",")]
train_ds = UnifiedClipDataset(cfg, roots=roots, split="train", n_sup_views=4)
val_ds = UnifiedClipDataset(cfg, roots=args.val_root, split="val", n_sup_views=6)
print(f"combined train clips: {len(train_ds)} | val: {len(val_ds)} | "
f"batch {args.batch} x accum {args.grad_accum} = eff {args.batch*args.grad_accum}", flush=True)
trainer = Trainer(cfg)
ev = Evaluator(trainer.model, cfg, device="cuda")
print("BEFORE:", fmt(ev.interpolation(val_ds, max_scenes=40)), flush=True)
t = time.time()
trainer.fit(train_ds, max_iters=args.iters)
print(f"trained {args.iters} steps in {time.time()-t:.0f}s", flush=True)
trainer.model.eval()
print("AFTER :", fmt(ev.interpolation(val_ds, max_scenes=40)), flush=True)
print("LANE :", fmt(ev.lane_consistency(val_ds, max_scenes=30, frame=cfg.data.num_frames // 2)), flush=True)
trainer.save("runs/mapgs_combined/ckpt_final.pt")
if __name__ == "__main__":
main()