BitLinear / read /IMPLEMENTATION_GUIDE.md
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
# Implementation Guide
This document provides a roadmap for implementing the BitLinear functionality, following the structure defined in the project skeleton. This is here to give insight on how one can replicate this process to different operations.
## Implementation Order
### Phase 1: Python Baseline (Correctness First)
Start here to establish correctness before optimizing.
#### 1.1 Quantization (`bitlinear/quantization.py`)
Order of implementation:
1. `absmax_scale()` - Simple max computation
2. `ternary_quantize()` - Threshold-based quantization to {-1, 0, +1}
3. `weight_to_ternary()` - Combines the above
4. Test thoroughly with `tests/test_quantization.py`
**Key considerations:**
- Threshold selection (try 0.33 * scale or 0.5 * scale)
- Per-channel vs. global scaling trade-offs
- Numerical stability (avoid division by zero)
#### 1.2 Functional Operations (`bitlinear/functional.py`)
Order of implementation:
1. `bitlinear_python()` - Core ternary matmul
```python
# Pseudocode:
output = torch.matmul(x, W_ternary.T)
output = output * gamma.unsqueeze(0)
if bias is not None:
output = output + bias
return output
```
2. `greedy_ternary_decomposition()` - Iterative residual quantization
```python
# Pseudocode:
residual = W.clone()
for i in range(k):
W_t, gamma = weight_to_ternary(residual)
store W_t and gamma
residual = residual - gamma * W_t
```
3. `multi_ternary_linear_python()` - Sum of k ternary operations
4. Test with `tests/test_functional.py`
#### 1.3 Layer Modules (`bitlinear/layers.py`)
Order of implementation:
1. `BitLinear.__init__()` and `reset_parameters()`
- Initialize dense weights using kaiming_uniform
- Quantize to ternary using `weight_to_ternary()`
- Store as buffers or parameters
2. `BitLinear.forward()` - Call `bitlinear_python()`
3. `BitLinear.from_linear()` - Conversion utility
4. `MultiTernaryLinear` - Similar structure
5. `convert_linear_to_bitlinear()` - Recursive module conversion
6. Test with `tests/test_layers.py`
**Testing strategy:**
- Compare output shapes with nn.Linear
- Verify ternary weight values
- Test conversion from pre-trained weights
- Validate in Transformer example
### Phase 2: Memory Optimization
#### 2.1 Base-3 Packing (`bitlinear/packing.py`)
Implement packing for memory efficiency:
1. `pack_ternary_base3()` - 5 values per byte
2. `unpack_ternary_base3()` - Reverse operation
3. Verify roundtrip: pack → unpack == identity
**Packing scheme:**
```
Map: -1 → 0, 0 → 1, +1 → 2 (base-3 digits)
Pack 5 digits per byte: d0 + d1*3 + d2*9 + d3*27 + d4*81
```
### Phase 3: C++ Extensions (Optional but Recommended)
#### 3.1 CPU Implementation (`bitlinear/cpp/bitlinear.cpp`)
1. Implement `bitlinear_cpu_forward()`
- Basic matrix multiplication with ternary weights
- Exploit ternary structure (skip multiplications)
2. Implement `multi_ternary_cpu_forward()`
3. Test integration with Python
**Optimization opportunities (later):**
- AVX/AVX512 vectorization
- OpenMP parallelization
- Cache-efficient tiling
#### 3.2 CUDA Kernels (`bitlinear/cpp/bitlinear_kernel.cu`)
Only after CPU version works!
1. Basic kernel without optimization
- Thread per output element
- Simple accumulation
2. Optimized kernel:
- Shared memory tiling
- Warp-level reductions
- Memory coalescing
- Exploit ternary (conditional accumulation)
3. Advanced (optional):
- Tensor Core utilization
- Mixed precision
- Fused kernels (activation quantization + matmul)
**Performance targets:**
- Should be faster than PyTorch's F.linear for large matrices
- Aim for 2-5x speedup from ternary optimization
### Phase 4: Training Support
#### 4.1 Quantization-Aware Training (QAT)
Modify layers to support gradient flow:
1. Straight-through estimator for ternary quantization
2. Learnable scaling factors (gamma)
3. Fine-tuning pre-trained models
#### 4.2 Initialization Strategies
Experiment with initialization for ternary weights:
- Standard kaiming_uniform then quantize
- Specialized initialization for ternary
- Better threshold selection
## Testing Strategy
### Unit Tests
Run frequently during development:
```bash
pytest tests/test_quantization.py -v
pytest tests/test_functional.py -v
pytest tests/test_layers.py -v
```
### Integration Tests
Test full pipelines:
1. Dense model → quantization → inference
2. Transformer with BitLinear layers
3. Save/load model checkpoints
### Numerical Correctness
Compare with reference:
```python
# Create same layer in dense and ternary
linear = nn.Linear(512, 512)
bitlinear = BitLinear.from_linear(linear)
x = torch.randn(32, 512)
out_dense = linear(x)
out_ternary = bitlinear(x)
# Should be similar (not identical due to quantization)
error = torch.norm(out_dense - out_ternary) / torch.norm(out_dense)
print(f"Relative error: {error:.4f}") # Expect ~0.1-0.3
```
## Common Pitfalls
### Quantization
- **Pitfall:** Wrong threshold → too many zeros or not enough
- **Solution:** Start with 0.5 * scale, tune empirically
### Shape Handling
- **Pitfall:** Broadcasting errors with gamma
- **Solution:** Use `.unsqueeze()` carefully, test various input shapes
### CUDA Compilation
- **Pitfall:** CUDA version mismatches
- **Solution:** Match PyTorch's CUDA version, use CPU-only build first
### Gradients
- **Pitfall:** No gradient flow through ternary quantization
- **Solution:** Implement straight-through estimator for QAT
## Performance Benchmarks
Create benchmarks to track progress:
```python
import time
import torch
from bitlinear import BitLinear
def benchmark(layer, x, n_runs=100):
# Warmup
for _ in range(10):
_ = layer(x)
# Benchmark
start = time.time()
for _ in range(n_runs):
_ = layer(x)
end = time.time()
return (end - start) / n_runs
# Compare
linear = nn.Linear(2048, 2048).cuda()
bitlinear = BitLinear(2048, 2048).cuda()
x = torch.randn(128, 2048).cuda()
time_linear = benchmark(linear, x)
time_bitlinear = benchmark(bitlinear, x)
print(f"nn.Linear: {time_linear*1000:.2f} ms")
print(f"BitLinear: {time_bitlinear*1000:.2f} ms")
print(f"Speedup: {time_linear/time_bitlinear:.2f}x")
```
## Next Steps After Skeleton
1. **Implement Phase 1** (Python baseline)
- Start with `absmax_scale()` and `ternary_quantize()`
- Test each function as you go
- Don't move to next phase until tests pass
2. **Validate with Examples**
- Run `examples/basic_usage.py`
- Run `examples/transformer_example.py`
- Check output similarity and memory savings
3. **Optimize if Needed**
- Profile to find bottlenecks
- Implement C++/CUDA only after Python works
- Measure performance improvements
4. **Documentation**
- Add docstring details from implementation
- Create API documentation
- Write usage tutorials
## Resources
### Papers
- BitNet: https://arxiv.org/abs/2310.11453
- Ternary Neural Networks: https://jmlr.org/papers/volume26/24-2050/24-2050.pdf
### PyTorch Resources
- Custom Extensions: https://pytorch.org/tutorials/advanced/cpp_extension.html
- CUDA Programming: https://pytorch.org/tutorials/advanced/custom_ops.html
### Quantization
- QAT Guide: https://pytorch.org/docs/stable/quantization.html
- Straight-through Estimator: Bengio et al., 2013
## Questions to Consider
As you implement, think about:
1. **Memory vs. Speed:** Packed weights save memory but need unpacking
2. **Training vs. Inference:** Different requirements for gradients
3. **Compatibility:** Should work with existing PyTorch features (DDP, AMP, etc.)
4. **Extensibility:** Easy to add new quantization schemes?
Good luck with implementation! Start with correctness, then optimize.