| # Model Card: BitLinear | |
| ## Model Description | |
| **BitLinear** is a PyTorch implementation of ultra-low-precision (1.58-bit) ternary linear layers that can serve as drop-in replacements for `nn.Linear` in neural networks, particularly Transformers. It achieves ~19x memory compression while maintaining high output similarity. | |
| ### Model Details | |
| - **Developed by:** BitLinear Contributors | |
| - **Model type:** Quantization / Compression | |
| - **Language:** Python, C++, CUDA | |
| - **License:** MIT | |
| - **Repository:** https://github.com/yourusername/bitlinear | |
| ## Intended Use | |
| ### Primary Use Cases | |
| - **Edge Deployment:** Deploying large models on memory-constrained devices | |
| - **Production Inference:** Reducing memory footprint for serving large language models | |
| - **Research:** Exploring ultra-low-precision neural networks | |
| - **Cost Optimization:** Reducing cloud infrastructure costs through memory savings | |
| ### Out-of-Scope Use Cases | |
| - Training from scratch (requires quantization-aware training) | |
| - Applications requiring exact numerical precision | |
| - Real-time applications where Python overhead is prohibitive (use C++/CUDA extensions) | |
| ## How to Use | |
| ### Basic Usage | |
| ```python | |
| import torch | |
| from bitlinear import BitLinear | |
| # Create a BitLinear layer (same interface as nn.Linear) | |
| layer = BitLinear(in_features=512, out_features=1024, bias=True) | |
| # Forward pass | |
| x = torch.randn(32, 128, 512) | |
| output = layer(x) # Same as nn.Linear | |
| ``` | |
| ### Converting Existing Models | |
| ```python | |
| import torch.nn as nn | |
| from bitlinear import convert_linear_to_bitlinear | |
| # Convert a pre-trained model | |
| model = nn.TransformerEncoderLayer(d_model=512, nhead=8) | |
| model_compressed = convert_linear_to_bitlinear(model, inplace=False) | |
| # Use as normal | |
| x = torch.randn(10, 32, 512) | |
| output = model_compressed(x) | |
| ``` | |
| ### Multi-Ternary for Better Accuracy | |
| ```python | |
| from bitlinear import MultiTernaryLinear | |
| # Use k=3 components for 75% error reduction | |
| layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3) | |
| ``` | |
| ## Performance | |
| ### Memory Compression | |
| - **Average Compression:** 19.23x (95% of theoretical 20x) | |
| - **GPT-2 Small Example:** 324 MB → 16.8 MB (307 MB saved) | |
| | Layer Size | nn.Linear | BitLinear (Packed) | Compression | | |
| |------------|-----------|-------------------|-------------| | |
| | 512×512 | 1.00 MB | 0.05 MB | 18.6x | | |
| | 1024×1024 | 4.00 MB | 0.21 MB | 19.3x | | |
| | 4096×4096 | 64.02 MB | 3.23 MB | 19.8x | | |
| ### Accuracy | |
| - **Cosine Similarity:** > 0.96 (96%+) | |
| - **Relative Error:** ~0.28 (28%) | |
| - **Multi-Ternary (k=3):** 75% error reduction vs k=1 | |
| ## Limitations | |
| ### Known Limitations | |
| 1. **Accuracy Trade-off:** Ternary quantization introduces approximation error (~3-5% typical) | |
| 2. **Training:** Requires quantization-aware training (QAT) for optimal results | |
| 3. **Speed:** Python implementation may be slower than nn.Linear (use C++/CUDA for production) | |
| 4. **Activation Quantization:** Currently only weights are quantized (full BitNet includes activation quantization) | |
| ### Recommendations | |
| - Fine-tune converted models for best accuracy | |
| - Use k≥2 for MultiTernaryLinear when accuracy is critical | |
| - Profile performance on your specific hardware | |
| - Test accuracy on your specific task before deployment | |
| ## Training | |
| ### Quantization-Aware Training (QAT) | |
| For best results, fine-tune models with BitLinear layers: | |
| ```python | |
| # Convert pre-trained model | |
| model_bit = convert_linear_to_bitlinear(pretrained_model) | |
| # Fine-tune with standard training loop | |
| optimizer = torch.optim.AdamW(model_bit.parameters(), lr=1e-4) | |
| # ... train as normal ... | |
| ``` | |
| ### From Scratch Training | |
| Training from scratch with ternary weights requires: | |
| - Careful initialization | |
| - Straight-through estimators for gradients | |
| - Potentially modified learning rates | |
| See `read/IMPLEMENTATION_GUIDE.md` for details. | |
| ## Technical Specifications | |
| ### Architecture | |
| - **Weight Quantization:** Ternary {-1, 0, +1} | |
| - **Scaling:** Per-output-channel absmax scaling | |
| - **Packing:** Base-3 encoding (5 values per byte) | |
| - **Decomposition:** Greedy residual quantization for multi-ternary | |
| ### Implementation | |
| - **Python:** Pure PyTorch baseline | |
| - **C++:** Optimized CPU kernels with PyBind11 | |
| - **CUDA:** GPU kernels with warp-level reductions and shared memory tiling | |
| ### Requirements | |
| - Python ≥ 3.8 | |
| - PyTorch ≥ 2.0.0 | |
| - NumPy ≥ 1.20.0 | |
| - C++ compiler (for C++ extensions) | |
| - CUDA toolkit (optional, for GPU support) | |
| ## Evaluation | |
| ### Benchmarks | |
| Comprehensive benchmarks available in `BENCHMARKS.md`: | |
| - Memory compression analysis | |
| - Forward pass timing | |
| - Accuracy metrics | |
| - Real-world transformer examples | |
| ### Validation | |
| All implementations validated against: | |
| - Unit tests (pytest suite) | |
| - Numerical correctness tests | |
| - Integration tests with Transformers | |
| - Cross-implementation consistency (Python vs C++) | |
| ## Citation | |
| If you use BitLinear in your research, please cite: | |
| ```bibtex | |
| @article{jmlr_ternary_2024, | |
| title={Ternary Representations of Neural Networks}, | |
| journal={Journal of Machine Learning Research}, | |
| volume={26}, | |
| year={2024}, | |
| url={https://jmlr.org/papers/volume26/24-2050/24-2050.pdf} | |
| } | |
| @article{bitnet2023, | |
| title={BitNet: Scaling 1-bit Transformers for Large Language Models}, | |
| author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Wang, Huaijie and Ma, Lingxiao and Yang, Fan and Wang, Ruiping and Wu, Yi and Wei, Furu}, | |
| journal={arXiv preprint arXiv:2310.11453}, | |
| year={2023} | |
| } | |
| ``` | |
| ## Model Card Contact | |
| For questions or issues, please open an issue on GitHub or contact the maintainers. | |
| ## Glossary | |
| - **Ternary Quantization:** Representing weights with only three values {-1, 0, +1} | |
| - **Absmax Scaling:** Scaling factor computed as max(abs(weights)) | |
| - **Base-3 Packing:** Encoding ternary values in base-3 for memory efficiency | |
| - **Multi-Ternary:** Sum of k ternary components for improved approximation | |
| - **QAT:** Quantization-Aware Training - training with quantization in the loop | |
| ## More Information | |
| - **Documentation:** See `README.md` and `read/` directory | |
| - **Examples:** See `examples/` directory | |
| - **Benchmarks:** See `BENCHMARKS.md` | |
| - **Implementation Guide:** See `read/IMPLEMENTATION_GUIDE.md` | |