BitLinear / benchmarks /benchmark_performance.py
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
"""
Performance benchmarking for BitLinear vs nn.Linear.
This script benchmarks forward pass time for various layer sizes and batch sizes,
comparing BitLinear (Python implementation) with standard nn.Linear.
"""
import torch
import torch.nn as nn
import time
from bitlinear import BitLinear, MultiTernaryLinear
import sys
def benchmark_forward_pass(layer, x, n_warmup=10, n_runs=100):
"""
Benchmark forward pass time for a layer.
Args:
layer: PyTorch module to benchmark
x: Input tensor
n_warmup: Number of warmup iterations
n_runs: Number of benchmark iterations
Returns:
Average time per forward pass in milliseconds
"""
# Warmup
with torch.no_grad():
for _ in range(n_warmup):
_ = layer(x)
# Benchmark
start_time = time.time()
with torch.no_grad():
for _ in range(n_runs):
_ = layer(x)
end_time = time.time()
avg_time_ms = (end_time - start_time) / n_runs * 1000
return avg_time_ms
def run_benchmarks():
"""Run comprehensive benchmarks."""
print("=" * 100)
print("BitLinear Performance Benchmarks")
print("=" * 100)
print(f"\nPyTorch version: {torch.__version__}")
print(f"Device: CPU")
print(f"Number of warmup runs: 10")
print(f"Number of benchmark runs: 100")
# Test configurations
layer_sizes = [
(512, 512),
(1024, 1024),
(2048, 2048),
(4096, 4096),
]
batch_configs = [
(1, 1), # Single token
(16, 128), # Small batch
(32, 128), # Medium batch
(64, 128), # Large batch
]
results = []
for in_features, out_features in layer_sizes:
print(f"\n{'=' * 100}")
print(f"Layer Size: {in_features}{out_features}")
print(f"{'=' * 100}")
for batch_size, seq_len in batch_configs:
print(f"\nBatch: {batch_size}, Seq Length: {seq_len}")
print("-" * 100)
# Create input
x = torch.randn(batch_size, seq_len, in_features)
# Create layers
linear = nn.Linear(in_features, out_features)
bitlinear = BitLinear.from_linear(linear)
multi_ternary = MultiTernaryLinear.from_linear(linear, k=2)
# Benchmark nn.Linear
time_linear = benchmark_forward_pass(linear, x)
# Benchmark BitLinear
time_bitlinear = benchmark_forward_pass(bitlinear, x)
# Benchmark MultiTernaryLinear
time_multi = benchmark_forward_pass(multi_ternary, x)
# Calculate speedup/slowdown
speedup_bit = time_linear / time_bitlinear
speedup_multi = time_linear / time_multi
# Print results
print(f"nn.Linear: {time_linear:8.3f} ms")
print(f"BitLinear: {time_bitlinear:8.3f} ms (speedup: {speedup_bit:5.2f}x)")
print(f"MultiTernaryLinear: {time_multi:8.3f} ms (speedup: {speedup_multi:5.2f}x)")
# Store results
results.append({
'in_features': in_features,
'out_features': out_features,
'batch_size': batch_size,
'seq_len': seq_len,
'time_linear': time_linear,
'time_bitlinear': time_bitlinear,
'time_multi': time_multi,
'speedup_bit': speedup_bit,
'speedup_multi': speedup_multi,
})
# Generate markdown table
print(f"\n\n{'=' * 100}")
print("Summary Table (Markdown Format)")
print(f"{'=' * 100}\n")
print("| Layer Size | Batch | Seq Len | nn.Linear (ms) | BitLinear (ms) | Speedup | Multi-Ternary (ms) | Speedup |")
print("|------------|-------|---------|----------------|----------------|---------|--------------------|---------| ")
for r in results:
print(f"| {r['in_features']}×{r['out_features']:<4} | {r['batch_size']:5} | {r['seq_len']:7} | "
f"{r['time_linear']:14.3f} | {r['time_bitlinear']:14.3f} | {r['speedup_bit']:7.2f} | "
f"{r['time_multi']:18.3f} | {r['speedup_multi']:7.2f} |")
# Summary statistics
print(f"\n{'=' * 100}")
print("Summary Statistics")
print(f"{'=' * 100}\n")
avg_speedup_bit = sum(r['speedup_bit'] for r in results) / len(results)
avg_speedup_multi = sum(r['speedup_multi'] for r in results) / len(results)
print(f"Average BitLinear speedup: {avg_speedup_bit:.2f}x")
print(f"Average Multi-Ternary speedup: {avg_speedup_multi:.2f}x")
if avg_speedup_bit < 1.0:
print(f"\nNote: BitLinear is slower than nn.Linear by {1/avg_speedup_bit:.2f}x on average.")
print("This is expected for the Python implementation. C++/CUDA extensions would be faster.")
else:
print(f"\nNote: BitLinear is faster than nn.Linear by {avg_speedup_bit:.2f}x on average!")
print(f"\n{'=' * 100}")
print("Benchmark Complete!")
print(f"{'=' * 100}")
if __name__ == "__main__":
run_benchmarks()