File size: 6,394 Bytes
fd8c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# Model Card: BitLinear

## Model Description

**BitLinear** is a PyTorch implementation of ultra-low-precision (1.58-bit) ternary linear layers that can serve as drop-in replacements for `nn.Linear` in neural networks, particularly Transformers. It achieves ~19x memory compression while maintaining high output similarity.

### Model Details

- **Developed by:** BitLinear Contributors
- **Model type:** Quantization / Compression
- **Language:** Python, C++, CUDA
- **License:** MIT
- **Repository:** https://github.com/yourusername/bitlinear

## Intended Use

### Primary Use Cases

- **Edge Deployment:** Deploying large models on memory-constrained devices
- **Production Inference:** Reducing memory footprint for serving large language models
- **Research:** Exploring ultra-low-precision neural networks
- **Cost Optimization:** Reducing cloud infrastructure costs through memory savings

### Out-of-Scope Use Cases

- Training from scratch (requires quantization-aware training)
- Applications requiring exact numerical precision
- Real-time applications where Python overhead is prohibitive (use C++/CUDA extensions)

## How to Use

### Basic Usage

```python

import torch

from bitlinear import BitLinear



# Create a BitLinear layer (same interface as nn.Linear)

layer = BitLinear(in_features=512, out_features=1024, bias=True)



# Forward pass

x = torch.randn(32, 128, 512)

output = layer(x)  # Same as nn.Linear

```

### Converting Existing Models

```python

import torch.nn as nn

from bitlinear import convert_linear_to_bitlinear



# Convert a pre-trained model

model = nn.TransformerEncoderLayer(d_model=512, nhead=8)

model_compressed = convert_linear_to_bitlinear(model, inplace=False)



# Use as normal

x = torch.randn(10, 32, 512)

output = model_compressed(x)

```

### Multi-Ternary for Better Accuracy

```python

from bitlinear import MultiTernaryLinear



# Use k=3 components for 75% error reduction

layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3)

```

## Performance

### Memory Compression

- **Average Compression:** 19.23x (95% of theoretical 20x)
- **GPT-2 Small Example:** 324 MB → 16.8 MB (307 MB saved)

| Layer Size | nn.Linear | BitLinear (Packed) | Compression |
|------------|-----------|-------------------|-------------|
| 512×512    | 1.00 MB   | 0.05 MB          | 18.6x       |
| 1024×1024  | 4.00 MB   | 0.21 MB          | 19.3x       |
| 4096×4096  | 64.02 MB  | 3.23 MB          | 19.8x       |

### Accuracy

- **Cosine Similarity:** > 0.96 (96%+)
- **Relative Error:** ~0.28 (28%)
- **Multi-Ternary (k=3):** 75% error reduction vs k=1

## Limitations

### Known Limitations

1. **Accuracy Trade-off:** Ternary quantization introduces approximation error (~3-5% typical)
2. **Training:** Requires quantization-aware training (QAT) for optimal results
3. **Speed:** Python implementation may be slower than nn.Linear (use C++/CUDA for production)
4. **Activation Quantization:** Currently only weights are quantized (full BitNet includes activation quantization)

### Recommendations

- Fine-tune converted models for best accuracy
- Use k≥2 for MultiTernaryLinear when accuracy is critical
- Profile performance on your specific hardware
- Test accuracy on your specific task before deployment

## Training

### Quantization-Aware Training (QAT)

For best results, fine-tune models with BitLinear layers:

```python

# Convert pre-trained model

model_bit = convert_linear_to_bitlinear(pretrained_model)



# Fine-tune with standard training loop

optimizer = torch.optim.AdamW(model_bit.parameters(), lr=1e-4)

# ... train as normal ...

```

### From Scratch Training

Training from scratch with ternary weights requires:
- Careful initialization
- Straight-through estimators for gradients
- Potentially modified learning rates

See `read/IMPLEMENTATION_GUIDE.md` for details.

## Technical Specifications

### Architecture

- **Weight Quantization:** Ternary {-1, 0, +1}
- **Scaling:** Per-output-channel absmax scaling
- **Packing:** Base-3 encoding (5 values per byte)
- **Decomposition:** Greedy residual quantization for multi-ternary

### Implementation

- **Python:** Pure PyTorch baseline
- **C++:** Optimized CPU kernels with PyBind11
- **CUDA:** GPU kernels with warp-level reductions and shared memory tiling

### Requirements

- Python ≥ 3.8
- PyTorch ≥ 2.0.0
- NumPy ≥ 1.20.0
- C++ compiler (for C++ extensions)
- CUDA toolkit (optional, for GPU support)

## Evaluation

### Benchmarks

Comprehensive benchmarks available in `BENCHMARKS.md`:
- Memory compression analysis
- Forward pass timing
- Accuracy metrics
- Real-world transformer examples

### Validation

All implementations validated against:
- Unit tests (pytest suite)
- Numerical correctness tests
- Integration tests with Transformers
- Cross-implementation consistency (Python vs C++)

## Citation

If you use BitLinear in your research, please cite:

```bibtex

@article{jmlr_ternary_2024,

  title={Ternary Representations of Neural Networks},

  journal={Journal of Machine Learning Research},

  volume={26},

  year={2024},

  url={https://jmlr.org/papers/volume26/24-2050/24-2050.pdf}

}



@article{bitnet2023,

  title={BitNet: Scaling 1-bit Transformers for Large Language Models},

  author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Wang, Huaijie and Ma, Lingxiao and Yang, Fan and Wang, Ruiping and Wu, Yi and Wei, Furu},

  journal={arXiv preprint arXiv:2310.11453},

  year={2023}

}

```

## Model Card Contact

For questions or issues, please open an issue on GitHub or contact the maintainers.

## Glossary

- **Ternary Quantization:** Representing weights with only three values {-1, 0, +1}
- **Absmax Scaling:** Scaling factor computed as max(abs(weights))
- **Base-3 Packing:** Encoding ternary values in base-3 for memory efficiency
- **Multi-Ternary:** Sum of k ternary components for improved approximation
- **QAT:** Quantization-Aware Training - training with quantization in the loop

## More Information

- **Documentation:** See `README.md` and `read/` directory
- **Examples:** See `examples/` directory
- **Benchmarks:** See `BENCHMARKS.md`
- **Implementation Guide:** See `read/IMPLEMENTATION_GUIDE.md`