# Model Card: BitLinear ## Model Description **BitLinear** is a PyTorch implementation of ultra-low-precision (1.58-bit) ternary linear layers that can serve as drop-in replacements for `nn.Linear` in neural networks, particularly Transformers. It achieves ~19x memory compression while maintaining high output similarity. ### Model Details - **Developed by:** BitLinear Contributors - **Model type:** Quantization / Compression - **Language:** Python, C++, CUDA - **License:** MIT - **Repository:** https://github.com/yourusername/bitlinear ## Intended Use ### Primary Use Cases - **Edge Deployment:** Deploying large models on memory-constrained devices - **Production Inference:** Reducing memory footprint for serving large language models - **Research:** Exploring ultra-low-precision neural networks - **Cost Optimization:** Reducing cloud infrastructure costs through memory savings ### Out-of-Scope Use Cases - Training from scratch (requires quantization-aware training) - Applications requiring exact numerical precision - Real-time applications where Python overhead is prohibitive (use C++/CUDA extensions) ## How to Use ### Basic Usage ```python import torch from bitlinear import BitLinear # Create a BitLinear layer (same interface as nn.Linear) layer = BitLinear(in_features=512, out_features=1024, bias=True) # Forward pass x = torch.randn(32, 128, 512) output = layer(x) # Same as nn.Linear ``` ### Converting Existing Models ```python import torch.nn as nn from bitlinear import convert_linear_to_bitlinear # Convert a pre-trained model model = nn.TransformerEncoderLayer(d_model=512, nhead=8) model_compressed = convert_linear_to_bitlinear(model, inplace=False) # Use as normal x = torch.randn(10, 32, 512) output = model_compressed(x) ``` ### Multi-Ternary for Better Accuracy ```python from bitlinear import MultiTernaryLinear # Use k=3 components for 75% error reduction layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3) ``` ## Performance ### Memory Compression - **Average Compression:** 19.23x (95% of theoretical 20x) - **GPT-2 Small Example:** 324 MB → 16.8 MB (307 MB saved) | Layer Size | nn.Linear | BitLinear (Packed) | Compression | |------------|-----------|-------------------|-------------| | 512×512 | 1.00 MB | 0.05 MB | 18.6x | | 1024×1024 | 4.00 MB | 0.21 MB | 19.3x | | 4096×4096 | 64.02 MB | 3.23 MB | 19.8x | ### Accuracy - **Cosine Similarity:** > 0.96 (96%+) - **Relative Error:** ~0.28 (28%) - **Multi-Ternary (k=3):** 75% error reduction vs k=1 ## Limitations ### Known Limitations 1. **Accuracy Trade-off:** Ternary quantization introduces approximation error (~3-5% typical) 2. **Training:** Requires quantization-aware training (QAT) for optimal results 3. **Speed:** Python implementation may be slower than nn.Linear (use C++/CUDA for production) 4. **Activation Quantization:** Currently only weights are quantized (full BitNet includes activation quantization) ### Recommendations - Fine-tune converted models for best accuracy - Use k≥2 for MultiTernaryLinear when accuracy is critical - Profile performance on your specific hardware - Test accuracy on your specific task before deployment ## Training ### Quantization-Aware Training (QAT) For best results, fine-tune models with BitLinear layers: ```python # Convert pre-trained model model_bit = convert_linear_to_bitlinear(pretrained_model) # Fine-tune with standard training loop optimizer = torch.optim.AdamW(model_bit.parameters(), lr=1e-4) # ... train as normal ... ``` ### From Scratch Training Training from scratch with ternary weights requires: - Careful initialization - Straight-through estimators for gradients - Potentially modified learning rates See `read/IMPLEMENTATION_GUIDE.md` for details. ## Technical Specifications ### Architecture - **Weight Quantization:** Ternary {-1, 0, +1} - **Scaling:** Per-output-channel absmax scaling - **Packing:** Base-3 encoding (5 values per byte) - **Decomposition:** Greedy residual quantization for multi-ternary ### Implementation - **Python:** Pure PyTorch baseline - **C++:** Optimized CPU kernels with PyBind11 - **CUDA:** GPU kernels with warp-level reductions and shared memory tiling ### Requirements - Python ≥ 3.8 - PyTorch ≥ 2.0.0 - NumPy ≥ 1.20.0 - C++ compiler (for C++ extensions) - CUDA toolkit (optional, for GPU support) ## Evaluation ### Benchmarks Comprehensive benchmarks available in `BENCHMARKS.md`: - Memory compression analysis - Forward pass timing - Accuracy metrics - Real-world transformer examples ### Validation All implementations validated against: - Unit tests (pytest suite) - Numerical correctness tests - Integration tests with Transformers - Cross-implementation consistency (Python vs C++) ## Citation If you use BitLinear in your research, please cite: ```bibtex @article{jmlr_ternary_2024, title={Ternary Representations of Neural Networks}, journal={Journal of Machine Learning Research}, volume={26}, year={2024}, url={https://jmlr.org/papers/volume26/24-2050/24-2050.pdf} } @article{bitnet2023, title={BitNet: Scaling 1-bit Transformers for Large Language Models}, author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Wang, Huaijie and Ma, Lingxiao and Yang, Fan and Wang, Ruiping and Wu, Yi and Wei, Furu}, journal={arXiv preprint arXiv:2310.11453}, year={2023} } ``` ## Model Card Contact For questions or issues, please open an issue on GitHub or contact the maintainers. ## Glossary - **Ternary Quantization:** Representing weights with only three values {-1, 0, +1} - **Absmax Scaling:** Scaling factor computed as max(abs(weights)) - **Base-3 Packing:** Encoding ternary values in base-3 for memory efficiency - **Multi-Ternary:** Sum of k ternary components for improved approximation - **QAT:** Quantization-Aware Training - training with quantization in the loop ## More Information - **Documentation:** See `README.md` and `read/` directory - **Examples:** See `examples/` directory - **Benchmarks:** See `BENCHMARKS.md` - **Implementation Guide:** See `read/IMPLEMENTATION_GUIDE.md`