Buckets:
FLUX.2 [klein] 4B → W4A8 SVDQuant — Plan
Status: Active track (supersedes the block-surgery track) · Dev rig: 1× A100-80GB (SDPA, fake-quant simulation) · Teacher: klein distilled 4-step 4B (Apache 2.0) Deliverable: a 4-bit-weight / 8-bit-activation klein-4B that holds teacher-at-4-steps quality, as a fast + small inference endpoint; Nunchaku-deployable later.
The block-surgery track (depth-prune single blocks → linear-attn surrogates → distill) is
shelved — it topped out at ~1.15–1.26× and was quality-bounded (see
block_surgery_plan.md / block_surgery_todo.md for the full history). Quantization is a
better, orthogonal axis: it attacks memory and compute, is post-training (no fragile
freeze-and-distill loop), and is proven on this exact model family (Nunchaku already ships
SVDQuant weights for klein-9B).
0. Method — SVDQuant (our own fake-quant)
SVDQuant (Li et al., ICLR 2025, arXiv 2411.05007 / Nunchaku / DeepCompressor). Per Linear
y = x Wᵀ + b (W: (out,in)):
- Smooth — migrate per-channel activation outliers into the weights:
x̂ = x / s,Ŵ = W ⊙ s, withs_j = max|x_j|^α / max|W_:,j|^(1-α)(SmoothQuant/AWQ form). Exactly invariant:x̂ Ŵᵀ = x Wᵀ. (max|x_j|from calibration.) - Absorb — keep the top-rank components as a 16-bit low-rank branch
Lthat soaks up the dominant (outlier) energy. ResidualR = Ŵ − Lis smooth → easy to low-bit. Two ways to pickL(both implemented;WHITENenv):- plain = SVD of
Ŵ, minimizes weight error‖Ŵ−L‖(base SVDQuant paper). - whitened (DEFAULT) = activation-aware SVD minimizing OUTPUT error
‖X̂(Ŵ−L)‖via the act-GramĜ=X̂ᵀX̂(ASVD/SVD-LLM idea; our add, not in the paper). Robust eigen-√ ofĜ, NOT Cholesky (bf16 Gram is non-PD). This is what makes rank a real quality knob.
- plain = SVD of
- Refine (DEFAULT 3 iters; SVDQuant §4.2) — re-fit
LtoŴ−Q(R)to absorb the 4-bit rounding error, keep the best iterate. Free at inference; build-time only. - Quantize — residual weights to 4-bit (group-wise along in-dim), activations to
8-bit (per-token dynamic). The low-rank branch keeps full-precision
x̂.
Forward: y = (x̂·L)[16-bit] + Q8(x̂)·Q4(R)ᵀ[low-bit] + b. The low-rank branch is the
high-precision error-correction path; rank ≪ 3072 so it is ~free on compute (memory only).
Full math + the Cholesky→eigh fix + ablation caveats: report/QUANT_REPORT.md.
RESULTS — full 4×3 grid (2026-06-01, 300-img calib; supersedes the earlier 100-calib L-shape).
Every method × every rank, one fixed calib (data/monet_cache), so all 12 cells are comparable
(full table + per-cell metrics in RESULTS.md, report report/QUANT_REPORT.{md,pdf}). Best per rank:
r16 whiten+refine 0.0556 / r32 whiten+refine 0.0476 / r64 plain+refine 0.0446 (@ 3.67/3.59/3.43×).
Overall best = r64 plain+refine 0.0446 @ 3.43×, all ~4–5× below the surgery frontier (0.23–0.32).
Headline finding: the upgrades are NOT monotonic — refine is the reliable lever; whitening alone
is non-monotonic in rank (overfits the 300-img Gram: hurts r16 & r64, helps r32) and only helps paired
with refine or at moderate rank. At high rank, plain+refine (no Gram) wins.
⚠️ Bigger finding (2026-06-01 mechanism ablation): SmoothQuant (α=0.5) is HARMFUL at W4A8. The RTN
floor (no smooth, no SVD, SMOOTH=0/s=1) = 0.0573 — beats SmoothQuant rank-0 (0.0729, −21%) AND the
smoothed SVD r16/r32 cells. SmoothQuant migrates outliers into the weights, which only helps low-bit
activations; at A8 the acts are already fine, so it just widens the 4-bit weight spread (worst-layer
wrecon 0.15→0.26). The whole α=0.5 grid is mis-tuned. → #1 next experiment: re-run no-smooth/low-α
(beats the 2000-img re-sweep in priority). See TODO.md / RESULTS.md.
Why W4A8 (not the paper's W4A4): 8-bit activations are far more forgiving — the right conservative lead for a 4-step distilled model where per-step quant error doesn't average out over many steps. We lead with W4A8 + rank 32, then sweep (rank, α, A4). W4A4 (the Nunchaku sweet spot) is a later comparison.
Fake-quant, not real kernels. On the A100 we simulate (quantize→dequantize in bf16) to map the quality frontier via the existing held-out velocity-loss metric. We are not training with DeepCompressor (effectively unmaintained) or running Nunchaku kernels here. The decomposition is kernel-agnostic, so a checkpoint built here is the same math a fused W4A8 kernel would run — making it Nunchaku-ready later (an open task, deferred).
1. What we quantize
- All
nn.Linearinside the 25 transformer blocks (transformer_blocks.*+single_transformer_blocks.*) — uniform W4A8, no per-block precision games (mixing precisions is slower/messier and we have no evidence the 5 cross-modal double blocks need more; revisit only if a fragility shows up). - Kept in bf16: norms (not Linear), embedders (
x_embedder,context_embedder, time/text embedders), finalproj_out. Tiny share of compute; quantizing the output projection tends to cost quality for ~nothing. - Knobs (
flux2distill/config.py::QuantConfig):rank=32,alpha=0.5,w_bits=4,a_bits=8,w_group=64.
2. Calibration data
- 5–7k full-res
jasperai/monetimages viaurl(NOT the thumbnailimagefield) → VAE latents + captions, cached todata/monet_calib/byscripts/11_cache_calib.py. Download is slow but worth it — more images = better activation statistics. - Calibration forward = real latents noised across sigmas with their captions (the same forward the eval metric uses), collecting per-input-channel activation abs-max.
3. Metric & continuity
- Primary: held-out velocity-matching loss vs the teacher on the fixed first-16 of
data/monet_cache— identical construction to08_train_recover.py, so the quant number is directly comparable to the block-surgery frontier (teacher-vs-student loss; a good W4A8 quant should land well below the surgery numbers ~0.23–0.32, since quant perturbs far less than dropping whole blocks). Also report velocity rel-L2 and teacher-vs-quant image montages.
4. Pipeline (scripts)
source .venv/bin/activate; export PYTHONPATH=. # system python has NO torch
python3 scripts/11_cache_calib.py 2000 data/monet_calib # full-res monet via url (already cached)
RANK=32 ALPHA=0.5 WGROUP=64 N_CALIB=100 WHITEN=1 REFINE=3 \
python3 -u scripts/12_build_svdquant.py # calibrate(+Gram)->whiten/SVD/refine/4-bit
python3 -u scripts/13_eval_svdquant.py outputs/svdquant_r32_a0.5_w4a8_whiten # vel-loss + montages
flux2distill/svdquant.py holds the module (SVDQuantLinear), the fake-quant primitives,
collect_act_stats (absmax+Gram hooks), and the build/load surgery (apply_svdquant_from_stats /
apply_svdquant_empty). quant_config.json is the build↔load contract (records specs
= per-layer shape/rank so empty modules rebuild before load_state_dict). Build env knobs:
RANK ALPHA WBITS ABITS WGROUP N_CALIB WHITEN REFINE CALIB_DIR MB OUT.
5. Roadmap (see TODO.md for the live checklist)
- ✅ W4A8 baseline + rank sweep {16,32,64}, whitened+refine — DONE (results in §0 / RESULTS.md).
- ✅ Whitening + iterative refinement implemented and measured.
- 🔄 ACTIVE (2026-06-01): full method×rank GRID at fixed calib. ranks {16,32,64} ×
{plain, plain+whiten, plain+refine, plain+whiten+refine} = 12 builds, all at 300-calib
from
data/monet_cache(closes the L-shape: every cell comparable; runs the missingplain+refine). Also expanded the visual eval to 8 probe prompts. SeeTODO.md. - DEFERRED — 2000-image calib re-sweep. After the 300-calib grid picks the method, re-run
ranks {16,32,64} at the chosen method on the full-res 2000-image calib (
scripts/11→data/monet_calib) to confirm the choice holds with richer activation statistics. (Needs the image download; do NOT run until the 300-calib grid is read.) - Sweep α, group size, more refine iters; then W4A4 for the aggressive corner.
- One-time stats cache (absmax+Gram independent of rank/α/group) → cheap sweeps.
- Nunchaku-ready export + real low-bit kernels on Ada/Blackwell → the throughput number.
6. Open items
- Real speedup needs fused low-bit kernels (Nunchaku, Ada/Blackwell — NOT our A100 sm_80). On this box we only have the quality story; throughput/VRAM land on a Blackwell / B200 run.
The method ablation is an L-shape, not a grid— CLOSED (2026-06-01): the full 4×3 grid ran. New open item it surfaced: whitening is unreliable at 300-calib (non-monotonic in rank) → the deferred 2000-img re-sweep (§5, item 4) tests whether richer calib stabilizes it.- Whether the 5 double (cross-modal) blocks tolerate W4A8 as well as the singles — only revisit with per-block precision if eval/montages show binding failures.
- α for W4A8: 8-bit activations may want less smoothing (less weight burden) — sweep.
Xet Storage Details
- Size:
- 9.29 kB
- Xet hash:
- 004f50f88ca51efa1011606007259754f5aff53b80465cb7f16b9aea94f4e1d6
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.