Buckets:

Pranav2748's picture
|
download
raw
9.29 kB
# FLUX.2 [klein] 4B → W4A8 SVDQuant — Plan
> **Status:** Active track (supersedes the block-surgery track) · **Dev rig:** 1× A100-80GB
> (SDPA, fake-quant simulation) · **Teacher:** klein **distilled** 4-step 4B (Apache 2.0)
> **Deliverable:** a 4-bit-weight / 8-bit-activation klein-4B that holds teacher-at-4-steps
> quality, as a fast + small inference endpoint; Nunchaku-deployable later.
The block-surgery track (depth-prune single blocks → linear-attn surrogates → distill) is
**shelved** — it topped out at ~1.15–1.26× and was quality-bounded (see
`block_surgery_plan.md` / `block_surgery_todo.md` for the full history). Quantization is a
better, orthogonal axis: it attacks memory *and* compute, is post-training (no fragile
freeze-and-distill loop), and is proven on this exact model family (Nunchaku already ships
SVDQuant weights for klein-9B).
---
## 0. Method — SVDQuant (our own fake-quant)
SVDQuant (Li et al., ICLR 2025, arXiv 2411.05007 / Nunchaku / DeepCompressor). Per Linear
`y = x Wᵀ + b` (W: (out,in)):
1. **Smooth** — migrate per-channel activation outliers into the weights:
`x̂ = x / s`, `Ŵ = W ⊙ s`, with `s_j = max|x_j|^α / max|W_:,j|^(1-α)` (SmoothQuant/AWQ
form). Exactly invariant: `x̂ Ŵᵀ = x Wᵀ`. (`max|x_j|` from calibration.)
2. **Absorb** — keep the top-**rank** components as a **16-bit low-rank branch** `L` that
soaks up the dominant (outlier) energy. Residual `R = Ŵ − L` is smooth → easy to low-bit.
Two ways to pick `L` (both implemented; `WHITEN` env):
- **plain** = SVD of `Ŵ`, minimizes weight error `‖Ŵ−L‖` (base SVDQuant paper).
- **whitened** (DEFAULT) = activation-aware SVD minimizing OUTPUT error `‖X̂(Ŵ−L)‖` via the
act-Gram `Ĝ=X̂ᵀX̂` (ASVD/SVD-LLM idea; **our add, not in the paper**). Robust eigen-√ of `Ĝ`,
NOT Cholesky (bf16 Gram is non-PD). This is what makes rank a real quality knob.
3. **Refine** (DEFAULT 3 iters; SVDQuant §4.2) — re-fit `L` to `Ŵ−Q(R)` to absorb the 4-bit
rounding error, keep the best iterate. Free at inference; build-time only.
4. **Quantize** — residual weights to **4-bit** (group-wise along in-dim), activations to
**8-bit** (per-token dynamic). The low-rank branch keeps full-precision `x̂`.
Forward: `y = (x̂·L)[16-bit] + Q8(x̂)·Q4(R)ᵀ[low-bit] + b`. The low-rank branch is the
high-precision *error-correction* path; rank ≪ 3072 so it is ~free on compute (memory only).
**Full math + the Cholesky→eigh fix + ablation caveats: `report/QUANT_REPORT.md`.**
**RESULTS — full 4×3 grid (2026-06-01, 300-img calib; supersedes the earlier 100-calib L-shape).**
Every method × every rank, one fixed calib (`data/monet_cache`), so all 12 cells are comparable
(full table + per-cell metrics in `RESULTS.md`, report `report/QUANT_REPORT.{md,pdf}`). Best per rank:
r16 whiten+refine 0.0556 / r32 whiten+refine 0.0476 / r64 **plain+refine 0.0446** (@ 3.67/3.59/3.43×).
**Overall best = r64 plain+refine 0.0446 @ 3.43×**, all ~4–5× below the surgery frontier (0.23–0.32).
Headline finding: the upgrades are NOT monotonic — **refine** is the reliable lever; **whitening alone
is non-monotonic in rank** (overfits the 300-img Gram: hurts r16 & r64, helps r32) and only helps paired
with refine or at moderate rank. At high rank, plain+refine (no Gram) wins.
**⚠️ Bigger finding (2026-06-01 mechanism ablation): SmoothQuant (α=0.5) is HARMFUL at W4A8.** The RTN
floor (no smooth, no SVD, `SMOOTH=0`/s=1) = **0.0573** — beats SmoothQuant rank-0 (0.0729, −21%) AND the
smoothed SVD r16/r32 cells. SmoothQuant migrates outliers into the weights, which only helps low-bit
*activations*; at A8 the acts are already fine, so it just widens the 4-bit *weight* spread (worst-layer
wrecon 0.15→0.26). **The whole α=0.5 grid is mis-tuned.** → #1 next experiment: re-run no-smooth/low-α
(beats the 2000-img re-sweep in priority). See `TODO.md` / `RESULTS.md`.
**Why W4A8 (not the paper's W4A4):** 8-bit activations are far more forgiving — the right
conservative lead for a **4-step distilled** model where per-step quant error doesn't
average out over many steps. We lead with W4A8 + rank 32, then sweep (rank, α, A4). W4A4
(the Nunchaku sweet spot) is a later comparison.
**Fake-quant, not real kernels.** On the A100 we *simulate* (quantize→dequantize in bf16)
to map the **quality** frontier via the existing held-out velocity-loss metric. We are not
training with DeepCompressor (effectively unmaintained) or running Nunchaku kernels here.
The decomposition is kernel-agnostic, so a checkpoint built here is the same math a fused
W4A8 kernel would run — making it Nunchaku-ready later (an open task, deferred).
---
## 1. What we quantize
- **All `nn.Linear` inside the 25 transformer blocks** (`transformer_blocks.*` +
`single_transformer_blocks.*`) — uniform W4A8, **no per-block precision games**
(mixing precisions is slower/messier and we have no evidence the 5 cross-modal double
blocks need more; revisit only if a fragility shows up).
- **Kept in bf16:** norms (not Linear), embedders (`x_embedder`, `context_embedder`,
time/text embedders), final `proj_out`. Tiny share of compute; quantizing the output
projection tends to cost quality for ~nothing.
- Knobs (`flux2distill/config.py::QuantConfig`): `rank=32`, `alpha=0.5`, `w_bits=4`,
`a_bits=8`, `w_group=64`.
## 2. Calibration data
- **5–7k full-res `jasperai/monet` images via `url`** (NOT the thumbnail `image` field) →
VAE latents + captions, cached to `data/monet_calib/` by `scripts/11_cache_calib.py`.
Download is slow but worth it — more images = better activation statistics.
- Calibration forward = real latents noised across sigmas with their captions (the **same**
forward the eval metric uses), collecting per-input-channel activation abs-max.
## 3. Metric & continuity
- Primary: **held-out velocity-matching loss vs the teacher** on the fixed first-16 of
`data/monet_cache` — identical construction to `08_train_recover.py`, so the quant number
is **directly comparable** to the block-surgery frontier (teacher-vs-student loss; a good
W4A8 quant should land *well below* the surgery numbers ~0.23–0.32, since quant perturbs
far less than dropping whole blocks). Also report velocity rel-L2 and teacher-vs-quant
image montages.
---
## 4. Pipeline (scripts)
```bash
source .venv/bin/activate; export PYTHONPATH=. # system python has NO torch
python3 scripts/11_cache_calib.py 2000 data/monet_calib # full-res monet via url (already cached)
RANK=32 ALPHA=0.5 WGROUP=64 N_CALIB=100 WHITEN=1 REFINE=3 \
python3 -u scripts/12_build_svdquant.py # calibrate(+Gram)->whiten/SVD/refine/4-bit
python3 -u scripts/13_eval_svdquant.py outputs/svdquant_r32_a0.5_w4a8_whiten # vel-loss + montages
```
`flux2distill/svdquant.py` holds the module (`SVDQuantLinear`), the fake-quant primitives,
`collect_act_stats` (absmax+Gram hooks), and the build/load surgery (`apply_svdquant_from_stats` /
`apply_svdquant_empty`). `quant_config.json` is the build↔load contract (records `specs`
= per-layer shape/rank so empty modules rebuild before `load_state_dict`). Build env knobs:
`RANK ALPHA WBITS ABITS WGROUP N_CALIB WHITEN REFINE CALIB_DIR MB OUT`.
## 5. Roadmap (see `TODO.md` for the live checklist)
1.**W4A8 baseline + rank sweep** {16,32,64}, whitened+refine — DONE (results in §0 / RESULTS.md).
2.**Whitening + iterative refinement** implemented and measured.
3. 🔄 **ACTIVE (2026-06-01): full method×rank GRID at fixed calib.** ranks {16,32,64} ×
{plain, plain+whiten, plain+refine, plain+whiten+refine} = 12 builds, all at **300-calib
from `data/monet_cache`** (closes the L-shape: every cell comparable; runs the missing
`plain+refine`). Also expanded the visual eval to **8 probe prompts**. See `TODO.md`.
4. **DEFERRED — 2000-image calib re-sweep.** After the 300-calib grid picks the method, re-run
ranks {16,32,64} at the chosen method on the **full-res 2000-image** calib (`scripts/11` →
`data/monet_calib`) to confirm the choice holds with richer activation statistics. (Needs the
image download; do NOT run until the 300-calib grid is read.)
5. **Sweep** α, group size, more refine iters; then **W4A4** for the aggressive corner.
6. **One-time stats cache** (absmax+Gram independent of rank/α/group) → cheap sweeps.
7. **Nunchaku-ready export** + **real low-bit kernels** on Ada/Blackwell → the throughput number.
## 6. Open items
- Real speedup needs fused low-bit kernels (Nunchaku, Ada/Blackwell — NOT our A100 sm_80). On
this box we only have the quality story; throughput/VRAM land on a Blackwell / B200 run.
- ~~The method ablation is an L-shape, not a grid~~ — **CLOSED (2026-06-01):** the full 4×3 grid ran.
New open item it surfaced: **whitening is unreliable at 300-calib** (non-monotonic in rank) → the
deferred 2000-img re-sweep (§5, item 4) tests whether richer calib stabilizes it.
- Whether the 5 double (cross-modal) blocks tolerate W4A8 as well as the singles — only revisit
with per-block precision if eval/montages show binding failures.
- α for W4A8: 8-bit activations may want *less* smoothing (less weight burden) — sweep.

Xet Storage Details

Size:
9.29 kB
·
Xet hash:
004f50f88ca51efa1011606007259754f5aff53b80465cb7f16b9aea94f4e1d6

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.