File size: 6,394 Bytes
fd8c8b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
# Model Card: BitLinear
## Model Description
**BitLinear** is a PyTorch implementation of ultra-low-precision (1.58-bit) ternary linear layers that can serve as drop-in replacements for `nn.Linear` in neural networks, particularly Transformers. It achieves ~19x memory compression while maintaining high output similarity.
### Model Details
- **Developed by:** BitLinear Contributors
- **Model type:** Quantization / Compression
- **Language:** Python, C++, CUDA
- **License:** MIT
- **Repository:** https://github.com/yourusername/bitlinear
## Intended Use
### Primary Use Cases
- **Edge Deployment:** Deploying large models on memory-constrained devices
- **Production Inference:** Reducing memory footprint for serving large language models
- **Research:** Exploring ultra-low-precision neural networks
- **Cost Optimization:** Reducing cloud infrastructure costs through memory savings
### Out-of-Scope Use Cases
- Training from scratch (requires quantization-aware training)
- Applications requiring exact numerical precision
- Real-time applications where Python overhead is prohibitive (use C++/CUDA extensions)
## How to Use
### Basic Usage
```python
import torch
from bitlinear import BitLinear
# Create a BitLinear layer (same interface as nn.Linear)
layer = BitLinear(in_features=512, out_features=1024, bias=True)
# Forward pass
x = torch.randn(32, 128, 512)
output = layer(x) # Same as nn.Linear
```
### Converting Existing Models
```python
import torch.nn as nn
from bitlinear import convert_linear_to_bitlinear
# Convert a pre-trained model
model = nn.TransformerEncoderLayer(d_model=512, nhead=8)
model_compressed = convert_linear_to_bitlinear(model, inplace=False)
# Use as normal
x = torch.randn(10, 32, 512)
output = model_compressed(x)
```
### Multi-Ternary for Better Accuracy
```python
from bitlinear import MultiTernaryLinear
# Use k=3 components for 75% error reduction
layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3)
```
## Performance
### Memory Compression
- **Average Compression:** 19.23x (95% of theoretical 20x)
- **GPT-2 Small Example:** 324 MB → 16.8 MB (307 MB saved)
| Layer Size | nn.Linear | BitLinear (Packed) | Compression |
|------------|-----------|-------------------|-------------|
| 512×512 | 1.00 MB | 0.05 MB | 18.6x |
| 1024×1024 | 4.00 MB | 0.21 MB | 19.3x |
| 4096×4096 | 64.02 MB | 3.23 MB | 19.8x |
### Accuracy
- **Cosine Similarity:** > 0.96 (96%+)
- **Relative Error:** ~0.28 (28%)
- **Multi-Ternary (k=3):** 75% error reduction vs k=1
## Limitations
### Known Limitations
1. **Accuracy Trade-off:** Ternary quantization introduces approximation error (~3-5% typical)
2. **Training:** Requires quantization-aware training (QAT) for optimal results
3. **Speed:** Python implementation may be slower than nn.Linear (use C++/CUDA for production)
4. **Activation Quantization:** Currently only weights are quantized (full BitNet includes activation quantization)
### Recommendations
- Fine-tune converted models for best accuracy
- Use k≥2 for MultiTernaryLinear when accuracy is critical
- Profile performance on your specific hardware
- Test accuracy on your specific task before deployment
## Training
### Quantization-Aware Training (QAT)
For best results, fine-tune models with BitLinear layers:
```python
# Convert pre-trained model
model_bit = convert_linear_to_bitlinear(pretrained_model)
# Fine-tune with standard training loop
optimizer = torch.optim.AdamW(model_bit.parameters(), lr=1e-4)
# ... train as normal ...
```
### From Scratch Training
Training from scratch with ternary weights requires:
- Careful initialization
- Straight-through estimators for gradients
- Potentially modified learning rates
See `read/IMPLEMENTATION_GUIDE.md` for details.
## Technical Specifications
### Architecture
- **Weight Quantization:** Ternary {-1, 0, +1}
- **Scaling:** Per-output-channel absmax scaling
- **Packing:** Base-3 encoding (5 values per byte)
- **Decomposition:** Greedy residual quantization for multi-ternary
### Implementation
- **Python:** Pure PyTorch baseline
- **C++:** Optimized CPU kernels with PyBind11
- **CUDA:** GPU kernels with warp-level reductions and shared memory tiling
### Requirements
- Python ≥ 3.8
- PyTorch ≥ 2.0.0
- NumPy ≥ 1.20.0
- C++ compiler (for C++ extensions)
- CUDA toolkit (optional, for GPU support)
## Evaluation
### Benchmarks
Comprehensive benchmarks available in `BENCHMARKS.md`:
- Memory compression analysis
- Forward pass timing
- Accuracy metrics
- Real-world transformer examples
### Validation
All implementations validated against:
- Unit tests (pytest suite)
- Numerical correctness tests
- Integration tests with Transformers
- Cross-implementation consistency (Python vs C++)
## Citation
If you use BitLinear in your research, please cite:
```bibtex
@article{jmlr_ternary_2024,
title={Ternary Representations of Neural Networks},
journal={Journal of Machine Learning Research},
volume={26},
year={2024},
url={https://jmlr.org/papers/volume26/24-2050/24-2050.pdf}
}
@article{bitnet2023,
title={BitNet: Scaling 1-bit Transformers for Large Language Models},
author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Wang, Huaijie and Ma, Lingxiao and Yang, Fan and Wang, Ruiping and Wu, Yi and Wei, Furu},
journal={arXiv preprint arXiv:2310.11453},
year={2023}
}
```
## Model Card Contact
For questions or issues, please open an issue on GitHub or contact the maintainers.
## Glossary
- **Ternary Quantization:** Representing weights with only three values {-1, 0, +1}
- **Absmax Scaling:** Scaling factor computed as max(abs(weights))
- **Base-3 Packing:** Encoding ternary values in base-3 for memory efficiency
- **Multi-Ternary:** Sum of k ternary components for improved approximation
- **QAT:** Quantization-Aware Training - training with quantization in the loop
## More Information
- **Documentation:** See `README.md` and `read/` directory
- **Examples:** See `examples/` directory
- **Benchmarks:** See `BENCHMARKS.md`
- **Implementation Guide:** See `read/IMPLEMENTATION_GUIDE.md`
|