File size: 5,415 Bytes
fd8c8b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
"""
Performance benchmarking for BitLinear vs nn.Linear.
This script benchmarks forward pass time for various layer sizes and batch sizes,
comparing BitLinear (Python implementation) with standard nn.Linear.
"""
import torch
import torch.nn as nn
import time
from bitlinear import BitLinear, MultiTernaryLinear
import sys
def benchmark_forward_pass(layer, x, n_warmup=10, n_runs=100):
"""
Benchmark forward pass time for a layer.
Args:
layer: PyTorch module to benchmark
x: Input tensor
n_warmup: Number of warmup iterations
n_runs: Number of benchmark iterations
Returns:
Average time per forward pass in milliseconds
"""
# Warmup
with torch.no_grad():
for _ in range(n_warmup):
_ = layer(x)
# Benchmark
start_time = time.time()
with torch.no_grad():
for _ in range(n_runs):
_ = layer(x)
end_time = time.time()
avg_time_ms = (end_time - start_time) / n_runs * 1000
return avg_time_ms
def run_benchmarks():
"""Run comprehensive benchmarks."""
print("=" * 100)
print("BitLinear Performance Benchmarks")
print("=" * 100)
print(f"\nPyTorch version: {torch.__version__}")
print(f"Device: CPU")
print(f"Number of warmup runs: 10")
print(f"Number of benchmark runs: 100")
# Test configurations
layer_sizes = [
(512, 512),
(1024, 1024),
(2048, 2048),
(4096, 4096),
]
batch_configs = [
(1, 1), # Single token
(16, 128), # Small batch
(32, 128), # Medium batch
(64, 128), # Large batch
]
results = []
for in_features, out_features in layer_sizes:
print(f"\n{'=' * 100}")
print(f"Layer Size: {in_features} → {out_features}")
print(f"{'=' * 100}")
for batch_size, seq_len in batch_configs:
print(f"\nBatch: {batch_size}, Seq Length: {seq_len}")
print("-" * 100)
# Create input
x = torch.randn(batch_size, seq_len, in_features)
# Create layers
linear = nn.Linear(in_features, out_features)
bitlinear = BitLinear.from_linear(linear)
multi_ternary = MultiTernaryLinear.from_linear(linear, k=2)
# Benchmark nn.Linear
time_linear = benchmark_forward_pass(linear, x)
# Benchmark BitLinear
time_bitlinear = benchmark_forward_pass(bitlinear, x)
# Benchmark MultiTernaryLinear
time_multi = benchmark_forward_pass(multi_ternary, x)
# Calculate speedup/slowdown
speedup_bit = time_linear / time_bitlinear
speedup_multi = time_linear / time_multi
# Print results
print(f"nn.Linear: {time_linear:8.3f} ms")
print(f"BitLinear: {time_bitlinear:8.3f} ms (speedup: {speedup_bit:5.2f}x)")
print(f"MultiTernaryLinear: {time_multi:8.3f} ms (speedup: {speedup_multi:5.2f}x)")
# Store results
results.append({
'in_features': in_features,
'out_features': out_features,
'batch_size': batch_size,
'seq_len': seq_len,
'time_linear': time_linear,
'time_bitlinear': time_bitlinear,
'time_multi': time_multi,
'speedup_bit': speedup_bit,
'speedup_multi': speedup_multi,
})
# Generate markdown table
print(f"\n\n{'=' * 100}")
print("Summary Table (Markdown Format)")
print(f"{'=' * 100}\n")
print("| Layer Size | Batch | Seq Len | nn.Linear (ms) | BitLinear (ms) | Speedup | Multi-Ternary (ms) | Speedup |")
print("|------------|-------|---------|----------------|----------------|---------|--------------------|---------| ")
for r in results:
print(f"| {r['in_features']}×{r['out_features']:<4} | {r['batch_size']:5} | {r['seq_len']:7} | "
f"{r['time_linear']:14.3f} | {r['time_bitlinear']:14.3f} | {r['speedup_bit']:7.2f} | "
f"{r['time_multi']:18.3f} | {r['speedup_multi']:7.2f} |")
# Summary statistics
print(f"\n{'=' * 100}")
print("Summary Statistics")
print(f"{'=' * 100}\n")
avg_speedup_bit = sum(r['speedup_bit'] for r in results) / len(results)
avg_speedup_multi = sum(r['speedup_multi'] for r in results) / len(results)
print(f"Average BitLinear speedup: {avg_speedup_bit:.2f}x")
print(f"Average Multi-Ternary speedup: {avg_speedup_multi:.2f}x")
if avg_speedup_bit < 1.0:
print(f"\nNote: BitLinear is slower than nn.Linear by {1/avg_speedup_bit:.2f}x on average.")
print("This is expected for the Python implementation. C++/CUDA extensions would be faster.")
else:
print(f"\nNote: BitLinear is faster than nn.Linear by {avg_speedup_bit:.2f}x on average!")
print(f"\n{'=' * 100}")
print("Benchmark Complete!")
print(f"{'=' * 100}")
if __name__ == "__main__":
run_benchmarks()
|