""" Performance benchmarking for BitLinear vs nn.Linear. This script benchmarks forward pass time for various layer sizes and batch sizes, comparing BitLinear (Python implementation) with standard nn.Linear. """ import torch import torch.nn as nn import time from bitlinear import BitLinear, MultiTernaryLinear import sys def benchmark_forward_pass(layer, x, n_warmup=10, n_runs=100): """ Benchmark forward pass time for a layer. Args: layer: PyTorch module to benchmark x: Input tensor n_warmup: Number of warmup iterations n_runs: Number of benchmark iterations Returns: Average time per forward pass in milliseconds """ # Warmup with torch.no_grad(): for _ in range(n_warmup): _ = layer(x) # Benchmark start_time = time.time() with torch.no_grad(): for _ in range(n_runs): _ = layer(x) end_time = time.time() avg_time_ms = (end_time - start_time) / n_runs * 1000 return avg_time_ms def run_benchmarks(): """Run comprehensive benchmarks.""" print("=" * 100) print("BitLinear Performance Benchmarks") print("=" * 100) print(f"\nPyTorch version: {torch.__version__}") print(f"Device: CPU") print(f"Number of warmup runs: 10") print(f"Number of benchmark runs: 100") # Test configurations layer_sizes = [ (512, 512), (1024, 1024), (2048, 2048), (4096, 4096), ] batch_configs = [ (1, 1), # Single token (16, 128), # Small batch (32, 128), # Medium batch (64, 128), # Large batch ] results = [] for in_features, out_features in layer_sizes: print(f"\n{'=' * 100}") print(f"Layer Size: {in_features} → {out_features}") print(f"{'=' * 100}") for batch_size, seq_len in batch_configs: print(f"\nBatch: {batch_size}, Seq Length: {seq_len}") print("-" * 100) # Create input x = torch.randn(batch_size, seq_len, in_features) # Create layers linear = nn.Linear(in_features, out_features) bitlinear = BitLinear.from_linear(linear) multi_ternary = MultiTernaryLinear.from_linear(linear, k=2) # Benchmark nn.Linear time_linear = benchmark_forward_pass(linear, x) # Benchmark BitLinear time_bitlinear = benchmark_forward_pass(bitlinear, x) # Benchmark MultiTernaryLinear time_multi = benchmark_forward_pass(multi_ternary, x) # Calculate speedup/slowdown speedup_bit = time_linear / time_bitlinear speedup_multi = time_linear / time_multi # Print results print(f"nn.Linear: {time_linear:8.3f} ms") print(f"BitLinear: {time_bitlinear:8.3f} ms (speedup: {speedup_bit:5.2f}x)") print(f"MultiTernaryLinear: {time_multi:8.3f} ms (speedup: {speedup_multi:5.2f}x)") # Store results results.append({ 'in_features': in_features, 'out_features': out_features, 'batch_size': batch_size, 'seq_len': seq_len, 'time_linear': time_linear, 'time_bitlinear': time_bitlinear, 'time_multi': time_multi, 'speedup_bit': speedup_bit, 'speedup_multi': speedup_multi, }) # Generate markdown table print(f"\n\n{'=' * 100}") print("Summary Table (Markdown Format)") print(f"{'=' * 100}\n") print("| Layer Size | Batch | Seq Len | nn.Linear (ms) | BitLinear (ms) | Speedup | Multi-Ternary (ms) | Speedup |") print("|------------|-------|---------|----------------|----------------|---------|--------------------|---------| ") for r in results: print(f"| {r['in_features']}×{r['out_features']:<4} | {r['batch_size']:5} | {r['seq_len']:7} | " f"{r['time_linear']:14.3f} | {r['time_bitlinear']:14.3f} | {r['speedup_bit']:7.2f} | " f"{r['time_multi']:18.3f} | {r['speedup_multi']:7.2f} |") # Summary statistics print(f"\n{'=' * 100}") print("Summary Statistics") print(f"{'=' * 100}\n") avg_speedup_bit = sum(r['speedup_bit'] for r in results) / len(results) avg_speedup_multi = sum(r['speedup_multi'] for r in results) / len(results) print(f"Average BitLinear speedup: {avg_speedup_bit:.2f}x") print(f"Average Multi-Ternary speedup: {avg_speedup_multi:.2f}x") if avg_speedup_bit < 1.0: print(f"\nNote: BitLinear is slower than nn.Linear by {1/avg_speedup_bit:.2f}x on average.") print("This is expected for the Python implementation. C++/CUDA extensions would be faster.") else: print(f"\nNote: BitLinear is faster than nn.Linear by {avg_speedup_bit:.2f}x on average!") print(f"\n{'=' * 100}") print("Benchmark Complete!") print(f"{'=' * 100}") if __name__ == "__main__": run_benchmarks()