File size: 6,439 Bytes
fd8c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# BitLinear Demo Notebook

This notebook provides an interactive demonstration of BitLinear, showing how to use it as a drop-in replacement for nn.Linear with significant memory savings.

## Installation

First, install the BitLinear package:

```bash

pip install -e .

```

## 1. Basic Usage

Let's start with a simple example:

```python

import torch

import torch.nn as nn

from bitlinear import BitLinear, estimate_memory_savings



# Create a BitLinear layer

layer = BitLinear(in_features=512, out_features=1024, bias=True)



# Create input

x = torch.randn(32, 128, 512)



# Forward pass (same interface as nn.Linear)

output = layer(x)



print(f"Input shape: {x.shape}")

print(f"Output shape: {output.shape}")

print(f"Weight values: {torch.unique(layer.W_ternary)}")

```

## 2. Memory Savings

Calculate the memory savings:

```python

# Estimate memory savings

stats = estimate_memory_savings(512, 1024, num_layers=1)



print(f"Float32 weights: {stats['float32_bytes'] / 1024:.2f} KB")

print(f"Packed weights:  {stats['packed_bytes'] / 1024:.2f} KB")

print(f"Memory saved:    {stats['savings_bytes'] / 1024:.2f} KB")

print(f"Compression:     {stats['compression_ratio']:.1f}x")

```

## 3. Converting Existing Models

Convert a pre-trained model to use BitLinear:

```python

# Create a standard Linear layer

linear = nn.Linear(512, 1024)



# Simulate some training

with torch.no_grad():

    linear.weight.normal_(0, 0.02)



# Convert to BitLinear

bitlinear = BitLinear.from_linear(linear)



# Compare outputs

x = torch.randn(16, 512)



with torch.no_grad():

    out_linear = linear(x)

    out_bitlinear = bitlinear(x)



# Calculate similarity

mse = torch.mean((out_linear - out_bitlinear) ** 2).item()

cosine_sim = torch.nn.functional.cosine_similarity(

    out_linear.flatten(), 

    out_bitlinear.flatten(), 

    dim=0

).item()



print(f"MSE: {mse:.6f}")

print(f"Cosine similarity: {cosine_sim:.6f}")

```

## 4. Transformer Example

Use BitLinear in a real Transformer:

```python

from bitlinear import convert_linear_to_bitlinear



# Create a Transformer encoder layer

model = nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=2048)



# Convert all Linear layers to BitLinear

model_compressed = convert_linear_to_bitlinear(model, inplace=False)



# Test forward pass

x = torch.randn(10, 32, 512)  # (seq_len, batch, d_model)



with torch.no_grad():

    out_original = model(x)

    out_compressed = model_compressed(x)



# Compare

similarity = torch.nn.functional.cosine_similarity(

    out_original.flatten(), 

    out_compressed.flatten(), 

    dim=0

).item()



print(f"Output similarity: {similarity:.4f}")

```

## 5. Multi-Ternary for Better Accuracy

Use multiple ternary components for improved approximation:

```python

from bitlinear import MultiTernaryLinear



# Create layers with different k values

linear = nn.Linear(512, 1024)

bitlinear_k1 = BitLinear.from_linear(linear)

bitlinear_k3 = MultiTernaryLinear.from_linear(linear, k=3)



# Compare accuracy

x = torch.randn(16, 512)



with torch.no_grad():

    out_orig = linear(x)

    out_k1 = bitlinear_k1(x)

    out_k3 = bitlinear_k3(x)



error_k1 = (torch.norm(out_orig - out_k1) / torch.norm(out_orig)).item()

error_k3 = (torch.norm(out_orig - out_k3) / torch.norm(out_orig)).item()



print(f"Relative error (k=1): {error_k1:.6f}")

print(f"Relative error (k=3): {error_k3:.6f}")

print(f"Improvement: {(error_k1 - error_k3) / error_k1 * 100:.1f}%")

```

## 6. Visualizing Ternary Weights

Visualize the ternary weight distribution:

```python

import matplotlib.pyplot as plt

import numpy as np



# Get ternary weights

W_ternary = bitlinear_k1.W_ternary.detach().numpy()



# Count values

unique, counts = np.unique(W_ternary, return_counts=True)



# Plot

plt.figure(figsize=(10, 6))

plt.bar(unique, counts, width=0.5)

plt.xlabel('Weight Value')

plt.ylabel('Count')

plt.title('Ternary Weight Distribution')

plt.xticks([-1, 0, 1])

plt.grid(axis='y', alpha=0.3)

plt.show()



# Print statistics

total = W_ternary.size

print(f"Total weights: {total}")

print(f"Zeros: {counts[unique == 0][0]} ({counts[unique == 0][0]/total*100:.1f}%)")

print(f"Ones (+1): {counts[unique == 1][0]} ({counts[unique == 1][0]/total*100:.1f}%)")

print(f"Negative ones (-1): {counts[unique == -1][0]} ({counts[unique == -1][0]/total*100:.1f}%)")

```

## 7. Memory Profiling

Profile actual memory usage:

```python

import torch

import gc



def get_model_memory_mb(model):

    """Get model memory in MB."""

    total_bytes = sum(p.element_size() * p.nelement() for p in model.parameters())

    return total_bytes / (1024 ** 2)



# Create models

model_linear = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=3072)

model_bitlinear = convert_linear_to_bitlinear(model_linear, inplace=False)



# Measure memory

mem_linear = get_model_memory_mb(model_linear)

mem_bitlinear = get_model_memory_mb(model_bitlinear)



print(f"Standard model: {mem_linear:.2f} MB")

print(f"BitLinear model: {mem_bitlinear:.2f} MB")

print(f"Memory savings: {(mem_linear - mem_bitlinear) / mem_linear * 100:.1f}%")

```

## 8. Benchmarking

Run a simple benchmark:

```python

import time



def benchmark(model, x, n_runs=100):

    # Warmup

    for _ in range(10):

        _ = model(x)

    

    # Benchmark

    start = time.time()

    for _ in range(n_runs):

        _ = model(x)

    end = time.time()

    

    return (end - start) / n_runs * 1000  # ms



# Create input

x = torch.randn(32, 128, 512)



# Benchmark

time_linear = benchmark(model_linear, x)

time_bitlinear = benchmark(model_bitlinear, x)



print(f"nn.Linear: {time_linear:.3f} ms")

print(f"BitLinear: {time_bitlinear:.3f} ms")

print(f"Speedup: {time_linear / time_bitlinear:.2f}x")

```

## Conclusion

BitLinear provides:
- ✅ ~19x memory compression
- ✅ Drop-in replacement for nn.Linear
- ✅ High output similarity (>96%)
- ✅ Easy model conversion
- ✅ Multi-ternary for better accuracy

Perfect for deploying large models on memory-constrained devices!

## For the future o the following

- Try converting your own models
- Experiment with different k values for multi-ternary
- Run comprehensive benchmarks with `benchmarks/benchmark_memory.py`
- Check out `examples/transformer_example.py` for more complex usage