BitLinear / read /IMPLEMENTATION_GUIDE.md
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified

Implementation Guide

This document provides a roadmap for implementing the BitLinear functionality, following the structure defined in the project skeleton. This is here to give insight on how one can replicate this process to different operations.

Implementation Order

Phase 1: Python Baseline (Correctness First)

Start here to establish correctness before optimizing.

1.1 Quantization (bitlinear/quantization.py)

Order of implementation:

  1. absmax_scale() - Simple max computation
  2. ternary_quantize() - Threshold-based quantization to {-1, 0, +1}
  3. weight_to_ternary() - Combines the above
  4. Test thoroughly with tests/test_quantization.py

Key considerations:

  • Threshold selection (try 0.33 * scale or 0.5 * scale)
  • Per-channel vs. global scaling trade-offs
  • Numerical stability (avoid division by zero)

1.2 Functional Operations (bitlinear/functional.py)

Order of implementation:

  1. bitlinear_python() - Core ternary matmul

    # Pseudocode:
    output = torch.matmul(x, W_ternary.T)
    output = output * gamma.unsqueeze(0)
    if bias is not None:
        output = output + bias
    return output
    
  2. greedy_ternary_decomposition() - Iterative residual quantization

    # Pseudocode:
    residual = W.clone()
    for i in range(k):
        W_t, gamma = weight_to_ternary(residual)
        store W_t and gamma
        residual = residual - gamma * W_t
    
  3. multi_ternary_linear_python() - Sum of k ternary operations

  4. Test with tests/test_functional.py

1.3 Layer Modules (bitlinear/layers.py)

Order of implementation:

  1. BitLinear.__init__() and reset_parameters()

    • Initialize dense weights using kaiming_uniform
    • Quantize to ternary using weight_to_ternary()
    • Store as buffers or parameters
  2. BitLinear.forward() - Call bitlinear_python()

  3. BitLinear.from_linear() - Conversion utility

  4. MultiTernaryLinear - Similar structure

  5. convert_linear_to_bitlinear() - Recursive module conversion

  6. Test with tests/test_layers.py

Testing strategy:

  • Compare output shapes with nn.Linear
  • Verify ternary weight values
  • Test conversion from pre-trained weights
  • Validate in Transformer example

Phase 2: Memory Optimization

2.1 Base-3 Packing (bitlinear/packing.py)

Implement packing for memory efficiency:

  1. pack_ternary_base3() - 5 values per byte
  2. unpack_ternary_base3() - Reverse operation
  3. Verify roundtrip: pack → unpack == identity

Packing scheme:

Map: -1 → 0, 0 → 1, +1 → 2 (base-3 digits)
Pack 5 digits per byte: d0 + d1*3 + d2*9 + d3*27 + d4*81

Phase 3: C++ Extensions (Optional but Recommended)

3.1 CPU Implementation (bitlinear/cpp/bitlinear.cpp)

  1. Implement bitlinear_cpu_forward()

    • Basic matrix multiplication with ternary weights
    • Exploit ternary structure (skip multiplications)
  2. Implement multi_ternary_cpu_forward()

  3. Test integration with Python

Optimization opportunities (later):

  • AVX/AVX512 vectorization
  • OpenMP parallelization
  • Cache-efficient tiling

3.2 CUDA Kernels (bitlinear/cpp/bitlinear_kernel.cu)

Only after CPU version works!

  1. Basic kernel without optimization

    • Thread per output element
    • Simple accumulation
  2. Optimized kernel:

    • Shared memory tiling
    • Warp-level reductions
    • Memory coalescing
    • Exploit ternary (conditional accumulation)
  3. Advanced (optional):

    • Tensor Core utilization
    • Mixed precision
    • Fused kernels (activation quantization + matmul)

Performance targets:

  • Should be faster than PyTorch's F.linear for large matrices
  • Aim for 2-5x speedup from ternary optimization

Phase 4: Training Support

4.1 Quantization-Aware Training (QAT)

Modify layers to support gradient flow:

  1. Straight-through estimator for ternary quantization
  2. Learnable scaling factors (gamma)
  3. Fine-tuning pre-trained models

4.2 Initialization Strategies

Experiment with initialization for ternary weights:

  • Standard kaiming_uniform then quantize
  • Specialized initialization for ternary
  • Better threshold selection

Testing Strategy

Unit Tests

Run frequently during development:

pytest tests/test_quantization.py -v
pytest tests/test_functional.py -v
pytest tests/test_layers.py -v

Integration Tests

Test full pipelines:

  1. Dense model → quantization → inference
  2. Transformer with BitLinear layers
  3. Save/load model checkpoints

Numerical Correctness

Compare with reference:

# Create same layer in dense and ternary
linear = nn.Linear(512, 512)
bitlinear = BitLinear.from_linear(linear)

x = torch.randn(32, 512)
out_dense = linear(x)
out_ternary = bitlinear(x)

# Should be similar (not identical due to quantization)
error = torch.norm(out_dense - out_ternary) / torch.norm(out_dense)
print(f"Relative error: {error:.4f}")  # Expect ~0.1-0.3

Common Pitfalls

Quantization

  • Pitfall: Wrong threshold → too many zeros or not enough
  • Solution: Start with 0.5 * scale, tune empirically

Shape Handling

  • Pitfall: Broadcasting errors with gamma
  • Solution: Use .unsqueeze() carefully, test various input shapes

CUDA Compilation

  • Pitfall: CUDA version mismatches
  • Solution: Match PyTorch's CUDA version, use CPU-only build first

Gradients

  • Pitfall: No gradient flow through ternary quantization
  • Solution: Implement straight-through estimator for QAT

Performance Benchmarks

Create benchmarks to track progress:

import time
import torch
from bitlinear import BitLinear

def benchmark(layer, x, n_runs=100):
    # Warmup
    for _ in range(10):
        _ = layer(x)
    
    # Benchmark
    start = time.time()
    for _ in range(n_runs):
        _ = layer(x)
    end = time.time()
    
    return (end - start) / n_runs

# Compare
linear = nn.Linear(2048, 2048).cuda()
bitlinear = BitLinear(2048, 2048).cuda()
x = torch.randn(128, 2048).cuda()

time_linear = benchmark(linear, x)
time_bitlinear = benchmark(bitlinear, x)

print(f"nn.Linear: {time_linear*1000:.2f} ms")
print(f"BitLinear: {time_bitlinear*1000:.2f} ms")
print(f"Speedup: {time_linear/time_bitlinear:.2f}x")

Next Steps After Skeleton

  1. Implement Phase 1 (Python baseline)

    • Start with absmax_scale() and ternary_quantize()
    • Test each function as you go
    • Don't move to next phase until tests pass
  2. Validate with Examples

    • Run examples/basic_usage.py
    • Run examples/transformer_example.py
    • Check output similarity and memory savings
  3. Optimize if Needed

    • Profile to find bottlenecks
    • Implement C++/CUDA only after Python works
    • Measure performance improvements
  4. Documentation

    • Add docstring details from implementation
    • Create API documentation
    • Write usage tutorials

Resources

Papers

PyTorch Resources

Quantization

Questions to Consider

As you implement, think about:

  1. Memory vs. Speed: Packed weights save memory but need unpacking
  2. Training vs. Inference: Different requirements for gradients
  3. Compatibility: Should work with existing PyTorch features (DDP, AMP, etc.)
  4. Extensibility: Easy to add new quantization schemes?

Good luck with implementation! Start with correctness, then optimize.