BitLinear Demo Notebook
This notebook provides an interactive demonstration of BitLinear, showing how to use it as a drop-in replacement for nn.Linear with significant memory savings.
Installation
First, install the BitLinear package:
pip install -e .
1. Basic Usage
Let's start with a simple example:
import torch
import torch.nn as nn
from bitlinear import BitLinear, estimate_memory_savings
# Create a BitLinear layer
layer = BitLinear(in_features=512, out_features=1024, bias=True)
# Create input
x = torch.randn(32, 128, 512)
# Forward pass (same interface as nn.Linear)
output = layer(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Weight values: {torch.unique(layer.W_ternary)}")
2. Memory Savings
Calculate the memory savings:
# Estimate memory savings
stats = estimate_memory_savings(512, 1024, num_layers=1)
print(f"Float32 weights: {stats['float32_bytes'] / 1024:.2f} KB")
print(f"Packed weights: {stats['packed_bytes'] / 1024:.2f} KB")
print(f"Memory saved: {stats['savings_bytes'] / 1024:.2f} KB")
print(f"Compression: {stats['compression_ratio']:.1f}x")
3. Converting Existing Models
Convert a pre-trained model to use BitLinear:
# Create a standard Linear layer
linear = nn.Linear(512, 1024)
# Simulate some training
with torch.no_grad():
linear.weight.normal_(0, 0.02)
# Convert to BitLinear
bitlinear = BitLinear.from_linear(linear)
# Compare outputs
x = torch.randn(16, 512)
with torch.no_grad():
out_linear = linear(x)
out_bitlinear = bitlinear(x)
# Calculate similarity
mse = torch.mean((out_linear - out_bitlinear) ** 2).item()
cosine_sim = torch.nn.functional.cosine_similarity(
out_linear.flatten(),
out_bitlinear.flatten(),
dim=0
).item()
print(f"MSE: {mse:.6f}")
print(f"Cosine similarity: {cosine_sim:.6f}")
4. Transformer Example
Use BitLinear in a real Transformer:
from bitlinear import convert_linear_to_bitlinear
# Create a Transformer encoder layer
model = nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=2048)
# Convert all Linear layers to BitLinear
model_compressed = convert_linear_to_bitlinear(model, inplace=False)
# Test forward pass
x = torch.randn(10, 32, 512) # (seq_len, batch, d_model)
with torch.no_grad():
out_original = model(x)
out_compressed = model_compressed(x)
# Compare
similarity = torch.nn.functional.cosine_similarity(
out_original.flatten(),
out_compressed.flatten(),
dim=0
).item()
print(f"Output similarity: {similarity:.4f}")
5. Multi-Ternary for Better Accuracy
Use multiple ternary components for improved approximation:
from bitlinear import MultiTernaryLinear
# Create layers with different k values
linear = nn.Linear(512, 1024)
bitlinear_k1 = BitLinear.from_linear(linear)
bitlinear_k3 = MultiTernaryLinear.from_linear(linear, k=3)
# Compare accuracy
x = torch.randn(16, 512)
with torch.no_grad():
out_orig = linear(x)
out_k1 = bitlinear_k1(x)
out_k3 = bitlinear_k3(x)
error_k1 = (torch.norm(out_orig - out_k1) / torch.norm(out_orig)).item()
error_k3 = (torch.norm(out_orig - out_k3) / torch.norm(out_orig)).item()
print(f"Relative error (k=1): {error_k1:.6f}")
print(f"Relative error (k=3): {error_k3:.6f}")
print(f"Improvement: {(error_k1 - error_k3) / error_k1 * 100:.1f}%")
6. Visualizing Ternary Weights
Visualize the ternary weight distribution:
import matplotlib.pyplot as plt
import numpy as np
# Get ternary weights
W_ternary = bitlinear_k1.W_ternary.detach().numpy()
# Count values
unique, counts = np.unique(W_ternary, return_counts=True)
# Plot
plt.figure(figsize=(10, 6))
plt.bar(unique, counts, width=0.5)
plt.xlabel('Weight Value')
plt.ylabel('Count')
plt.title('Ternary Weight Distribution')
plt.xticks([-1, 0, 1])
plt.grid(axis='y', alpha=0.3)
plt.show()
# Print statistics
total = W_ternary.size
print(f"Total weights: {total}")
print(f"Zeros: {counts[unique == 0][0]} ({counts[unique == 0][0]/total*100:.1f}%)")
print(f"Ones (+1): {counts[unique == 1][0]} ({counts[unique == 1][0]/total*100:.1f}%)")
print(f"Negative ones (-1): {counts[unique == -1][0]} ({counts[unique == -1][0]/total*100:.1f}%)")
7. Memory Profiling
Profile actual memory usage:
import torch
import gc
def get_model_memory_mb(model):
"""Get model memory in MB."""
total_bytes = sum(p.element_size() * p.nelement() for p in model.parameters())
return total_bytes / (1024 ** 2)
# Create models
model_linear = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=3072)
model_bitlinear = convert_linear_to_bitlinear(model_linear, inplace=False)
# Measure memory
mem_linear = get_model_memory_mb(model_linear)
mem_bitlinear = get_model_memory_mb(model_bitlinear)
print(f"Standard model: {mem_linear:.2f} MB")
print(f"BitLinear model: {mem_bitlinear:.2f} MB")
print(f"Memory savings: {(mem_linear - mem_bitlinear) / mem_linear * 100:.1f}%")
8. Benchmarking
Run a simple benchmark:
import time
def benchmark(model, x, n_runs=100):
# Warmup
for _ in range(10):
_ = model(x)
# Benchmark
start = time.time()
for _ in range(n_runs):
_ = model(x)
end = time.time()
return (end - start) / n_runs * 1000 # ms
# Create input
x = torch.randn(32, 128, 512)
# Benchmark
time_linear = benchmark(model_linear, x)
time_bitlinear = benchmark(model_bitlinear, x)
print(f"nn.Linear: {time_linear:.3f} ms")
print(f"BitLinear: {time_bitlinear:.3f} ms")
print(f"Speedup: {time_linear / time_bitlinear:.2f}x")
Conclusion
BitLinear provides:
- ✅ ~19x memory compression
- ✅ Drop-in replacement for nn.Linear
- ✅ High output similarity (>96%)
- ✅ Easy model conversion
- ✅ Multi-ternary for better accuracy
Perfect for deploying large models on memory-constrained devices!
For the future o the following
- Try converting your own models
- Experiment with different k values for multi-ternary
- Run comprehensive benchmarks with
benchmarks/benchmark_memory.py - Check out
examples/transformer_example.pyfor more complex usage