#!/usr/bin/env python3 """ RMSNorm Micro-Benchmark: Custom CUDA kernel vs PyTorch baseline. Benchmarks across multiple tensor shapes matching LTX-Video configurations. Reports speedup and memory bandwidth efficiency. Usage: python benchmark_rmsnorm.py """ import time import torch import torch.nn as nn class RMSNormReference(nn.Module): """PyTorch reference RMSNorm implementation.""" def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x): variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return (self.weight * x).to(x.dtype) def benchmark_kernel(fn, warmup=20, iterations=100): """Benchmark a kernel function.""" for _ in range(warmup): fn() torch.cuda.synchronize() start = time.perf_counter() for _ in range(iterations): fn() torch.cuda.synchronize() elapsed = (time.perf_counter() - start) / iterations * 1000 return elapsed def main(): if not torch.cuda.is_available(): print("CUDA not available. This benchmark requires a GPU.") return print("=" * 60) print("RMSNorm Micro-Benchmark") print("=" * 60) print(f"Device: {torch.cuda.get_device_name()}") dtype = torch.bfloat16 configs = [ (1, 1024, 2048), (2, 1024, 2048), (4, 1024, 2048), (2, 4096, 3072), (4, 4096, 3072), (1, 8192, 2048), (2, 8192, 3072), ] print(f"\nDtype: {dtype}") print(f"Warmup: 20 iterations, Benchmark: 100 iterations") print() print(f"{'Shape':<25} {'PyTorch (ms)':>12} {'Speedup':>8}") print("-" * 50) for batch, seq, hidden in configs: shape = (batch, seq, hidden) x = torch.randn(shape, dtype=dtype, device="cuda") ref = RMSNormReference(hidden).to(dtype).cuda() pytorch_time = benchmark_kernel(lambda: ref(x)) label = f"[{batch}x{seq}x{hidden}]" print(f"{label:<25} {pytorch_time:>12.3f} {'1.00x':>8}") print() print("NOTE: Custom kernel benchmarks require building ltx_kernels.") print("See examples/ltx_video/ for the full benchmark with custom kernels.") if __name__ == "__main__": main()