File size: 2,255 Bytes
8cf92b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python3
"""Find the largest batch size that fits, to fill the GPU for training.

Runs real train steps at increasing batch sizes (bf16 + grad-checkpoint),
records peak memory, stops at OOM. Use the reported max to set train.batch_size
(leave ~10-15% headroom)."""

import argparse
import torch

from mapgs.config import load_config
from mapgs.data import UnifiedClipDataset, collate_samples
from mapgs.train import Trainer


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--root", default="/mnt/william/data/unified/av2")
    ap.add_argument("--override", nargs="*", default=[])
    ap.add_argument("--batches", nargs="*", type=int, default=[1, 2, 4, 6, 8, 12, 16, 24, 32])
    ap.add_argument("--n-sup", type=int, default=4)
    args = ap.parse_args()

    cfg = load_config("configs/base.yaml", [
        "data.name=unified", f"data.root={args.root}", "data.num_frames=20",
        "data.height=256", "data.width=384", "train.amp=true", "train.grad_checkpoint=true",
    ] + args.override)
    total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU total {total_gb:.0f} GB | model embed={cfg.model.embed_dim} dec={cfg.model.dec_depth} "
          f"tokens={cfg.n_static_tokens} gpt={cfg.model.tokens.gaussians_per_token} feat={cfg.model.feature_dim}")

    ds = UnifiedClipDataset(cfg, roots=args.root, split="train", n_sup_views=args.n_sup)
    maxB = max(args.batches)
    samples = [ds[i % len(ds)] for i in range(maxB)]
    trainer = Trainer(cfg)

    best = 0
    for B in args.batches:
        try:
            torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats()
            batch = collate_samples(samples[:B])
            for _ in range(2):
                trainer.train_step(batch)
            peak = torch.cuda.max_memory_allocated() / 1e9
            print(f"  B={B:>3}  OK   peak {peak:5.1f} GB  ({100*peak/total_gb:.0f}% of GPU)", flush=True)
            best = B
        except torch.cuda.OutOfMemoryError:
            print(f"  B={B:>3}  OOM", flush=True)
            torch.cuda.empty_cache()
            break
    print(f"==> max batch that fits: {best}  (recommend train.batch_size={int(best*0.85) if best>2 else best})")


if __name__ == "__main__":
    main()