test / skill_example /scripts /benchmark_rmsnorm.py
Jack-Khuu
Demo
88a1dd2
#!/usr/bin/env python3
"""
RMSNorm Micro-Benchmark: Custom CUDA kernel vs PyTorch baseline.
Benchmarks across multiple tensor shapes matching LTX-Video configurations.
Reports speedup and memory bandwidth efficiency.
Usage:
python benchmark_rmsnorm.py
"""
import time
import torch
import torch.nn as nn
class RMSNormReference(nn.Module):
"""PyTorch reference RMSNorm implementation."""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return (self.weight * x).to(x.dtype)
def benchmark_kernel(fn, warmup=20, iterations=100):
"""Benchmark a kernel function."""
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iterations):
fn()
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / iterations * 1000
return elapsed
def main():
if not torch.cuda.is_available():
print("CUDA not available. This benchmark requires a GPU.")
return
print("=" * 60)
print("RMSNorm Micro-Benchmark")
print("=" * 60)
print(f"Device: {torch.cuda.get_device_name()}")
dtype = torch.bfloat16
configs = [
(1, 1024, 2048),
(2, 1024, 2048),
(4, 1024, 2048),
(2, 4096, 3072),
(4, 4096, 3072),
(1, 8192, 2048),
(2, 8192, 3072),
]
print(f"\nDtype: {dtype}")
print(f"Warmup: 20 iterations, Benchmark: 100 iterations")
print()
print(f"{'Shape':<25} {'PyTorch (ms)':>12} {'Speedup':>8}")
print("-" * 50)
for batch, seq, hidden in configs:
shape = (batch, seq, hidden)
x = torch.randn(shape, dtype=dtype, device="cuda")
ref = RMSNormReference(hidden).to(dtype).cuda()
pytorch_time = benchmark_kernel(lambda: ref(x))
label = f"[{batch}x{seq}x{hidden}]"
print(f"{label:<25} {pytorch_time:>12.3f} {'1.00x':>8}")
print()
print("NOTE: Custom kernel benchmarks require building ltx_kernels.")
print("See examples/ltx_video/ for the full benchmark with custom kernels.")
if __name__ == "__main__":
main()