File size: 5,574 Bytes
6cdcc30 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | # Quantization Deep Dive
## Overview
We explored multiple quantization strategies for LeWM. This document records what was tried, why it worked or didn't, and the engineering trade-offs.
## Formats Evaluated
| Format | Bits/weight | Compression | Cos vs f32 | Status |
|--------|-------------|-------------|-------------|--------|
| **INT8 (encoder)** | 8 | 4x | 0.9999 | Production |
| **Q4 (predictor)** | 4 | 2x more | 0.998 | Production |
| **INT8+Q4 (full)** | mixed | 5x total | 0.999 | **Production** |
| Q4 encoder only | 4 | 4x | 0.93 | Rejected |
| Ternary {-1,0,+1} | 2 | 8x | ~0.85 | Rejected |
| WANDA 20% | sparse Q4 | 20% fewer | ~0.99 | Experimental |
| WANDA 40% | sparse Q4 | 40% fewer | ~0.97 | Experimental |
## INT8 Per-Channel (Encoder)
### Method
```
f32_weight [out, in] β per-output-channel quantization
β int8_weight [out, in]
β per-output-channel scales [out]
At inference:
result = input @ dequant(weight) Γ scales
= input @ (int8_weight + zero_point) Γ scales
= (input @ int8_weight) Γ scales // zero_point = 0 for symmetric
```
### Why It Works
1. **Per-channel** preserves channel-level statistics. Each output channel gets its own scale.
2. **Symmetric** (zero_point = 0) avoids the overhead of asymmetric quantization.
3. **INT8 GEMV** on PIE SIMD: 16-wide multiply-accumulate per cycle.
4. **Encoder activations** have predictable dynamic range after LayerNorm.
### Quality
| Layer | cos vs f32 |
|-------|-----------|
| patch_embed | 0.9999 |
| encoder layer 0 | 0.9999 |
| encoder layer 5 | 0.9998 |
| encoder.proj | 0.9999 |
| **Total** | **0.9999** |
### Engineering Notes
- Activation quantization is dynamic (per-row at inference time), not static.
- QKV shared quantization: same normalized input quantized once for Q, K, V.
- Scales stored as f32 (4 bytes per channel) β negligible overhead.
## Q4 Block (Predictor)
### Method
```
f32_weight [out, in] β per-32-element-block quantization
β nibble-packed weight data
β per-block f16 scales
At inference:
for each block of 32:
unpack nibbles β int8 [-8, 7]
dot = simd_dot(input_block, unpacked)
result += dot Γ block_scale
```
### Why It Works
1. **Per-block** (32 elements) matches the predictor's adaLN normalization.
2. **adaLN modulation** provides implicit normalization β weights don't need per-channel precision.
3. **Smaller layers** = less error accumulation. 4-layer predictor vs 6-layer encoder.
4. **PIE SIMD** handles the nibble unpack + dot in tight loops.
### Why Full Q4 (Encoder + Predictor) Fails
When we skip INT8 and quantize the encoder to Q4:
| Layer | INT8 cos | Q4 cos |
|-------|----------|--------|
| encoder | 0.9999 | 0.93 |
| predictor | 0.998 | 0.998 |
| **Total** | **0.999** | **0.93** |
**Root cause**: ViT encoder has high dynamic range in intermediate activations. The 32-element block granularity doesn't align with the encoder's channel statistics. INT8's per-channel precision is essential.
### Engineering Notes
- Nibbles decode as `value - 8` β range [-8, 7]
- Block scales stored as f16 (2 bytes per block) β 64 bytes per 32Γ32 block
- Zero weights are rare in Q4 (<1%) β skip-zero optimization not implemented
## Ternary ({-1, 0, +1})
### Method
```
f32_weight β hard threshold
β +1 if w > +tau
β -1 if w < -tau
β 0 otherwise
β bit-packed ternary
At inference:
result = input @ ternary_weight
= sum(sign(w_i) Γ x_i) // pure addition/subtraction
```
### Why It Fails
| Metric | Q4 | Ternary |
|--------|-----|---------|
| Cos vs f32 | 0.998 | ~0.85 |
| Compression | 8x vs f32 | 16x vs f32 |
**Root cause**: adaLN generates 6 modulation vectors (scale1, shift1, gate1, scale2, shift2, gate2) that multiply and add to the normalized activations. The magnitudes of these modulation vectors matter β ternary destroys them.
**What was tried**:
- Various thresholds (Ο = 0.5Ο, Ο, 2Ο)
- STE (straight-through estimator) during fine-tuning
- Mixed ternary (ternary weights + fp32 scales)
None recovered quality sufficiently.
## WANDA Pruning
### Method
```
1. Forward pass on calibration set β collect activations
2. Compute WANDA score: s(w) = |w| Γ ||a||
3. Sort scores, prune bottom N%
4. Fine-tune 1 epoch
5. Re-quantize remaining weights to Q4
```
### Results
| Model | Pruned | Size | Cos | Notes |
|-------|--------|------|-----|-------|
| Baseline Q4 | 0% | 23.6 MB | 0.998 | Reference |
| WANDA 20% | 20% | 22.0 MB | ~0.99 | Pruned, no fine-tune |
| WANDA 40% | 40% | 25.1 MB | ~0.97 | Bitmap overhead exceeds savings |
### Engineering Notes
- 40% pruned model has **larger** binary than 20% due to bitmap overhead
- Skip-zero GEMV needs hardware support (not implemented in PIE SIMD)
- Would benefit from fine-tuning after pruning
## Shared Lessons
1. **Per-channel > per-block for encoders**: Encoders have high per-channel variance. INT8's per-channel precision beats Q4's per-block.
2. **Predictors are quantization-friendly**: adaLN provides implicit normalization. Predictors can use Q4 with minimal quality loss.
3. **Architecture changes beat quantization**: hybrid_ALAL (3.0M params) achieves similar quality to slim_96d (10.2M params) at INT8+Q4. Architecture > precision.
4. **Epoch 1 models have headroom**: All current slim/epoch-1 models will improve with longer training. The quality comparisons should be re-run at convergence.
5. **Hardwired is the limit**: Q4 weights decompose to shift-add operations. Zero multiplications, zero memory fetches. That's the theoretical floor.
|