File size: 6,795 Bytes
fd8c8b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
"""
Memory usage benchmarking for BitLinear.
This script measures actual memory usage and compression ratios for BitLinear
compared to standard nn.Linear layers.
"""
import torch
import torch.nn as nn
from bitlinear import BitLinear, MultiTernaryLinear, pack_ternary_base3, estimate_memory_savings
import sys
def get_tensor_memory_mb(tensor):
"""Get memory usage of a tensor in MB."""
return tensor.element_size() * tensor.nelement() / (1024 ** 2)
def get_model_memory_mb(model):
"""Get total memory usage of model parameters in MB."""
total_bytes = sum(p.element_size() * p.nelement() for p in model.parameters())
return total_bytes / (1024 ** 2)
def analyze_layer_memory(in_features, out_features):
"""Analyze memory usage for a single layer."""
print(f"\n{'=' * 100}")
print(f"Layer: {in_features} → {out_features}")
print(f"{'=' * 100}\n")
# Create layers
linear = nn.Linear(in_features, out_features, bias=True)
bitlinear = BitLinear.from_linear(linear)
multi_ternary = MultiTernaryLinear.from_linear(linear, k=2)
# Memory for nn.Linear
mem_linear = get_model_memory_mb(linear)
# Memory for BitLinear (stored as float32 currently, but can be packed)
mem_bitlinear = get_model_memory_mb(bitlinear)
# Memory for MultiTernaryLinear
mem_multi = get_model_memory_mb(multi_ternary)
# Theoretical packed memory (base-3 packing)
weights_count = in_features * out_features
packed_bytes = (weights_count + 4) // 5 # 5 ternary values per byte
bias_bytes = out_features * 4 # float32 bias
gamma_bytes = out_features * 4 # float32 gamma
theoretical_packed_mb = (packed_bytes + bias_bytes + gamma_bytes) / (1024 ** 2)
# Calculate compression ratios
compression_current = mem_linear / mem_bitlinear
compression_packed = mem_linear / theoretical_packed_mb
# Print results
print(f"nn.Linear memory: {mem_linear:10.4f} MB")
print(f"BitLinear memory (current): {mem_bitlinear:10.4f} MB (ratio: {compression_current:5.2f}x)")
print(f"BitLinear memory (packed): {theoretical_packed_mb:10.4f} MB (ratio: {compression_packed:5.2f}x)")
print(f"MultiTernaryLinear memory (k=2): {mem_multi:10.4f} MB (ratio: {mem_linear/mem_multi:5.2f}x)")
# Test actual packing
print(f"\nPacking Test:")
print(f"-" * 100)
W_ternary = bitlinear.W_ternary
packed, original_shape = pack_ternary_base3(W_ternary)
unpacked_size_mb = get_tensor_memory_mb(W_ternary)
packed_size_mb = get_tensor_memory_mb(packed)
actual_compression = unpacked_size_mb / packed_size_mb
print(f"Unpacked weights: {unpacked_size_mb:10.4f} MB")
print(f"Packed weights: {packed_size_mb:10.4f} MB")
print(f"Actual compression: {actual_compression:8.2f}x")
return {
'in_features': in_features,
'out_features': out_features,
'mem_linear': mem_linear,
'mem_bitlinear': mem_bitlinear,
'mem_packed': theoretical_packed_mb,
'mem_multi': mem_multi,
'compression_current': compression_current,
'compression_packed': compression_packed,
}
def run_memory_benchmarks():
"""Run comprehensive memory benchmarks."""
print("=" * 100)
print("BitLinear Memory Benchmarks")
print("=" * 100)
print(f"\nPyTorch version: {torch.__version__}")
# Test configurations
layer_sizes = [
(512, 512),
(768, 768),
(1024, 1024),
(2048, 2048),
(4096, 4096),
(768, 3072), # Typical Transformer FFN
(1024, 4096), # Larger Transformer FFN
]
results = []
for in_features, out_features in layer_sizes:
result = analyze_layer_memory(in_features, out_features)
results.append(result)
# Generate summary table
print(f"\n\n{'=' * 100}")
print("Memory Compression Summary (Markdown Format)")
print(f"{'=' * 100}\n")
print("| Layer Size | nn.Linear (MB) | BitLinear Current (MB) | BitLinear Packed (MB) | Compression (Packed) |")
print("|------------|----------------|------------------------|----------------------|----------------------|")
for r in results:
print(f"| {r['in_features']}×{r['out_features']:<4} | {r['mem_linear']:14.4f} | "
f"{r['mem_bitlinear']:22.4f} | {r['mem_packed']:20.4f} | {r['compression_packed']:20.2f}x |")
# Overall statistics
print(f"\n{'=' * 100}")
print("Summary Statistics")
print(f"{'=' * 100}\n")
avg_compression = sum(r['compression_packed'] for r in results) / len(results)
min_compression = min(r['compression_packed'] for r in results)
max_compression = max(r['compression_packed'] for r in results)
print(f"Average compression ratio: {avg_compression:.2f}x")
print(f"Minimum compression ratio: {min_compression:.2f}x")
print(f"Maximum compression ratio: {max_compression:.2f}x")
# Transformer example
print(f"\n{'=' * 100}")
print("Real-World Example: GPT-2 Style Transformer")
print(f"{'=' * 100}\n")
# GPT-2 small: 12 layers, d_model=768, d_ff=3072
num_layers = 12
d_model = 768
d_ff = 3072
# Each layer has: Q, K, V, O projections (4 × d_model²) + 2 FFN layers (d_model×d_ff + d_ff×d_model)
linear_per_layer = (4 * d_model * d_model) + (d_model * d_ff) + (d_ff * d_model)
linear_total = linear_per_layer * num_layers
# Calculate memory
linear_mem_mb = (linear_total * 4) / (1024 ** 2) # float32
packed_mem_mb = ((linear_total + 4) // 5) / (1024 ** 2) # base-3 packed
# Add bias and gamma
params_per_layer = (4 * d_model) + d_ff + d_model # biases
gammas_per_layer = (4 * d_model) + d_ff + d_model # scaling factors
overhead_mb = ((params_per_layer + gammas_per_layer) * num_layers * 4) / (1024 ** 2)
packed_total_mb = packed_mem_mb + overhead_mb
compression = linear_mem_mb / packed_total_mb
print(f"Configuration: {num_layers} layers, d_model={d_model}, d_ff={d_ff}")
print(f"Total linear parameters: {linear_total:,}")
print(f"\nnn.Linear memory: {linear_mem_mb:10.2f} MB")
print(f"BitLinear packed: {packed_total_mb:10.2f} MB")
print(f"Memory saved: {linear_mem_mb - packed_total_mb:10.2f} MB")
print(f"Compression ratio: {compression:10.2f}x")
print(f"\n{'=' * 100}")
print("Benchmark Complete!")
print(f"{'=' * 100}")
if __name__ == "__main__":
run_memory_benchmarks()
|