Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| RMSNorm Micro-Benchmark: Custom CUDA kernel vs PyTorch baseline. | |
| Benchmarks across multiple tensor shapes matching LTX-Video configurations. | |
| Reports speedup and memory bandwidth efficiency. | |
| Usage: | |
| python benchmark_rmsnorm.py | |
| """ | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| class RMSNormReference(nn.Module): | |
| """PyTorch reference RMSNorm implementation.""" | |
| def __init__(self, hidden_size, eps=1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.eps = eps | |
| def forward(self, x): | |
| variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
| x = x * torch.rsqrt(variance + self.eps) | |
| return (self.weight * x).to(x.dtype) | |
| def benchmark_kernel(fn, warmup=20, iterations=100): | |
| """Benchmark a kernel function.""" | |
| for _ in range(warmup): | |
| fn() | |
| torch.cuda.synchronize() | |
| start = time.perf_counter() | |
| for _ in range(iterations): | |
| fn() | |
| torch.cuda.synchronize() | |
| elapsed = (time.perf_counter() - start) / iterations * 1000 | |
| return elapsed | |
| def main(): | |
| if not torch.cuda.is_available(): | |
| print("CUDA not available. This benchmark requires a GPU.") | |
| return | |
| print("=" * 60) | |
| print("RMSNorm Micro-Benchmark") | |
| print("=" * 60) | |
| print(f"Device: {torch.cuda.get_device_name()}") | |
| dtype = torch.bfloat16 | |
| configs = [ | |
| (1, 1024, 2048), | |
| (2, 1024, 2048), | |
| (4, 1024, 2048), | |
| (2, 4096, 3072), | |
| (4, 4096, 3072), | |
| (1, 8192, 2048), | |
| (2, 8192, 3072), | |
| ] | |
| print(f"\nDtype: {dtype}") | |
| print(f"Warmup: 20 iterations, Benchmark: 100 iterations") | |
| print() | |
| print(f"{'Shape':<25} {'PyTorch (ms)':>12} {'Speedup':>8}") | |
| print("-" * 50) | |
| for batch, seq, hidden in configs: | |
| shape = (batch, seq, hidden) | |
| x = torch.randn(shape, dtype=dtype, device="cuda") | |
| ref = RMSNormReference(hidden).to(dtype).cuda() | |
| pytorch_time = benchmark_kernel(lambda: ref(x)) | |
| label = f"[{batch}x{seq}x{hidden}]" | |
| print(f"{label:<25} {pytorch_time:>12.3f} {'1.00x':>8}") | |
| print() | |
| print("NOTE: Custom kernel benchmarks require building ltx_kernels.") | |
| print("See examples/ltx_video/ for the full benchmark with custom kernels.") | |
| if __name__ == "__main__": | |
| main() | |