Buckets:

Pranav2748's picture
|
download
raw
9.29 kB

FLUX.2 [klein] 4B → W4A8 SVDQuant — Plan

Status: Active track (supersedes the block-surgery track) · Dev rig: 1× A100-80GB (SDPA, fake-quant simulation) · Teacher: klein distilled 4-step 4B (Apache 2.0) Deliverable: a 4-bit-weight / 8-bit-activation klein-4B that holds teacher-at-4-steps quality, as a fast + small inference endpoint; Nunchaku-deployable later.

The block-surgery track (depth-prune single blocks → linear-attn surrogates → distill) is shelved — it topped out at ~1.15–1.26× and was quality-bounded (see block_surgery_plan.md / block_surgery_todo.md for the full history). Quantization is a better, orthogonal axis: it attacks memory and compute, is post-training (no fragile freeze-and-distill loop), and is proven on this exact model family (Nunchaku already ships SVDQuant weights for klein-9B).


0. Method — SVDQuant (our own fake-quant)

SVDQuant (Li et al., ICLR 2025, arXiv 2411.05007 / Nunchaku / DeepCompressor). Per Linear y = x Wᵀ + b (W: (out,in)):

  1. Smooth — migrate per-channel activation outliers into the weights: x̂ = x / s, Ŵ = W ⊙ s, with s_j = max|x_j|^α / max|W_:,j|^(1-α) (SmoothQuant/AWQ form). Exactly invariant: x̂ Ŵᵀ = x Wᵀ. (max|x_j| from calibration.)
  2. Absorb — keep the top-rank components as a 16-bit low-rank branch L that soaks up the dominant (outlier) energy. Residual R = Ŵ − L is smooth → easy to low-bit. Two ways to pick L (both implemented; WHITEN env):
    • plain = SVD of Ŵ, minimizes weight error ‖Ŵ−L‖ (base SVDQuant paper).
    • whitened (DEFAULT) = activation-aware SVD minimizing OUTPUT error ‖X̂(Ŵ−L)‖ via the act-Gram Ĝ=X̂ᵀX̂ (ASVD/SVD-LLM idea; our add, not in the paper). Robust eigen-√ of Ĝ, NOT Cholesky (bf16 Gram is non-PD). This is what makes rank a real quality knob.
  3. Refine (DEFAULT 3 iters; SVDQuant §4.2) — re-fit L to Ŵ−Q(R) to absorb the 4-bit rounding error, keep the best iterate. Free at inference; build-time only.
  4. Quantize — residual weights to 4-bit (group-wise along in-dim), activations to 8-bit (per-token dynamic). The low-rank branch keeps full-precision .

Forward: y = (x̂·L)[16-bit] + Q8(x̂)·Q4(R)ᵀ[low-bit] + b. The low-rank branch is the high-precision error-correction path; rank ≪ 3072 so it is ~free on compute (memory only). Full math + the Cholesky→eigh fix + ablation caveats: report/QUANT_REPORT.md.

RESULTS — full 4×3 grid (2026-06-01, 300-img calib; supersedes the earlier 100-calib L-shape). Every method × every rank, one fixed calib (data/monet_cache), so all 12 cells are comparable (full table + per-cell metrics in RESULTS.md, report report/QUANT_REPORT.{md,pdf}). Best per rank: r16 whiten+refine 0.0556 / r32 whiten+refine 0.0476 / r64 plain+refine 0.0446 (@ 3.67/3.59/3.43×). Overall best = r64 plain+refine 0.0446 @ 3.43×, all ~4–5× below the surgery frontier (0.23–0.32). Headline finding: the upgrades are NOT monotonic — refine is the reliable lever; whitening alone is non-monotonic in rank (overfits the 300-img Gram: hurts r16 & r64, helps r32) and only helps paired with refine or at moderate rank. At high rank, plain+refine (no Gram) wins.

⚠️ Bigger finding (2026-06-01 mechanism ablation): SmoothQuant (α=0.5) is HARMFUL at W4A8. The RTN floor (no smooth, no SVD, SMOOTH=0/s=1) = 0.0573 — beats SmoothQuant rank-0 (0.0729, −21%) AND the smoothed SVD r16/r32 cells. SmoothQuant migrates outliers into the weights, which only helps low-bit activations; at A8 the acts are already fine, so it just widens the 4-bit weight spread (worst-layer wrecon 0.15→0.26). The whole α=0.5 grid is mis-tuned. → #1 next experiment: re-run no-smooth/low-α (beats the 2000-img re-sweep in priority). See TODO.md / RESULTS.md.

Why W4A8 (not the paper's W4A4): 8-bit activations are far more forgiving — the right conservative lead for a 4-step distilled model where per-step quant error doesn't average out over many steps. We lead with W4A8 + rank 32, then sweep (rank, α, A4). W4A4 (the Nunchaku sweet spot) is a later comparison.

Fake-quant, not real kernels. On the A100 we simulate (quantize→dequantize in bf16) to map the quality frontier via the existing held-out velocity-loss metric. We are not training with DeepCompressor (effectively unmaintained) or running Nunchaku kernels here. The decomposition is kernel-agnostic, so a checkpoint built here is the same math a fused W4A8 kernel would run — making it Nunchaku-ready later (an open task, deferred).


1. What we quantize

  • All nn.Linear inside the 25 transformer blocks (transformer_blocks.* + single_transformer_blocks.*) — uniform W4A8, no per-block precision games (mixing precisions is slower/messier and we have no evidence the 5 cross-modal double blocks need more; revisit only if a fragility shows up).
  • Kept in bf16: norms (not Linear), embedders (x_embedder, context_embedder, time/text embedders), final proj_out. Tiny share of compute; quantizing the output projection tends to cost quality for ~nothing.
  • Knobs (flux2distill/config.py::QuantConfig): rank=32, alpha=0.5, w_bits=4, a_bits=8, w_group=64.

2. Calibration data

  • 5–7k full-res jasperai/monet images via url (NOT the thumbnail image field) → VAE latents + captions, cached to data/monet_calib/ by scripts/11_cache_calib.py. Download is slow but worth it — more images = better activation statistics.
  • Calibration forward = real latents noised across sigmas with their captions (the same forward the eval metric uses), collecting per-input-channel activation abs-max.

3. Metric & continuity

  • Primary: held-out velocity-matching loss vs the teacher on the fixed first-16 of data/monet_cache — identical construction to 08_train_recover.py, so the quant number is directly comparable to the block-surgery frontier (teacher-vs-student loss; a good W4A8 quant should land well below the surgery numbers ~0.23–0.32, since quant perturbs far less than dropping whole blocks). Also report velocity rel-L2 and teacher-vs-quant image montages.

4. Pipeline (scripts)

source .venv/bin/activate; export PYTHONPATH=.    # system python has NO torch
python3 scripts/11_cache_calib.py 2000 data/monet_calib   # full-res monet via url (already cached)
RANK=32 ALPHA=0.5 WGROUP=64 N_CALIB=100 WHITEN=1 REFINE=3 \
  python3 -u scripts/12_build_svdquant.py                 # calibrate(+Gram)->whiten/SVD/refine/4-bit
python3 -u scripts/13_eval_svdquant.py outputs/svdquant_r32_a0.5_w4a8_whiten   # vel-loss + montages

flux2distill/svdquant.py holds the module (SVDQuantLinear), the fake-quant primitives, collect_act_stats (absmax+Gram hooks), and the build/load surgery (apply_svdquant_from_stats / apply_svdquant_empty). quant_config.json is the build↔load contract (records specs = per-layer shape/rank so empty modules rebuild before load_state_dict). Build env knobs: RANK ALPHA WBITS ABITS WGROUP N_CALIB WHITEN REFINE CALIB_DIR MB OUT.

5. Roadmap (see TODO.md for the live checklist)

  1. W4A8 baseline + rank sweep {16,32,64}, whitened+refine — DONE (results in §0 / RESULTS.md).
  2. Whitening + iterative refinement implemented and measured.
  3. 🔄 ACTIVE (2026-06-01): full method×rank GRID at fixed calib. ranks {16,32,64} × {plain, plain+whiten, plain+refine, plain+whiten+refine} = 12 builds, all at 300-calib from data/monet_cache (closes the L-shape: every cell comparable; runs the missing plain+refine). Also expanded the visual eval to 8 probe prompts. See TODO.md.
  4. DEFERRED — 2000-image calib re-sweep. After the 300-calib grid picks the method, re-run ranks {16,32,64} at the chosen method on the full-res 2000-image calib (scripts/11data/monet_calib) to confirm the choice holds with richer activation statistics. (Needs the image download; do NOT run until the 300-calib grid is read.)
  5. Sweep α, group size, more refine iters; then W4A4 for the aggressive corner.
  6. One-time stats cache (absmax+Gram independent of rank/α/group) → cheap sweeps.
  7. Nunchaku-ready export + real low-bit kernels on Ada/Blackwell → the throughput number.

6. Open items

  • Real speedup needs fused low-bit kernels (Nunchaku, Ada/Blackwell — NOT our A100 sm_80). On this box we only have the quality story; throughput/VRAM land on a Blackwell / B200 run.
  • The method ablation is an L-shape, not a gridCLOSED (2026-06-01): the full 4×3 grid ran. New open item it surfaced: whitening is unreliable at 300-calib (non-monotonic in rank) → the deferred 2000-img re-sweep (§5, item 4) tests whether richer calib stabilizes it.
  • Whether the 5 double (cross-modal) blocks tolerate W4A8 as well as the singles — only revisit with per-block precision if eval/montages show binding failures.
  • α for W4A8: 8-bit activations may want less smoothing (less weight burden) — sweep.

Xet Storage Details

Size:
9.29 kB
·
Xet hash:
004f50f88ca51efa1011606007259754f5aff53b80465cb7f16b9aea94f4e1d6

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.