SparseVLM / benchmark /bench_layer1.py
Aryan3108's picture
Upload folder using huggingface_hub
176b11a verified
Raw
History Blame Contribute Delete
4.61 kB
"""
benchmark/bench_layer1.py
--------------------------
Run this first when you connect your RunPod instance.
Proves every Layer 1 component is faster than baseline.
Usage:
python benchmark/bench_layer1.py
"""
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import time
from kernels.rank_estimator import sketch_rank, estimate_prune_counts
from kernels.varlen_packing import pack_varlen_batch
from kernels.sparse_attn import sparse_vision_attn
def timeit(fn, n_warmup=5, n_runs=50, device="cpu"):
for _ in range(n_warmup):
fn()
if device == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(n_runs):
fn()
if device == "cuda":
torch.cuda.synchronize()
return (time.perf_counter() - t0) / n_runs * 1000
def bench_rank(device):
print("\n── Rank Estimator ───────────────────────────────────────────────")
print(f"{'Config':<30} {'SVD':>10} {'Sketch':>10} {'Speedup':>10} {'MaxErr':>10}")
print("─" * 75)
for B, T, V in [(1,77,196),(4,77,196),(8,77,196),(8,128,576)]:
P = torch.rand(B, T, V, device=device)
P = P / P.sum(dim=-1, keepdim=True)
svd_ms = timeit(lambda: torch.stack([torch.linalg.matrix_rank(P[i]) for i in range(B)]), device=device)
sketch_ms = timeit(lambda: sketch_rank(P), device=device)
r_svd = torch.stack([torch.linalg.matrix_rank(P[i]) for i in range(B)]).float()
r_skc = sketch_rank(P).float()
err = (r_svd - r_skc).abs().max().item()
print(f"B={B} T={T} V={V:<10} {svd_ms:>9.1f}ms {sketch_ms:>9.1f}ms {svd_ms/sketch_ms:>9.1f}x {err:>10.0f}")
def bench_packing(device):
print("\n── Varlen Packing ───────────────────────────────────────────────")
print(f"{'Config':<35} {'pad_seq':>10} {'pack':>10} {'Speedup':>10} {'Mem':>10}")
print("─" * 80)
from torch.nn.utils.rnn import pad_sequence
for B, D, lens in [
(4, 768, [120, 80, 100, 90]),
(8, 768, [160, 80, 90, 110, 140, 70, 130, 100]),
]:
tokens = [torch.randn(L, D, device=device) for L in lens]
pad_ms = timeit(lambda: pad_sequence(tokens, batch_first=True), device=device)
pack_ms = timeit(lambda: pack_varlen_batch(tokens), device=device)
pack_mem = sum(lens) * D
pad_mem = max(lens) * B * D
saving = (pack_mem / pad_mem - 1) * 100
label = f"B={B} D={D} lens=[{min(lens)}..{max(lens)}]"
print(f"{label:<35} {pad_ms:>9.2f}ms {pack_ms:>9.2f}ms {pad_ms/pack_ms:>9.1f}x {saving:>+9.0f}%")
def bench_sparse_attn(device):
print("\n── Sparse Attention ─────────────────────────────────────────────")
print(f"{'Config':<38} {'Dense':>10} {'Sparse':>10} {'Speedup':>10} {'MaxErr':>10}")
print("─" * 83)
for B, N_vis, K, T, D in [
(1,196,80,77,768),
(4,196,80,77,768),
(8,196,80,77,768),
(8,576,127,77,1024),
]:
patch = torch.randn(B, N_vis, D, device=device)
text = torch.randn(B, T, D, device=device)
kept = torch.stack([torch.randperm(N_vis, device=device)[:K] for _ in range(B)])
scale = D ** -0.5
dense_ms = timeit(lambda: torch.bmm(patch, text.transpose(1,2)) * scale, device=device)
sparse_ms = timeit(lambda: sparse_vision_attn(patch, text, kept, use_triton=False), device=device)
dense_out = torch.bmm(patch, text.transpose(1,2)) * scale
sparse_out = sparse_vision_attn(patch, text, kept, use_triton=False)
idx = kept.unsqueeze(-1).expand(B, K, T)
err = (torch.gather(dense_out,1,idx) - sparse_out).abs().max().item()
label = f"B={B} N={N_vis} K={K} T={T} D={D}"
print(f"{label:<38} {dense_ms:>9.2f}ms {sparse_ms:>9.2f}ms {dense_ms/sparse_ms:>9.1f}x {err:>10.2e}")
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nSparseVLM Layer 1 Benchmark | Device: {device}")
if device == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
bench_rank(device)
bench_packing(device)
bench_sparse_attn(device)
print("\n── Done. Replace README.md benchmark table with these numbers. ──\n")