mapvggt / scripts /find_max_batch.py
ChenmingWu's picture
Upload folder using huggingface_hub
8cf92b3 verified
Raw
History Blame Contribute Delete
2.26 kB
#!/usr/bin/env python3
"""Find the largest batch size that fits, to fill the GPU for training.
Runs real train steps at increasing batch sizes (bf16 + grad-checkpoint),
records peak memory, stops at OOM. Use the reported max to set train.batch_size
(leave ~10-15% headroom)."""
import argparse
import torch
from mapgs.config import load_config
from mapgs.data import UnifiedClipDataset, collate_samples
from mapgs.train import Trainer
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--root", default="/mnt/william/data/unified/av2")
ap.add_argument("--override", nargs="*", default=[])
ap.add_argument("--batches", nargs="*", type=int, default=[1, 2, 4, 6, 8, 12, 16, 24, 32])
ap.add_argument("--n-sup", type=int, default=4)
args = ap.parse_args()
cfg = load_config("configs/base.yaml", [
"data.name=unified", f"data.root={args.root}", "data.num_frames=20",
"data.height=256", "data.width=384", "train.amp=true", "train.grad_checkpoint=true",
] + args.override)
total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"GPU total {total_gb:.0f} GB | model embed={cfg.model.embed_dim} dec={cfg.model.dec_depth} "
f"tokens={cfg.n_static_tokens} gpt={cfg.model.tokens.gaussians_per_token} feat={cfg.model.feature_dim}")
ds = UnifiedClipDataset(cfg, roots=args.root, split="train", n_sup_views=args.n_sup)
maxB = max(args.batches)
samples = [ds[i % len(ds)] for i in range(maxB)]
trainer = Trainer(cfg)
best = 0
for B in args.batches:
try:
torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats()
batch = collate_samples(samples[:B])
for _ in range(2):
trainer.train_step(batch)
peak = torch.cuda.max_memory_allocated() / 1e9
print(f" B={B:>3} OK peak {peak:5.1f} GB ({100*peak/total_gb:.0f}% of GPU)", flush=True)
best = B
except torch.cuda.OutOfMemoryError:
print(f" B={B:>3} OOM", flush=True)
torch.cuda.empty_cache()
break
print(f"==> max batch that fits: {best} (recommend train.batch_size={int(best*0.85) if best>2 else best})")
if __name__ == "__main__":
main()