# BitLinear Demo Notebook This notebook provides an interactive demonstration of BitLinear, showing how to use it as a drop-in replacement for nn.Linear with significant memory savings. ## Installation First, install the BitLinear package: ```bash pip install -e . ``` ## 1. Basic Usage Let's start with a simple example: ```python import torch import torch.nn as nn from bitlinear import BitLinear, estimate_memory_savings # Create a BitLinear layer layer = BitLinear(in_features=512, out_features=1024, bias=True) # Create input x = torch.randn(32, 128, 512) # Forward pass (same interface as nn.Linear) output = layer(x) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") print(f"Weight values: {torch.unique(layer.W_ternary)}") ``` ## 2. Memory Savings Calculate the memory savings: ```python # Estimate memory savings stats = estimate_memory_savings(512, 1024, num_layers=1) print(f"Float32 weights: {stats['float32_bytes'] / 1024:.2f} KB") print(f"Packed weights: {stats['packed_bytes'] / 1024:.2f} KB") print(f"Memory saved: {stats['savings_bytes'] / 1024:.2f} KB") print(f"Compression: {stats['compression_ratio']:.1f}x") ``` ## 3. Converting Existing Models Convert a pre-trained model to use BitLinear: ```python # Create a standard Linear layer linear = nn.Linear(512, 1024) # Simulate some training with torch.no_grad(): linear.weight.normal_(0, 0.02) # Convert to BitLinear bitlinear = BitLinear.from_linear(linear) # Compare outputs x = torch.randn(16, 512) with torch.no_grad(): out_linear = linear(x) out_bitlinear = bitlinear(x) # Calculate similarity mse = torch.mean((out_linear - out_bitlinear) ** 2).item() cosine_sim = torch.nn.functional.cosine_similarity( out_linear.flatten(), out_bitlinear.flatten(), dim=0 ).item() print(f"MSE: {mse:.6f}") print(f"Cosine similarity: {cosine_sim:.6f}") ``` ## 4. Transformer Example Use BitLinear in a real Transformer: ```python from bitlinear import convert_linear_to_bitlinear # Create a Transformer encoder layer model = nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=2048) # Convert all Linear layers to BitLinear model_compressed = convert_linear_to_bitlinear(model, inplace=False) # Test forward pass x = torch.randn(10, 32, 512) # (seq_len, batch, d_model) with torch.no_grad(): out_original = model(x) out_compressed = model_compressed(x) # Compare similarity = torch.nn.functional.cosine_similarity( out_original.flatten(), out_compressed.flatten(), dim=0 ).item() print(f"Output similarity: {similarity:.4f}") ``` ## 5. Multi-Ternary for Better Accuracy Use multiple ternary components for improved approximation: ```python from bitlinear import MultiTernaryLinear # Create layers with different k values linear = nn.Linear(512, 1024) bitlinear_k1 = BitLinear.from_linear(linear) bitlinear_k3 = MultiTernaryLinear.from_linear(linear, k=3) # Compare accuracy x = torch.randn(16, 512) with torch.no_grad(): out_orig = linear(x) out_k1 = bitlinear_k1(x) out_k3 = bitlinear_k3(x) error_k1 = (torch.norm(out_orig - out_k1) / torch.norm(out_orig)).item() error_k3 = (torch.norm(out_orig - out_k3) / torch.norm(out_orig)).item() print(f"Relative error (k=1): {error_k1:.6f}") print(f"Relative error (k=3): {error_k3:.6f}") print(f"Improvement: {(error_k1 - error_k3) / error_k1 * 100:.1f}%") ``` ## 6. Visualizing Ternary Weights Visualize the ternary weight distribution: ```python import matplotlib.pyplot as plt import numpy as np # Get ternary weights W_ternary = bitlinear_k1.W_ternary.detach().numpy() # Count values unique, counts = np.unique(W_ternary, return_counts=True) # Plot plt.figure(figsize=(10, 6)) plt.bar(unique, counts, width=0.5) plt.xlabel('Weight Value') plt.ylabel('Count') plt.title('Ternary Weight Distribution') plt.xticks([-1, 0, 1]) plt.grid(axis='y', alpha=0.3) plt.show() # Print statistics total = W_ternary.size print(f"Total weights: {total}") print(f"Zeros: {counts[unique == 0][0]} ({counts[unique == 0][0]/total*100:.1f}%)") print(f"Ones (+1): {counts[unique == 1][0]} ({counts[unique == 1][0]/total*100:.1f}%)") print(f"Negative ones (-1): {counts[unique == -1][0]} ({counts[unique == -1][0]/total*100:.1f}%)") ``` ## 7. Memory Profiling Profile actual memory usage: ```python import torch import gc def get_model_memory_mb(model): """Get model memory in MB.""" total_bytes = sum(p.element_size() * p.nelement() for p in model.parameters()) return total_bytes / (1024 ** 2) # Create models model_linear = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=3072) model_bitlinear = convert_linear_to_bitlinear(model_linear, inplace=False) # Measure memory mem_linear = get_model_memory_mb(model_linear) mem_bitlinear = get_model_memory_mb(model_bitlinear) print(f"Standard model: {mem_linear:.2f} MB") print(f"BitLinear model: {mem_bitlinear:.2f} MB") print(f"Memory savings: {(mem_linear - mem_bitlinear) / mem_linear * 100:.1f}%") ``` ## 8. Benchmarking Run a simple benchmark: ```python import time def benchmark(model, x, n_runs=100): # Warmup for _ in range(10): _ = model(x) # Benchmark start = time.time() for _ in range(n_runs): _ = model(x) end = time.time() return (end - start) / n_runs * 1000 # ms # Create input x = torch.randn(32, 128, 512) # Benchmark time_linear = benchmark(model_linear, x) time_bitlinear = benchmark(model_bitlinear, x) print(f"nn.Linear: {time_linear:.3f} ms") print(f"BitLinear: {time_bitlinear:.3f} ms") print(f"Speedup: {time_linear / time_bitlinear:.2f}x") ``` ## Conclusion BitLinear provides: - ✅ ~19x memory compression - ✅ Drop-in replacement for nn.Linear - ✅ High output similarity (>96%) - ✅ Easy model conversion - ✅ Multi-ternary for better accuracy Perfect for deploying large models on memory-constrained devices! ## For the future o the following - Try converting your own models - Experiment with different k values for multi-ternary - Run comprehensive benchmarks with `benchmarks/benchmark_memory.py` - Check out `examples/transformer_example.py` for more complex usage