|
|
"""
|
|
|
Performance benchmarking for BitLinear vs nn.Linear.
|
|
|
|
|
|
This script benchmarks forward pass time for various layer sizes and batch sizes,
|
|
|
comparing BitLinear (Python implementation) with standard nn.Linear.
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import time
|
|
|
from bitlinear import BitLinear, MultiTernaryLinear
|
|
|
import sys
|
|
|
|
|
|
|
|
|
def benchmark_forward_pass(layer, x, n_warmup=10, n_runs=100):
|
|
|
"""
|
|
|
Benchmark forward pass time for a layer.
|
|
|
|
|
|
Args:
|
|
|
layer: PyTorch module to benchmark
|
|
|
x: Input tensor
|
|
|
n_warmup: Number of warmup iterations
|
|
|
n_runs: Number of benchmark iterations
|
|
|
|
|
|
Returns:
|
|
|
Average time per forward pass in milliseconds
|
|
|
"""
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(n_warmup):
|
|
|
_ = layer(x)
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
with torch.no_grad():
|
|
|
for _ in range(n_runs):
|
|
|
_ = layer(x)
|
|
|
end_time = time.time()
|
|
|
|
|
|
avg_time_ms = (end_time - start_time) / n_runs * 1000
|
|
|
return avg_time_ms
|
|
|
|
|
|
|
|
|
def run_benchmarks():
|
|
|
"""Run comprehensive benchmarks."""
|
|
|
|
|
|
print("=" * 100)
|
|
|
print("BitLinear Performance Benchmarks")
|
|
|
print("=" * 100)
|
|
|
print(f"\nPyTorch version: {torch.__version__}")
|
|
|
print(f"Device: CPU")
|
|
|
print(f"Number of warmup runs: 10")
|
|
|
print(f"Number of benchmark runs: 100")
|
|
|
|
|
|
|
|
|
layer_sizes = [
|
|
|
(512, 512),
|
|
|
(1024, 1024),
|
|
|
(2048, 2048),
|
|
|
(4096, 4096),
|
|
|
]
|
|
|
|
|
|
batch_configs = [
|
|
|
(1, 1),
|
|
|
(16, 128),
|
|
|
(32, 128),
|
|
|
(64, 128),
|
|
|
]
|
|
|
|
|
|
results = []
|
|
|
|
|
|
for in_features, out_features in layer_sizes:
|
|
|
print(f"\n{'=' * 100}")
|
|
|
print(f"Layer Size: {in_features} → {out_features}")
|
|
|
print(f"{'=' * 100}")
|
|
|
|
|
|
for batch_size, seq_len in batch_configs:
|
|
|
print(f"\nBatch: {batch_size}, Seq Length: {seq_len}")
|
|
|
print("-" * 100)
|
|
|
|
|
|
|
|
|
x = torch.randn(batch_size, seq_len, in_features)
|
|
|
|
|
|
|
|
|
linear = nn.Linear(in_features, out_features)
|
|
|
bitlinear = BitLinear.from_linear(linear)
|
|
|
multi_ternary = MultiTernaryLinear.from_linear(linear, k=2)
|
|
|
|
|
|
|
|
|
time_linear = benchmark_forward_pass(linear, x)
|
|
|
|
|
|
|
|
|
time_bitlinear = benchmark_forward_pass(bitlinear, x)
|
|
|
|
|
|
|
|
|
time_multi = benchmark_forward_pass(multi_ternary, x)
|
|
|
|
|
|
|
|
|
speedup_bit = time_linear / time_bitlinear
|
|
|
speedup_multi = time_linear / time_multi
|
|
|
|
|
|
|
|
|
print(f"nn.Linear: {time_linear:8.3f} ms")
|
|
|
print(f"BitLinear: {time_bitlinear:8.3f} ms (speedup: {speedup_bit:5.2f}x)")
|
|
|
print(f"MultiTernaryLinear: {time_multi:8.3f} ms (speedup: {speedup_multi:5.2f}x)")
|
|
|
|
|
|
|
|
|
results.append({
|
|
|
'in_features': in_features,
|
|
|
'out_features': out_features,
|
|
|
'batch_size': batch_size,
|
|
|
'seq_len': seq_len,
|
|
|
'time_linear': time_linear,
|
|
|
'time_bitlinear': time_bitlinear,
|
|
|
'time_multi': time_multi,
|
|
|
'speedup_bit': speedup_bit,
|
|
|
'speedup_multi': speedup_multi,
|
|
|
})
|
|
|
|
|
|
|
|
|
print(f"\n\n{'=' * 100}")
|
|
|
print("Summary Table (Markdown Format)")
|
|
|
print(f"{'=' * 100}\n")
|
|
|
|
|
|
print("| Layer Size | Batch | Seq Len | nn.Linear (ms) | BitLinear (ms) | Speedup | Multi-Ternary (ms) | Speedup |")
|
|
|
print("|------------|-------|---------|----------------|----------------|---------|--------------------|---------| ")
|
|
|
|
|
|
for r in results:
|
|
|
print(f"| {r['in_features']}×{r['out_features']:<4} | {r['batch_size']:5} | {r['seq_len']:7} | "
|
|
|
f"{r['time_linear']:14.3f} | {r['time_bitlinear']:14.3f} | {r['speedup_bit']:7.2f} | "
|
|
|
f"{r['time_multi']:18.3f} | {r['speedup_multi']:7.2f} |")
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 100}")
|
|
|
print("Summary Statistics")
|
|
|
print(f"{'=' * 100}\n")
|
|
|
|
|
|
avg_speedup_bit = sum(r['speedup_bit'] for r in results) / len(results)
|
|
|
avg_speedup_multi = sum(r['speedup_multi'] for r in results) / len(results)
|
|
|
|
|
|
print(f"Average BitLinear speedup: {avg_speedup_bit:.2f}x")
|
|
|
print(f"Average Multi-Ternary speedup: {avg_speedup_multi:.2f}x")
|
|
|
|
|
|
if avg_speedup_bit < 1.0:
|
|
|
print(f"\nNote: BitLinear is slower than nn.Linear by {1/avg_speedup_bit:.2f}x on average.")
|
|
|
print("This is expected for the Python implementation. C++/CUDA extensions would be faster.")
|
|
|
else:
|
|
|
print(f"\nNote: BitLinear is faster than nn.Linear by {avg_speedup_bit:.2f}x on average!")
|
|
|
|
|
|
print(f"\n{'=' * 100}")
|
|
|
print("Benchmark Complete!")
|
|
|
print(f"{'=' * 100}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
run_benchmarks()
|
|
|
|