BitLinear / MODEL_CARD.md
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
# Model Card: BitLinear
## Model Description
**BitLinear** is a PyTorch implementation of ultra-low-precision (1.58-bit) ternary linear layers that can serve as drop-in replacements for `nn.Linear` in neural networks, particularly Transformers. It achieves ~19x memory compression while maintaining high output similarity.
### Model Details
- **Developed by:** BitLinear Contributors
- **Model type:** Quantization / Compression
- **Language:** Python, C++, CUDA
- **License:** MIT
- **Repository:** https://github.com/yourusername/bitlinear
## Intended Use
### Primary Use Cases
- **Edge Deployment:** Deploying large models on memory-constrained devices
- **Production Inference:** Reducing memory footprint for serving large language models
- **Research:** Exploring ultra-low-precision neural networks
- **Cost Optimization:** Reducing cloud infrastructure costs through memory savings
### Out-of-Scope Use Cases
- Training from scratch (requires quantization-aware training)
- Applications requiring exact numerical precision
- Real-time applications where Python overhead is prohibitive (use C++/CUDA extensions)
## How to Use
### Basic Usage
```python
import torch
from bitlinear import BitLinear
# Create a BitLinear layer (same interface as nn.Linear)
layer = BitLinear(in_features=512, out_features=1024, bias=True)
# Forward pass
x = torch.randn(32, 128, 512)
output = layer(x) # Same as nn.Linear
```
### Converting Existing Models
```python
import torch.nn as nn
from bitlinear import convert_linear_to_bitlinear
# Convert a pre-trained model
model = nn.TransformerEncoderLayer(d_model=512, nhead=8)
model_compressed = convert_linear_to_bitlinear(model, inplace=False)
# Use as normal
x = torch.randn(10, 32, 512)
output = model_compressed(x)
```
### Multi-Ternary for Better Accuracy
```python
from bitlinear import MultiTernaryLinear
# Use k=3 components for 75% error reduction
layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3)
```
## Performance
### Memory Compression
- **Average Compression:** 19.23x (95% of theoretical 20x)
- **GPT-2 Small Example:** 324 MB → 16.8 MB (307 MB saved)
| Layer Size | nn.Linear | BitLinear (Packed) | Compression |
|------------|-----------|-------------------|-------------|
| 512×512 | 1.00 MB | 0.05 MB | 18.6x |
| 1024×1024 | 4.00 MB | 0.21 MB | 19.3x |
| 4096×4096 | 64.02 MB | 3.23 MB | 19.8x |
### Accuracy
- **Cosine Similarity:** > 0.96 (96%+)
- **Relative Error:** ~0.28 (28%)
- **Multi-Ternary (k=3):** 75% error reduction vs k=1
## Limitations
### Known Limitations
1. **Accuracy Trade-off:** Ternary quantization introduces approximation error (~3-5% typical)
2. **Training:** Requires quantization-aware training (QAT) for optimal results
3. **Speed:** Python implementation may be slower than nn.Linear (use C++/CUDA for production)
4. **Activation Quantization:** Currently only weights are quantized (full BitNet includes activation quantization)
### Recommendations
- Fine-tune converted models for best accuracy
- Use k≥2 for MultiTernaryLinear when accuracy is critical
- Profile performance on your specific hardware
- Test accuracy on your specific task before deployment
## Training
### Quantization-Aware Training (QAT)
For best results, fine-tune models with BitLinear layers:
```python
# Convert pre-trained model
model_bit = convert_linear_to_bitlinear(pretrained_model)
# Fine-tune with standard training loop
optimizer = torch.optim.AdamW(model_bit.parameters(), lr=1e-4)
# ... train as normal ...
```
### From Scratch Training
Training from scratch with ternary weights requires:
- Careful initialization
- Straight-through estimators for gradients
- Potentially modified learning rates
See `read/IMPLEMENTATION_GUIDE.md` for details.
## Technical Specifications
### Architecture
- **Weight Quantization:** Ternary {-1, 0, +1}
- **Scaling:** Per-output-channel absmax scaling
- **Packing:** Base-3 encoding (5 values per byte)
- **Decomposition:** Greedy residual quantization for multi-ternary
### Implementation
- **Python:** Pure PyTorch baseline
- **C++:** Optimized CPU kernels with PyBind11
- **CUDA:** GPU kernels with warp-level reductions and shared memory tiling
### Requirements
- Python ≥ 3.8
- PyTorch ≥ 2.0.0
- NumPy ≥ 1.20.0
- C++ compiler (for C++ extensions)
- CUDA toolkit (optional, for GPU support)
## Evaluation
### Benchmarks
Comprehensive benchmarks available in `BENCHMARKS.md`:
- Memory compression analysis
- Forward pass timing
- Accuracy metrics
- Real-world transformer examples
### Validation
All implementations validated against:
- Unit tests (pytest suite)
- Numerical correctness tests
- Integration tests with Transformers
- Cross-implementation consistency (Python vs C++)
## Citation
If you use BitLinear in your research, please cite:
```bibtex
@article{jmlr_ternary_2024,
title={Ternary Representations of Neural Networks},
journal={Journal of Machine Learning Research},
volume={26},
year={2024},
url={https://jmlr.org/papers/volume26/24-2050/24-2050.pdf}
}
@article{bitnet2023,
title={BitNet: Scaling 1-bit Transformers for Large Language Models},
author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Wang, Huaijie and Ma, Lingxiao and Yang, Fan and Wang, Ruiping and Wu, Yi and Wei, Furu},
journal={arXiv preprint arXiv:2310.11453},
year={2023}
}
```
## Model Card Contact
For questions or issues, please open an issue on GitHub or contact the maintainers.
## Glossary
- **Ternary Quantization:** Representing weights with only three values {-1, 0, +1}
- **Absmax Scaling:** Scaling factor computed as max(abs(weights))
- **Base-3 Packing:** Encoding ternary values in base-3 for memory efficiency
- **Multi-Ternary:** Sum of k ternary components for improved approximation
- **QAT:** Quantization-Aware Training - training with quantization in the loop
## More Information
- **Documentation:** See `README.md` and `read/` directory
- **Examples:** See `examples/` directory
- **Benchmarks:** See `BENCHMARKS.md`
- **Implementation Guide:** See `read/IMPLEMENTATION_GUIDE.md`