File size: 6,795 Bytes
fd8c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""

Memory usage benchmarking for BitLinear.



This script measures actual memory usage and compression ratios for BitLinear

compared to standard nn.Linear layers.

"""

import torch
import torch.nn as nn
from bitlinear import BitLinear, MultiTernaryLinear, pack_ternary_base3, estimate_memory_savings
import sys


def get_tensor_memory_mb(tensor):
    """Get memory usage of a tensor in MB."""
    return tensor.element_size() * tensor.nelement() / (1024 ** 2)


def get_model_memory_mb(model):
    """Get total memory usage of model parameters in MB."""
    total_bytes = sum(p.element_size() * p.nelement() for p in model.parameters())
    return total_bytes / (1024 ** 2)


def analyze_layer_memory(in_features, out_features):
    """Analyze memory usage for a single layer."""
    
    print(f"\n{'=' * 100}")
    print(f"Layer: {in_features}{out_features}")
    print(f"{'=' * 100}\n")
    
    # Create layers
    linear = nn.Linear(in_features, out_features, bias=True)
    bitlinear = BitLinear.from_linear(linear)
    multi_ternary = MultiTernaryLinear.from_linear(linear, k=2)
    
    # Memory for nn.Linear
    mem_linear = get_model_memory_mb(linear)
    
    # Memory for BitLinear (stored as float32 currently, but can be packed)
    mem_bitlinear = get_model_memory_mb(bitlinear)
    
    # Memory for MultiTernaryLinear
    mem_multi = get_model_memory_mb(multi_ternary)
    
    # Theoretical packed memory (base-3 packing)
    weights_count = in_features * out_features
    packed_bytes = (weights_count + 4) // 5  # 5 ternary values per byte
    bias_bytes = out_features * 4  # float32 bias
    gamma_bytes = out_features * 4  # float32 gamma
    theoretical_packed_mb = (packed_bytes + bias_bytes + gamma_bytes) / (1024 ** 2)
    
    # Calculate compression ratios
    compression_current = mem_linear / mem_bitlinear
    compression_packed = mem_linear / theoretical_packed_mb
    
    # Print results
    print(f"nn.Linear memory:                {mem_linear:10.4f} MB")
    print(f"BitLinear memory (current):      {mem_bitlinear:10.4f} MB  (ratio: {compression_current:5.2f}x)")
    print(f"BitLinear memory (packed):       {theoretical_packed_mb:10.4f} MB  (ratio: {compression_packed:5.2f}x)")
    print(f"MultiTernaryLinear memory (k=2): {mem_multi:10.4f} MB  (ratio: {mem_linear/mem_multi:5.2f}x)")
    
    # Test actual packing
    print(f"\nPacking Test:")
    print(f"-" * 100)
    
    W_ternary = bitlinear.W_ternary
    packed, original_shape = pack_ternary_base3(W_ternary)
    
    unpacked_size_mb = get_tensor_memory_mb(W_ternary)
    packed_size_mb = get_tensor_memory_mb(packed)
    actual_compression = unpacked_size_mb / packed_size_mb
    
    print(f"Unpacked weights: {unpacked_size_mb:10.4f} MB")
    print(f"Packed weights:   {packed_size_mb:10.4f} MB")
    print(f"Actual compression: {actual_compression:8.2f}x")
    
    return {
        'in_features': in_features,
        'out_features': out_features,
        'mem_linear': mem_linear,
        'mem_bitlinear': mem_bitlinear,
        'mem_packed': theoretical_packed_mb,
        'mem_multi': mem_multi,
        'compression_current': compression_current,
        'compression_packed': compression_packed,
    }


def run_memory_benchmarks():
    """Run comprehensive memory benchmarks."""
    
    print("=" * 100)
    print("BitLinear Memory Benchmarks")
    print("=" * 100)
    print(f"\nPyTorch version: {torch.__version__}")
    
    # Test configurations
    layer_sizes = [
        (512, 512),
        (768, 768),
        (1024, 1024),
        (2048, 2048),
        (4096, 4096),
        (768, 3072),   # Typical Transformer FFN
        (1024, 4096),  # Larger Transformer FFN
    ]
    
    results = []
    
    for in_features, out_features in layer_sizes:
        result = analyze_layer_memory(in_features, out_features)
        results.append(result)
    
    # Generate summary table
    print(f"\n\n{'=' * 100}")
    print("Memory Compression Summary (Markdown Format)")
    print(f"{'=' * 100}\n")
    
    print("| Layer Size | nn.Linear (MB) | BitLinear Current (MB) | BitLinear Packed (MB) | Compression (Packed) |")
    print("|------------|----------------|------------------------|----------------------|----------------------|")
    
    for r in results:
        print(f"| {r['in_features']}×{r['out_features']:<4} | {r['mem_linear']:14.4f} | "
              f"{r['mem_bitlinear']:22.4f} | {r['mem_packed']:20.4f} | {r['compression_packed']:20.2f}x |")
    
    # Overall statistics
    print(f"\n{'=' * 100}")
    print("Summary Statistics")
    print(f"{'=' * 100}\n")
    
    avg_compression = sum(r['compression_packed'] for r in results) / len(results)
    min_compression = min(r['compression_packed'] for r in results)
    max_compression = max(r['compression_packed'] for r in results)
    
    print(f"Average compression ratio: {avg_compression:.2f}x")
    print(f"Minimum compression ratio: {min_compression:.2f}x")
    print(f"Maximum compression ratio: {max_compression:.2f}x")
    
    # Transformer example
    print(f"\n{'=' * 100}")
    print("Real-World Example: GPT-2 Style Transformer")
    print(f"{'=' * 100}\n")
    
    # GPT-2 small: 12 layers, d_model=768, d_ff=3072
    num_layers = 12
    d_model = 768
    d_ff = 3072
    
    # Each layer has: Q, K, V, O projections (4 × d_model²) + 2 FFN layers (d_model×d_ff + d_ff×d_model)
    linear_per_layer = (4 * d_model * d_model) + (d_model * d_ff) + (d_ff * d_model)
    linear_total = linear_per_layer * num_layers
    
    # Calculate memory
    linear_mem_mb = (linear_total * 4) / (1024 ** 2)  # float32
    packed_mem_mb = ((linear_total + 4) // 5) / (1024 ** 2)  # base-3 packed
    
    # Add bias and gamma
    params_per_layer = (4 * d_model) + d_ff + d_model  # biases
    gammas_per_layer = (4 * d_model) + d_ff + d_model  # scaling factors
    overhead_mb = ((params_per_layer + gammas_per_layer) * num_layers * 4) / (1024 ** 2)
    
    packed_total_mb = packed_mem_mb + overhead_mb
    compression = linear_mem_mb / packed_total_mb
    
    print(f"Configuration: {num_layers} layers, d_model={d_model}, d_ff={d_ff}")
    print(f"Total linear parameters: {linear_total:,}")
    print(f"\nnn.Linear memory:      {linear_mem_mb:10.2f} MB")
    print(f"BitLinear packed:      {packed_total_mb:10.2f} MB")
    print(f"Memory saved:          {linear_mem_mb - packed_total_mb:10.2f} MB")
    print(f"Compression ratio:     {compression:10.2f}x")
    
    print(f"\n{'=' * 100}")
    print("Benchmark Complete!")
    print(f"{'=' * 100}")


if __name__ == "__main__":
    run_memory_benchmarks()