| |
| """Find the largest batch size that fits, to fill the GPU for training. |
| |
| Runs real train steps at increasing batch sizes (bf16 + grad-checkpoint), |
| records peak memory, stops at OOM. Use the reported max to set train.batch_size |
| (leave ~10-15% headroom).""" |
|
|
| import argparse |
| import torch |
|
|
| from mapgs.config import load_config |
| from mapgs.data import UnifiedClipDataset, collate_samples |
| from mapgs.train import Trainer |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--root", default="/mnt/william/data/unified/av2") |
| ap.add_argument("--override", nargs="*", default=[]) |
| ap.add_argument("--batches", nargs="*", type=int, default=[1, 2, 4, 6, 8, 12, 16, 24, 32]) |
| ap.add_argument("--n-sup", type=int, default=4) |
| args = ap.parse_args() |
|
|
| cfg = load_config("configs/base.yaml", [ |
| "data.name=unified", f"data.root={args.root}", "data.num_frames=20", |
| "data.height=256", "data.width=384", "train.amp=true", "train.grad_checkpoint=true", |
| ] + args.override) |
| total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 |
| print(f"GPU total {total_gb:.0f} GB | model embed={cfg.model.embed_dim} dec={cfg.model.dec_depth} " |
| f"tokens={cfg.n_static_tokens} gpt={cfg.model.tokens.gaussians_per_token} feat={cfg.model.feature_dim}") |
|
|
| ds = UnifiedClipDataset(cfg, roots=args.root, split="train", n_sup_views=args.n_sup) |
| maxB = max(args.batches) |
| samples = [ds[i % len(ds)] for i in range(maxB)] |
| trainer = Trainer(cfg) |
|
|
| best = 0 |
| for B in args.batches: |
| try: |
| torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats() |
| batch = collate_samples(samples[:B]) |
| for _ in range(2): |
| trainer.train_step(batch) |
| peak = torch.cuda.max_memory_allocated() / 1e9 |
| print(f" B={B:>3} OK peak {peak:5.1f} GB ({100*peak/total_gb:.0f}% of GPU)", flush=True) |
| best = B |
| except torch.cuda.OutOfMemoryError: |
| print(f" B={B:>3} OOM", flush=True) |
| torch.cuda.empty_cache() |
| break |
| print(f"==> max batch that fits: {best} (recommend train.batch_size={int(best*0.85) if best>2 else best})") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|