File size: 4,608 Bytes
176b11a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
benchmark/bench_layer1.py
--------------------------
Run this first when you connect your RunPod instance.
Proves every Layer 1 component is faster than baseline.

Usage:
    python benchmark/bench_layer1.py
"""

import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import time
from kernels.rank_estimator import sketch_rank, estimate_prune_counts
from kernels.varlen_packing  import pack_varlen_batch
from kernels.sparse_attn     import sparse_vision_attn


def timeit(fn, n_warmup=5, n_runs=50, device="cpu"):
    for _ in range(n_warmup):
        fn()
    if device == "cuda":
        torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(n_runs):
        fn()
    if device == "cuda":
        torch.cuda.synchronize()
    return (time.perf_counter() - t0) / n_runs * 1000


def bench_rank(device):
    print("\n── Rank Estimator ───────────────────────────────────────────────")
    print(f"{'Config':<30} {'SVD':>10} {'Sketch':>10} {'Speedup':>10} {'MaxErr':>10}")
    print("─" * 75)

    for B, T, V in [(1,77,196),(4,77,196),(8,77,196),(8,128,576)]:
        P = torch.rand(B, T, V, device=device)
        P = P / P.sum(dim=-1, keepdim=True)

        svd_ms    = timeit(lambda: torch.stack([torch.linalg.matrix_rank(P[i]) for i in range(B)]), device=device)
        sketch_ms = timeit(lambda: sketch_rank(P), device=device)

        r_svd = torch.stack([torch.linalg.matrix_rank(P[i]) for i in range(B)]).float()
        r_skc = sketch_rank(P).float()
        err   = (r_svd - r_skc).abs().max().item()

        print(f"B={B} T={T} V={V:<10} {svd_ms:>9.1f}ms {sketch_ms:>9.1f}ms {svd_ms/sketch_ms:>9.1f}x {err:>10.0f}")


def bench_packing(device):
    print("\n── Varlen Packing ───────────────────────────────────────────────")
    print(f"{'Config':<35} {'pad_seq':>10} {'pack':>10} {'Speedup':>10} {'Mem':>10}")
    print("─" * 80)

    from torch.nn.utils.rnn import pad_sequence

    for B, D, lens in [
        (4, 768, [120, 80, 100, 90]),
        (8, 768, [160, 80, 90, 110, 140, 70, 130, 100]),
    ]:
        tokens = [torch.randn(L, D, device=device) for L in lens]
        pad_ms  = timeit(lambda: pad_sequence(tokens, batch_first=True), device=device)
        pack_ms = timeit(lambda: pack_varlen_batch(tokens), device=device)

        pack_mem = sum(lens) * D
        pad_mem  = max(lens) * B * D
        saving   = (pack_mem / pad_mem - 1) * 100

        label = f"B={B} D={D} lens=[{min(lens)}..{max(lens)}]"
        print(f"{label:<35} {pad_ms:>9.2f}ms {pack_ms:>9.2f}ms {pad_ms/pack_ms:>9.1f}x {saving:>+9.0f}%")


def bench_sparse_attn(device):
    print("\n── Sparse Attention ─────────────────────────────────────────────")
    print(f"{'Config':<38} {'Dense':>10} {'Sparse':>10} {'Speedup':>10} {'MaxErr':>10}")
    print("─" * 83)

    for B, N_vis, K, T, D in [
        (1,196,80,77,768),
        (4,196,80,77,768),
        (8,196,80,77,768),
        (8,576,127,77,1024),
    ]:
        patch = torch.randn(B, N_vis, D, device=device)
        text  = torch.randn(B, T, D, device=device)
        kept  = torch.stack([torch.randperm(N_vis, device=device)[:K] for _ in range(B)])

        scale = D ** -0.5
        dense_ms  = timeit(lambda: torch.bmm(patch, text.transpose(1,2)) * scale, device=device)
        sparse_ms = timeit(lambda: sparse_vision_attn(patch, text, kept, use_triton=False), device=device)

        dense_out  = torch.bmm(patch, text.transpose(1,2)) * scale
        sparse_out = sparse_vision_attn(patch, text, kept, use_triton=False)
        idx = kept.unsqueeze(-1).expand(B, K, T)
        err = (torch.gather(dense_out,1,idx) - sparse_out).abs().max().item()

        label = f"B={B} N={N_vis} K={K} T={T} D={D}"
        print(f"{label:<38} {dense_ms:>9.2f}ms {sparse_ms:>9.2f}ms {dense_ms/sparse_ms:>9.1f}x {err:>10.2e}")


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"\nSparseVLM Layer 1 Benchmark | Device: {device}")
    if device == "cuda":
        print(f"GPU: {torch.cuda.get_device_name(0)} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
    bench_rank(device)
    bench_packing(device)
    bench_sparse_attn(device)
    print("\n── Done. Replace README.md benchmark table with these numbers. ──\n")