"""Benchmark SSMGamma and GammaSingleBlock performance. Compares CUDA-optimized vs PyTorch implementations. Measures throughput, latency, and memory usage. """ import torch import torch.nn as nn import time import argparse from pathlib import Path from typing import Dict, Tuple from gamma_space_model import SSMGamma, GammaSingleBlock, HAS_CUDA_OPS def benchmark_forward_pass( model: nn.Module, batch_size: int, seq_len: int, d_model: int, num_iterations: int = 100, warmup_iterations: int = 20, device: str = "cpu", dtype: torch.dtype = torch.float32, ) -> Dict[str, float]: """ Benchmark forward pass latency. Args: model: Model to benchmark batch_size: Batch size seq_len: Sequence length d_model: Model dimension (state_dim for SSMGamma) num_iterations: Number of iterations to benchmark warmup_iterations: Number of warmup iterations device: Device to run on dtype: Data type Returns: Dictionary with latency statistics (ms) """ model.eval() # Create dummy input x = torch.randn(batch_size, seq_len, d_model, dtype=dtype, device=device) print(f" Warming up for {warmup_iterations} iterations...") with torch.no_grad(): for _ in range(warmup_iterations): _ = model(x) if device == "cuda": torch.cuda.synchronize() print(f" Benchmarking {num_iterations} iterations...") # Forward timing times = [] with torch.no_grad(): for _ in range(num_iterations): torch.cuda.synchronize() if device == "cuda" else None start = time.perf_counter() _ = model(x) torch.cuda.synchronize() if device == "cuda" else None end = time.perf_counter() times.append((end - start) * 1000) # Convert to ms times = torch.tensor(times) return { "mean_latency_ms": times.mean().item(), "median_latency_ms": times.median().item(), "min_latency_ms": times.min().item(), "max_latency_ms": times.max().item(), "std_latency_ms": times.std().item(), "throughput_samples_per_sec": 1000.0 / times.mean().item() * batch_size, } def benchmark_memory( model: nn.Module, batch_size: int, seq_len: int, d_model: int, device: str = "cuda", ) -> Dict[str, float]: """ Benchmark memory usage. Args: model: Model to benchmark batch_size: Batch size seq_len: Sequence length d_model: Model dimension device: Device (cuda for GPU memory) Returns: Dictionary with memory statistics """ if device != "cuda": return {"gpu_memory_allocated_mb": 0.0, "gpu_memory_reserved_mb": 0.0} model.eval() # Clear cache torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() # Create input x = torch.randn(batch_size, seq_len, d_model, device="cuda") # Forward pass with torch.no_grad(): _ = model(x) torch.cuda.synchronize() allocated = torch.cuda.memory_allocated() / (1024 ** 2) # MB reserved = torch.cuda.memory_reserved() / (1024 ** 2) # MB max_allocated = torch.cuda.max_memory_allocated() / (1024 ** 2) # MB return { "gpu_memory_allocated_mb": allocated, "gpu_memory_reserved_mb": reserved, "gpu_memory_max_allocated_mb": max_allocated, } def run_benchmarks(args): """Run comprehensive benchmarks.""" print("\n" + "=" * 80) print("SSM GAMMA BENCHMARK SUITE") print("=" * 80) device = "cuda" if torch.cuda.is_available() and args.device == "cuda" else "cpu" print(f"\nDevice: {device}") print(f"CUDA optimizations available: {HAS_CUDA_OPS}") # Test configurations configs = [ (4, 128, 64, "Small (seq=128)"), (8, 512, 128, "Medium (seq=512)"), (16, 2048, 256, "Large (seq=2048)"), ] results = {} for batch_size, seq_len, d_model, config_name in configs: print(f"\n{'-' * 80}") print(f"Configuration: {config_name}") print(f" Batch size: {batch_size}") print(f" Sequence length: {seq_len}") print(f" Model dimension: {d_model}") print(f" Total tokens: {batch_size * seq_len:,}") # SSMGamma benchmark print(f"\n SSMGamma benchmark:") ssm = SSMGamma(state_dim=d_model, hidden_dim=d_model * 2).to(device) ssm_results = benchmark_forward_pass( ssm, batch_size, seq_len, d_model, num_iterations=args.iterations, device=device, ) print(f" Latency: {ssm_results['mean_latency_ms']:.3f} ± {ssm_results['std_latency_ms']:.3f} ms") print(f" Throughput: {ssm_results['throughput_samples_per_sec']:.0f} tokens/sec") if device == "cuda": mem = benchmark_memory(ssm, batch_size, seq_len, d_model) print(f" GPU Memory: {mem['gpu_memory_max_allocated_mb']:.1f} MB") # GammaSingleBlock benchmark print(f"\n GammaSingleBlock benchmark:") block = GammaSingleBlock(d_model=d_model, hidden_dim=d_model * 2).to(device) block_results = benchmark_forward_pass( block, batch_size, seq_len, d_model, num_iterations=args.iterations, device=device, ) print(f" Latency: {block_results['mean_latency_ms']:.3f} ± {block_results['std_latency_ms']:.3f} ms") print(f" Throughput: {block_results['throughput_samples_per_sec']:.0f} tokens/sec") if device == "cuda": mem = benchmark_memory(block, batch_size, seq_len, d_model) print(f" GPU Memory: {mem['gpu_memory_max_allocated_mb']:.1f} MB") # Store results results[config_name] = { "ssm_gamma": ssm_results, "gamma_block": block_results, } print(f"\n{'=' * 80}") print("Benchmark complete!") print(f"{'=' * 80}\n") return results def main(): parser = argparse.ArgumentParser(description="Benchmark SSM Gamma blocks") parser.add_argument( "--device", type=str, choices=["cpu", "cuda"], default="cuda" if torch.cuda.is_available() else "cpu", help="Device to benchmark on" ) parser.add_argument( "--iterations", type=int, default=100, help="Number of benchmark iterations" ) parser.add_argument( "--warmup", type=int, default=20, help="Number of warmup iterations" ) args = parser.parse_args() run_benchmarks(args) if __name__ == "__main__": main()