Spaces:
Running
Running
File size: 2,352 Bytes
88a1dd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | #!/usr/bin/env python3
"""
RMSNorm Micro-Benchmark: Custom CUDA kernel vs PyTorch baseline.
Benchmarks across multiple tensor shapes matching LTX-Video configurations.
Reports speedup and memory bandwidth efficiency.
Usage:
python benchmark_rmsnorm.py
"""
import time
import torch
import torch.nn as nn
class RMSNormReference(nn.Module):
"""PyTorch reference RMSNorm implementation."""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return (self.weight * x).to(x.dtype)
def benchmark_kernel(fn, warmup=20, iterations=100):
"""Benchmark a kernel function."""
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iterations):
fn()
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / iterations * 1000
return elapsed
def main():
if not torch.cuda.is_available():
print("CUDA not available. This benchmark requires a GPU.")
return
print("=" * 60)
print("RMSNorm Micro-Benchmark")
print("=" * 60)
print(f"Device: {torch.cuda.get_device_name()}")
dtype = torch.bfloat16
configs = [
(1, 1024, 2048),
(2, 1024, 2048),
(4, 1024, 2048),
(2, 4096, 3072),
(4, 4096, 3072),
(1, 8192, 2048),
(2, 8192, 3072),
]
print(f"\nDtype: {dtype}")
print(f"Warmup: 20 iterations, Benchmark: 100 iterations")
print()
print(f"{'Shape':<25} {'PyTorch (ms)':>12} {'Speedup':>8}")
print("-" * 50)
for batch, seq, hidden in configs:
shape = (batch, seq, hidden)
x = torch.randn(shape, dtype=dtype, device="cuda")
ref = RMSNormReference(hidden).to(dtype).cuda()
pytorch_time = benchmark_kernel(lambda: ref(x))
label = f"[{batch}x{seq}x{hidden}]"
print(f"{label:<25} {pytorch_time:>12.3f} {'1.00x':>8}")
print()
print("NOTE: Custom kernel benchmarks require building ltx_kernels.")
print("See examples/ltx_video/ for the full benchmark with custom kernels.")
if __name__ == "__main__":
main()
|