|
|
"""
|
|
|
Memory usage benchmarking for BitLinear.
|
|
|
|
|
|
This script measures actual memory usage and compression ratios for BitLinear
|
|
|
compared to standard nn.Linear layers.
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from bitlinear import BitLinear, MultiTernaryLinear, pack_ternary_base3, estimate_memory_savings
|
|
|
import sys
|
|
|
|
|
|
|
|
|
def get_tensor_memory_mb(tensor):
|
|
|
"""Get memory usage of a tensor in MB."""
|
|
|
return tensor.element_size() * tensor.nelement() / (1024 ** 2)
|
|
|
|
|
|
|
|
|
def get_model_memory_mb(model):
|
|
|
"""Get total memory usage of model parameters in MB."""
|
|
|
total_bytes = sum(p.element_size() * p.nelement() for p in model.parameters())
|
|
|
return total_bytes / (1024 ** 2)
|
|
|
|
|
|
|
|
|
def analyze_layer_memory(in_features, out_features):
|
|
|
"""Analyze memory usage for a single layer."""
|
|
|
|
|
|
print(f"\n{'=' * 100}")
|
|
|
print(f"Layer: {in_features} → {out_features}")
|
|
|
print(f"{'=' * 100}\n")
|
|
|
|
|
|
|
|
|
linear = nn.Linear(in_features, out_features, bias=True)
|
|
|
bitlinear = BitLinear.from_linear(linear)
|
|
|
multi_ternary = MultiTernaryLinear.from_linear(linear, k=2)
|
|
|
|
|
|
|
|
|
mem_linear = get_model_memory_mb(linear)
|
|
|
|
|
|
|
|
|
mem_bitlinear = get_model_memory_mb(bitlinear)
|
|
|
|
|
|
|
|
|
mem_multi = get_model_memory_mb(multi_ternary)
|
|
|
|
|
|
|
|
|
weights_count = in_features * out_features
|
|
|
packed_bytes = (weights_count + 4) // 5
|
|
|
bias_bytes = out_features * 4
|
|
|
gamma_bytes = out_features * 4
|
|
|
theoretical_packed_mb = (packed_bytes + bias_bytes + gamma_bytes) / (1024 ** 2)
|
|
|
|
|
|
|
|
|
compression_current = mem_linear / mem_bitlinear
|
|
|
compression_packed = mem_linear / theoretical_packed_mb
|
|
|
|
|
|
|
|
|
print(f"nn.Linear memory: {mem_linear:10.4f} MB")
|
|
|
print(f"BitLinear memory (current): {mem_bitlinear:10.4f} MB (ratio: {compression_current:5.2f}x)")
|
|
|
print(f"BitLinear memory (packed): {theoretical_packed_mb:10.4f} MB (ratio: {compression_packed:5.2f}x)")
|
|
|
print(f"MultiTernaryLinear memory (k=2): {mem_multi:10.4f} MB (ratio: {mem_linear/mem_multi:5.2f}x)")
|
|
|
|
|
|
|
|
|
print(f"\nPacking Test:")
|
|
|
print(f"-" * 100)
|
|
|
|
|
|
W_ternary = bitlinear.W_ternary
|
|
|
packed, original_shape = pack_ternary_base3(W_ternary)
|
|
|
|
|
|
unpacked_size_mb = get_tensor_memory_mb(W_ternary)
|
|
|
packed_size_mb = get_tensor_memory_mb(packed)
|
|
|
actual_compression = unpacked_size_mb / packed_size_mb
|
|
|
|
|
|
print(f"Unpacked weights: {unpacked_size_mb:10.4f} MB")
|
|
|
print(f"Packed weights: {packed_size_mb:10.4f} MB")
|
|
|
print(f"Actual compression: {actual_compression:8.2f}x")
|
|
|
|
|
|
return {
|
|
|
'in_features': in_features,
|
|
|
'out_features': out_features,
|
|
|
'mem_linear': mem_linear,
|
|
|
'mem_bitlinear': mem_bitlinear,
|
|
|
'mem_packed': theoretical_packed_mb,
|
|
|
'mem_multi': mem_multi,
|
|
|
'compression_current': compression_current,
|
|
|
'compression_packed': compression_packed,
|
|
|
}
|
|
|
|
|
|
|
|
|
def run_memory_benchmarks():
|
|
|
"""Run comprehensive memory benchmarks."""
|
|
|
|
|
|
print("=" * 100)
|
|
|
print("BitLinear Memory Benchmarks")
|
|
|
print("=" * 100)
|
|
|
print(f"\nPyTorch version: {torch.__version__}")
|
|
|
|
|
|
|
|
|
layer_sizes = [
|
|
|
(512, 512),
|
|
|
(768, 768),
|
|
|
(1024, 1024),
|
|
|
(2048, 2048),
|
|
|
(4096, 4096),
|
|
|
(768, 3072),
|
|
|
(1024, 4096),
|
|
|
]
|
|
|
|
|
|
results = []
|
|
|
|
|
|
for in_features, out_features in layer_sizes:
|
|
|
result = analyze_layer_memory(in_features, out_features)
|
|
|
results.append(result)
|
|
|
|
|
|
|
|
|
print(f"\n\n{'=' * 100}")
|
|
|
print("Memory Compression Summary (Markdown Format)")
|
|
|
print(f"{'=' * 100}\n")
|
|
|
|
|
|
print("| Layer Size | nn.Linear (MB) | BitLinear Current (MB) | BitLinear Packed (MB) | Compression (Packed) |")
|
|
|
print("|------------|----------------|------------------------|----------------------|----------------------|")
|
|
|
|
|
|
for r in results:
|
|
|
print(f"| {r['in_features']}×{r['out_features']:<4} | {r['mem_linear']:14.4f} | "
|
|
|
f"{r['mem_bitlinear']:22.4f} | {r['mem_packed']:20.4f} | {r['compression_packed']:20.2f}x |")
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 100}")
|
|
|
print("Summary Statistics")
|
|
|
print(f"{'=' * 100}\n")
|
|
|
|
|
|
avg_compression = sum(r['compression_packed'] for r in results) / len(results)
|
|
|
min_compression = min(r['compression_packed'] for r in results)
|
|
|
max_compression = max(r['compression_packed'] for r in results)
|
|
|
|
|
|
print(f"Average compression ratio: {avg_compression:.2f}x")
|
|
|
print(f"Minimum compression ratio: {min_compression:.2f}x")
|
|
|
print(f"Maximum compression ratio: {max_compression:.2f}x")
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 100}")
|
|
|
print("Real-World Example: GPT-2 Style Transformer")
|
|
|
print(f"{'=' * 100}\n")
|
|
|
|
|
|
|
|
|
num_layers = 12
|
|
|
d_model = 768
|
|
|
d_ff = 3072
|
|
|
|
|
|
|
|
|
linear_per_layer = (4 * d_model * d_model) + (d_model * d_ff) + (d_ff * d_model)
|
|
|
linear_total = linear_per_layer * num_layers
|
|
|
|
|
|
|
|
|
linear_mem_mb = (linear_total * 4) / (1024 ** 2)
|
|
|
packed_mem_mb = ((linear_total + 4) // 5) / (1024 ** 2)
|
|
|
|
|
|
|
|
|
params_per_layer = (4 * d_model) + d_ff + d_model
|
|
|
gammas_per_layer = (4 * d_model) + d_ff + d_model
|
|
|
overhead_mb = ((params_per_layer + gammas_per_layer) * num_layers * 4) / (1024 ** 2)
|
|
|
|
|
|
packed_total_mb = packed_mem_mb + overhead_mb
|
|
|
compression = linear_mem_mb / packed_total_mb
|
|
|
|
|
|
print(f"Configuration: {num_layers} layers, d_model={d_model}, d_ff={d_ff}")
|
|
|
print(f"Total linear parameters: {linear_total:,}")
|
|
|
print(f"\nnn.Linear memory: {linear_mem_mb:10.2f} MB")
|
|
|
print(f"BitLinear packed: {packed_total_mb:10.2f} MB")
|
|
|
print(f"Memory saved: {linear_mem_mb - packed_total_mb:10.2f} MB")
|
|
|
print(f"Compression ratio: {compression:10.2f}x")
|
|
|
|
|
|
print(f"\n{'=' * 100}")
|
|
|
print("Benchmark Complete!")
|
|
|
print(f"{'=' * 100}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
run_memory_benchmarks()
|
|
|
|