#!/usr/bin/env python3 """Find the largest batch size that fits, to fill the GPU for training. Runs real train steps at increasing batch sizes (bf16 + grad-checkpoint), records peak memory, stops at OOM. Use the reported max to set train.batch_size (leave ~10-15% headroom).""" import argparse import torch from mapgs.config import load_config from mapgs.data import UnifiedClipDataset, collate_samples from mapgs.train import Trainer def main(): ap = argparse.ArgumentParser() ap.add_argument("--root", default="/mnt/william/data/unified/av2") ap.add_argument("--override", nargs="*", default=[]) ap.add_argument("--batches", nargs="*", type=int, default=[1, 2, 4, 6, 8, 12, 16, 24, 32]) ap.add_argument("--n-sup", type=int, default=4) args = ap.parse_args() cfg = load_config("configs/base.yaml", [ "data.name=unified", f"data.root={args.root}", "data.num_frames=20", "data.height=256", "data.width=384", "train.amp=true", "train.grad_checkpoint=true", ] + args.override) total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 print(f"GPU total {total_gb:.0f} GB | model embed={cfg.model.embed_dim} dec={cfg.model.dec_depth} " f"tokens={cfg.n_static_tokens} gpt={cfg.model.tokens.gaussians_per_token} feat={cfg.model.feature_dim}") ds = UnifiedClipDataset(cfg, roots=args.root, split="train", n_sup_views=args.n_sup) maxB = max(args.batches) samples = [ds[i % len(ds)] for i in range(maxB)] trainer = Trainer(cfg) best = 0 for B in args.batches: try: torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats() batch = collate_samples(samples[:B]) for _ in range(2): trainer.train_step(batch) peak = torch.cuda.max_memory_allocated() / 1e9 print(f" B={B:>3} OK peak {peak:5.1f} GB ({100*peak/total_gb:.0f}% of GPU)", flush=True) best = B except torch.cuda.OutOfMemoryError: print(f" B={B:>3} OOM", flush=True) torch.cuda.empty_cache() break print(f"==> max batch that fits: {best} (recommend train.batch_size={int(best*0.85) if best>2 else best})") if __name__ == "__main__": main()