lewm-models / docs /quantization.md
eren23
Initial: LeWM model collection with full quantization documentation
6cdcc30
# Quantization Deep Dive
## Overview
We explored multiple quantization strategies for LeWM. This document records what was tried, why it worked or didn't, and the engineering trade-offs.
## Formats Evaluated
| Format | Bits/weight | Compression | Cos vs f32 | Status |
|--------|-------------|-------------|-------------|--------|
| **INT8 (encoder)** | 8 | 4x | 0.9999 | Production |
| **Q4 (predictor)** | 4 | 2x more | 0.998 | Production |
| **INT8+Q4 (full)** | mixed | 5x total | 0.999 | **Production** |
| Q4 encoder only | 4 | 4x | 0.93 | Rejected |
| Ternary {-1,0,+1} | 2 | 8x | ~0.85 | Rejected |
| WANDA 20% | sparse Q4 | 20% fewer | ~0.99 | Experimental |
| WANDA 40% | sparse Q4 | 40% fewer | ~0.97 | Experimental |
## INT8 Per-Channel (Encoder)
### Method
```
f32_weight [out, in] β†’ per-output-channel quantization
β†’ int8_weight [out, in]
β†’ per-output-channel scales [out]
At inference:
result = input @ dequant(weight) Γ— scales
= input @ (int8_weight + zero_point) Γ— scales
= (input @ int8_weight) Γ— scales // zero_point = 0 for symmetric
```
### Why It Works
1. **Per-channel** preserves channel-level statistics. Each output channel gets its own scale.
2. **Symmetric** (zero_point = 0) avoids the overhead of asymmetric quantization.
3. **INT8 GEMV** on PIE SIMD: 16-wide multiply-accumulate per cycle.
4. **Encoder activations** have predictable dynamic range after LayerNorm.
### Quality
| Layer | cos vs f32 |
|-------|-----------|
| patch_embed | 0.9999 |
| encoder layer 0 | 0.9999 |
| encoder layer 5 | 0.9998 |
| encoder.proj | 0.9999 |
| **Total** | **0.9999** |
### Engineering Notes
- Activation quantization is dynamic (per-row at inference time), not static.
- QKV shared quantization: same normalized input quantized once for Q, K, V.
- Scales stored as f32 (4 bytes per channel) β€” negligible overhead.
## Q4 Block (Predictor)
### Method
```
f32_weight [out, in] β†’ per-32-element-block quantization
β†’ nibble-packed weight data
β†’ per-block f16 scales
At inference:
for each block of 32:
unpack nibbles β†’ int8 [-8, 7]
dot = simd_dot(input_block, unpacked)
result += dot Γ— block_scale
```
### Why It Works
1. **Per-block** (32 elements) matches the predictor's adaLN normalization.
2. **adaLN modulation** provides implicit normalization β€” weights don't need per-channel precision.
3. **Smaller layers** = less error accumulation. 4-layer predictor vs 6-layer encoder.
4. **PIE SIMD** handles the nibble unpack + dot in tight loops.
### Why Full Q4 (Encoder + Predictor) Fails
When we skip INT8 and quantize the encoder to Q4:
| Layer | INT8 cos | Q4 cos |
|-------|----------|--------|
| encoder | 0.9999 | 0.93 |
| predictor | 0.998 | 0.998 |
| **Total** | **0.999** | **0.93** |
**Root cause**: ViT encoder has high dynamic range in intermediate activations. The 32-element block granularity doesn't align with the encoder's channel statistics. INT8's per-channel precision is essential.
### Engineering Notes
- Nibbles decode as `value - 8` β†’ range [-8, 7]
- Block scales stored as f16 (2 bytes per block) β€” 64 bytes per 32Γ—32 block
- Zero weights are rare in Q4 (<1%) β€” skip-zero optimization not implemented
## Ternary ({-1, 0, +1})
### Method
```
f32_weight β†’ hard threshold
β†’ +1 if w > +tau
β†’ -1 if w < -tau
β†’ 0 otherwise
β†’ bit-packed ternary
At inference:
result = input @ ternary_weight
= sum(sign(w_i) Γ— x_i) // pure addition/subtraction
```
### Why It Fails
| Metric | Q4 | Ternary |
|--------|-----|---------|
| Cos vs f32 | 0.998 | ~0.85 |
| Compression | 8x vs f32 | 16x vs f32 |
**Root cause**: adaLN generates 6 modulation vectors (scale1, shift1, gate1, scale2, shift2, gate2) that multiply and add to the normalized activations. The magnitudes of these modulation vectors matter β€” ternary destroys them.
**What was tried**:
- Various thresholds (Ο„ = 0.5Οƒ, Οƒ, 2Οƒ)
- STE (straight-through estimator) during fine-tuning
- Mixed ternary (ternary weights + fp32 scales)
None recovered quality sufficiently.
## WANDA Pruning
### Method
```
1. Forward pass on calibration set β†’ collect activations
2. Compute WANDA score: s(w) = |w| Γ— ||a||
3. Sort scores, prune bottom N%
4. Fine-tune 1 epoch
5. Re-quantize remaining weights to Q4
```
### Results
| Model | Pruned | Size | Cos | Notes |
|-------|--------|------|-----|-------|
| Baseline Q4 | 0% | 23.6 MB | 0.998 | Reference |
| WANDA 20% | 20% | 22.0 MB | ~0.99 | Pruned, no fine-tune |
| WANDA 40% | 40% | 25.1 MB | ~0.97 | Bitmap overhead exceeds savings |
### Engineering Notes
- 40% pruned model has **larger** binary than 20% due to bitmap overhead
- Skip-zero GEMV needs hardware support (not implemented in PIE SIMD)
- Would benefit from fine-tuning after pruning
## Shared Lessons
1. **Per-channel > per-block for encoders**: Encoders have high per-channel variance. INT8's per-channel precision beats Q4's per-block.
2. **Predictors are quantization-friendly**: adaLN provides implicit normalization. Predictors can use Q4 with minimal quality loss.
3. **Architecture changes beat quantization**: hybrid_ALAL (3.0M params) achieves similar quality to slim_96d (10.2M params) at INT8+Q4. Architecture > precision.
4. **Epoch 1 models have headroom**: All current slim/epoch-1 models will improve with longer training. The quality comparisons should be re-run at convergence.
5. **Hardwired is the limit**: Q4 weights decompose to shift-add operations. Zero multiplications, zero memory fetches. That's the theoretical floor.