| """ |
| benchmark/bench_layer1.py |
| -------------------------- |
| Run this first when you connect your RunPod instance. |
| Proves every Layer 1 component is faster than baseline. |
| |
| Usage: |
| python benchmark/bench_layer1.py |
| """ |
|
|
| import sys, os |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
|
|
| import torch |
| import time |
| from kernels.rank_estimator import sketch_rank, estimate_prune_counts |
| from kernels.varlen_packing import pack_varlen_batch |
| from kernels.sparse_attn import sparse_vision_attn |
|
|
|
|
| def timeit(fn, n_warmup=5, n_runs=50, device="cpu"): |
| for _ in range(n_warmup): |
| fn() |
| if device == "cuda": |
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
| for _ in range(n_runs): |
| fn() |
| if device == "cuda": |
| torch.cuda.synchronize() |
| return (time.perf_counter() - t0) / n_runs * 1000 |
|
|
|
|
| def bench_rank(device): |
| print("\nββ Rank Estimator βββββββββββββββββββββββββββββββββββββββββββββββ") |
| print(f"{'Config':<30} {'SVD':>10} {'Sketch':>10} {'Speedup':>10} {'MaxErr':>10}") |
| print("β" * 75) |
|
|
| for B, T, V in [(1,77,196),(4,77,196),(8,77,196),(8,128,576)]: |
| P = torch.rand(B, T, V, device=device) |
| P = P / P.sum(dim=-1, keepdim=True) |
|
|
| svd_ms = timeit(lambda: torch.stack([torch.linalg.matrix_rank(P[i]) for i in range(B)]), device=device) |
| sketch_ms = timeit(lambda: sketch_rank(P), device=device) |
|
|
| r_svd = torch.stack([torch.linalg.matrix_rank(P[i]) for i in range(B)]).float() |
| r_skc = sketch_rank(P).float() |
| err = (r_svd - r_skc).abs().max().item() |
|
|
| print(f"B={B} T={T} V={V:<10} {svd_ms:>9.1f}ms {sketch_ms:>9.1f}ms {svd_ms/sketch_ms:>9.1f}x {err:>10.0f}") |
|
|
|
|
| def bench_packing(device): |
| print("\nββ Varlen Packing βββββββββββββββββββββββββββββββββββββββββββββββ") |
| print(f"{'Config':<35} {'pad_seq':>10} {'pack':>10} {'Speedup':>10} {'Mem':>10}") |
| print("β" * 80) |
|
|
| from torch.nn.utils.rnn import pad_sequence |
|
|
| for B, D, lens in [ |
| (4, 768, [120, 80, 100, 90]), |
| (8, 768, [160, 80, 90, 110, 140, 70, 130, 100]), |
| ]: |
| tokens = [torch.randn(L, D, device=device) for L in lens] |
| pad_ms = timeit(lambda: pad_sequence(tokens, batch_first=True), device=device) |
| pack_ms = timeit(lambda: pack_varlen_batch(tokens), device=device) |
|
|
| pack_mem = sum(lens) * D |
| pad_mem = max(lens) * B * D |
| saving = (pack_mem / pad_mem - 1) * 100 |
|
|
| label = f"B={B} D={D} lens=[{min(lens)}..{max(lens)}]" |
| print(f"{label:<35} {pad_ms:>9.2f}ms {pack_ms:>9.2f}ms {pad_ms/pack_ms:>9.1f}x {saving:>+9.0f}%") |
|
|
|
|
| def bench_sparse_attn(device): |
| print("\nββ Sparse Attention βββββββββββββββββββββββββββββββββββββββββββββ") |
| print(f"{'Config':<38} {'Dense':>10} {'Sparse':>10} {'Speedup':>10} {'MaxErr':>10}") |
| print("β" * 83) |
|
|
| for B, N_vis, K, T, D in [ |
| (1,196,80,77,768), |
| (4,196,80,77,768), |
| (8,196,80,77,768), |
| (8,576,127,77,1024), |
| ]: |
| patch = torch.randn(B, N_vis, D, device=device) |
| text = torch.randn(B, T, D, device=device) |
| kept = torch.stack([torch.randperm(N_vis, device=device)[:K] for _ in range(B)]) |
|
|
| scale = D ** -0.5 |
| dense_ms = timeit(lambda: torch.bmm(patch, text.transpose(1,2)) * scale, device=device) |
| sparse_ms = timeit(lambda: sparse_vision_attn(patch, text, kept, use_triton=False), device=device) |
|
|
| dense_out = torch.bmm(patch, text.transpose(1,2)) * scale |
| sparse_out = sparse_vision_attn(patch, text, kept, use_triton=False) |
| idx = kept.unsqueeze(-1).expand(B, K, T) |
| err = (torch.gather(dense_out,1,idx) - sparse_out).abs().max().item() |
|
|
| label = f"B={B} N={N_vis} K={K} T={T} D={D}" |
| print(f"{label:<38} {dense_ms:>9.2f}ms {sparse_ms:>9.2f}ms {dense_ms/sparse_ms:>9.1f}x {err:>10.2e}") |
|
|
|
|
| if __name__ == "__main__": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"\nSparseVLM Layer 1 Benchmark | Device: {device}") |
| if device == "cuda": |
| print(f"GPU: {torch.cuda.get_device_name(0)} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB") |
| bench_rank(device) |
| bench_packing(device) |
| bench_sparse_attn(device) |
| print("\nββ Done. Replace README.md benchmark table with these numbers. ββ\n") |
|
|