File size: 2,352 Bytes
88a1dd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
"""
RMSNorm Micro-Benchmark: Custom CUDA kernel vs PyTorch baseline.

Benchmarks across multiple tensor shapes matching LTX-Video configurations.
Reports speedup and memory bandwidth efficiency.

Usage:
    python benchmark_rmsnorm.py
"""

import time

import torch
import torch.nn as nn


class RMSNormReference(nn.Module):
    """PyTorch reference RMSNorm implementation."""

    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, x):
        variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.eps)
        return (self.weight * x).to(x.dtype)


def benchmark_kernel(fn, warmup=20, iterations=100):
    """Benchmark a kernel function."""
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    start = time.perf_counter()
    for _ in range(iterations):
        fn()
    torch.cuda.synchronize()
    elapsed = (time.perf_counter() - start) / iterations * 1000
    return elapsed


def main():
    if not torch.cuda.is_available():
        print("CUDA not available. This benchmark requires a GPU.")
        return

    print("=" * 60)
    print("RMSNorm Micro-Benchmark")
    print("=" * 60)
    print(f"Device: {torch.cuda.get_device_name()}")

    dtype = torch.bfloat16
    configs = [
        (1, 1024, 2048),
        (2, 1024, 2048),
        (4, 1024, 2048),
        (2, 4096, 3072),
        (4, 4096, 3072),
        (1, 8192, 2048),
        (2, 8192, 3072),
    ]

    print(f"\nDtype: {dtype}")
    print(f"Warmup: 20 iterations, Benchmark: 100 iterations")
    print()
    print(f"{'Shape':<25} {'PyTorch (ms)':>12} {'Speedup':>8}")
    print("-" * 50)

    for batch, seq, hidden in configs:
        shape = (batch, seq, hidden)
        x = torch.randn(shape, dtype=dtype, device="cuda")

        ref = RMSNormReference(hidden).to(dtype).cuda()

        pytorch_time = benchmark_kernel(lambda: ref(x))

        label = f"[{batch}x{seq}x{hidden}]"
        print(f"{label:<25} {pytorch_time:>12.3f} {'1.00x':>8}")

    print()
    print("NOTE: Custom kernel benchmarks require building ltx_kernels.")
    print("See examples/ltx_video/ for the full benchmark with custom kernels.")


if __name__ == "__main__":
    main()