Buckets:
| # FLUX.2 [klein] 4B → W4A8 SVDQuant — Plan | |
| > **Status:** Active track (supersedes the block-surgery track) · **Dev rig:** 1× A100-80GB | |
| > (SDPA, fake-quant simulation) · **Teacher:** klein **distilled** 4-step 4B (Apache 2.0) | |
| > **Deliverable:** a 4-bit-weight / 8-bit-activation klein-4B that holds teacher-at-4-steps | |
| > quality, as a fast + small inference endpoint; Nunchaku-deployable later. | |
| The block-surgery track (depth-prune single blocks → linear-attn surrogates → distill) is | |
| **shelved** — it topped out at ~1.15–1.26× and was quality-bounded (see | |
| `block_surgery_plan.md` / `block_surgery_todo.md` for the full history). Quantization is a | |
| better, orthogonal axis: it attacks memory *and* compute, is post-training (no fragile | |
| freeze-and-distill loop), and is proven on this exact model family (Nunchaku already ships | |
| SVDQuant weights for klein-9B). | |
| --- | |
| ## 0. Method — SVDQuant (our own fake-quant) | |
| SVDQuant (Li et al., ICLR 2025, arXiv 2411.05007 / Nunchaku / DeepCompressor). Per Linear | |
| `y = x Wᵀ + b` (W: (out,in)): | |
| 1. **Smooth** — migrate per-channel activation outliers into the weights: | |
| `x̂ = x / s`, `Ŵ = W ⊙ s`, with `s_j = max|x_j|^α / max|W_:,j|^(1-α)` (SmoothQuant/AWQ | |
| form). Exactly invariant: `x̂ Ŵᵀ = x Wᵀ`. (`max|x_j|` from calibration.) | |
| 2. **Absorb** — keep the top-**rank** components as a **16-bit low-rank branch** `L` that | |
| soaks up the dominant (outlier) energy. Residual `R = Ŵ − L` is smooth → easy to low-bit. | |
| Two ways to pick `L` (both implemented; `WHITEN` env): | |
| - **plain** = SVD of `Ŵ`, minimizes weight error `‖Ŵ−L‖` (base SVDQuant paper). | |
| - **whitened** (DEFAULT) = activation-aware SVD minimizing OUTPUT error `‖X̂(Ŵ−L)‖` via the | |
| act-Gram `Ĝ=X̂ᵀX̂` (ASVD/SVD-LLM idea; **our add, not in the paper**). Robust eigen-√ of `Ĝ`, | |
| NOT Cholesky (bf16 Gram is non-PD). This is what makes rank a real quality knob. | |
| 3. **Refine** (DEFAULT 3 iters; SVDQuant §4.2) — re-fit `L` to `Ŵ−Q(R)` to absorb the 4-bit | |
| rounding error, keep the best iterate. Free at inference; build-time only. | |
| 4. **Quantize** — residual weights to **4-bit** (group-wise along in-dim), activations to | |
| **8-bit** (per-token dynamic). The low-rank branch keeps full-precision `x̂`. | |
| Forward: `y = (x̂·L)[16-bit] + Q8(x̂)·Q4(R)ᵀ[low-bit] + b`. The low-rank branch is the | |
| high-precision *error-correction* path; rank ≪ 3072 so it is ~free on compute (memory only). | |
| **Full math + the Cholesky→eigh fix + ablation caveats: `report/QUANT_REPORT.md`.** | |
| **RESULTS — full 4×3 grid (2026-06-01, 300-img calib; supersedes the earlier 100-calib L-shape).** | |
| Every method × every rank, one fixed calib (`data/monet_cache`), so all 12 cells are comparable | |
| (full table + per-cell metrics in `RESULTS.md`, report `report/QUANT_REPORT.{md,pdf}`). Best per rank: | |
| r16 whiten+refine 0.0556 / r32 whiten+refine 0.0476 / r64 **plain+refine 0.0446** (@ 3.67/3.59/3.43×). | |
| **Overall best = r64 plain+refine 0.0446 @ 3.43×**, all ~4–5× below the surgery frontier (0.23–0.32). | |
| Headline finding: the upgrades are NOT monotonic — **refine** is the reliable lever; **whitening alone | |
| is non-monotonic in rank** (overfits the 300-img Gram: hurts r16 & r64, helps r32) and only helps paired | |
| with refine or at moderate rank. At high rank, plain+refine (no Gram) wins. | |
| **⚠️ Bigger finding (2026-06-01 mechanism ablation): SmoothQuant (α=0.5) is HARMFUL at W4A8.** The RTN | |
| floor (no smooth, no SVD, `SMOOTH=0`/s=1) = **0.0573** — beats SmoothQuant rank-0 (0.0729, −21%) AND the | |
| smoothed SVD r16/r32 cells. SmoothQuant migrates outliers into the weights, which only helps low-bit | |
| *activations*; at A8 the acts are already fine, so it just widens the 4-bit *weight* spread (worst-layer | |
| wrecon 0.15→0.26). **The whole α=0.5 grid is mis-tuned.** → #1 next experiment: re-run no-smooth/low-α | |
| (beats the 2000-img re-sweep in priority). See `TODO.md` / `RESULTS.md`. | |
| **Why W4A8 (not the paper's W4A4):** 8-bit activations are far more forgiving — the right | |
| conservative lead for a **4-step distilled** model where per-step quant error doesn't | |
| average out over many steps. We lead with W4A8 + rank 32, then sweep (rank, α, A4). W4A4 | |
| (the Nunchaku sweet spot) is a later comparison. | |
| **Fake-quant, not real kernels.** On the A100 we *simulate* (quantize→dequantize in bf16) | |
| to map the **quality** frontier via the existing held-out velocity-loss metric. We are not | |
| training with DeepCompressor (effectively unmaintained) or running Nunchaku kernels here. | |
| The decomposition is kernel-agnostic, so a checkpoint built here is the same math a fused | |
| W4A8 kernel would run — making it Nunchaku-ready later (an open task, deferred). | |
| --- | |
| ## 1. What we quantize | |
| - **All `nn.Linear` inside the 25 transformer blocks** (`transformer_blocks.*` + | |
| `single_transformer_blocks.*`) — uniform W4A8, **no per-block precision games** | |
| (mixing precisions is slower/messier and we have no evidence the 5 cross-modal double | |
| blocks need more; revisit only if a fragility shows up). | |
| - **Kept in bf16:** norms (not Linear), embedders (`x_embedder`, `context_embedder`, | |
| time/text embedders), final `proj_out`. Tiny share of compute; quantizing the output | |
| projection tends to cost quality for ~nothing. | |
| - Knobs (`flux2distill/config.py::QuantConfig`): `rank=32`, `alpha=0.5`, `w_bits=4`, | |
| `a_bits=8`, `w_group=64`. | |
| ## 2. Calibration data | |
| - **5–7k full-res `jasperai/monet` images via `url`** (NOT the thumbnail `image` field) → | |
| VAE latents + captions, cached to `data/monet_calib/` by `scripts/11_cache_calib.py`. | |
| Download is slow but worth it — more images = better activation statistics. | |
| - Calibration forward = real latents noised across sigmas with their captions (the **same** | |
| forward the eval metric uses), collecting per-input-channel activation abs-max. | |
| ## 3. Metric & continuity | |
| - Primary: **held-out velocity-matching loss vs the teacher** on the fixed first-16 of | |
| `data/monet_cache` — identical construction to `08_train_recover.py`, so the quant number | |
| is **directly comparable** to the block-surgery frontier (teacher-vs-student loss; a good | |
| W4A8 quant should land *well below* the surgery numbers ~0.23–0.32, since quant perturbs | |
| far less than dropping whole blocks). Also report velocity rel-L2 and teacher-vs-quant | |
| image montages. | |
| --- | |
| ## 4. Pipeline (scripts) | |
| ```bash | |
| source .venv/bin/activate; export PYTHONPATH=. # system python has NO torch | |
| python3 scripts/11_cache_calib.py 2000 data/monet_calib # full-res monet via url (already cached) | |
| RANK=32 ALPHA=0.5 WGROUP=64 N_CALIB=100 WHITEN=1 REFINE=3 \ | |
| python3 -u scripts/12_build_svdquant.py # calibrate(+Gram)->whiten/SVD/refine/4-bit | |
| python3 -u scripts/13_eval_svdquant.py outputs/svdquant_r32_a0.5_w4a8_whiten # vel-loss + montages | |
| ``` | |
| `flux2distill/svdquant.py` holds the module (`SVDQuantLinear`), the fake-quant primitives, | |
| `collect_act_stats` (absmax+Gram hooks), and the build/load surgery (`apply_svdquant_from_stats` / | |
| `apply_svdquant_empty`). `quant_config.json` is the build↔load contract (records `specs` | |
| = per-layer shape/rank so empty modules rebuild before `load_state_dict`). Build env knobs: | |
| `RANK ALPHA WBITS ABITS WGROUP N_CALIB WHITEN REFINE CALIB_DIR MB OUT`. | |
| ## 5. Roadmap (see `TODO.md` for the live checklist) | |
| 1. ✅ **W4A8 baseline + rank sweep** {16,32,64}, whitened+refine — DONE (results in §0 / RESULTS.md). | |
| 2. ✅ **Whitening + iterative refinement** implemented and measured. | |
| 3. 🔄 **ACTIVE (2026-06-01): full method×rank GRID at fixed calib.** ranks {16,32,64} × | |
| {plain, plain+whiten, plain+refine, plain+whiten+refine} = 12 builds, all at **300-calib | |
| from `data/monet_cache`** (closes the L-shape: every cell comparable; runs the missing | |
| `plain+refine`). Also expanded the visual eval to **8 probe prompts**. See `TODO.md`. | |
| 4. **DEFERRED — 2000-image calib re-sweep.** After the 300-calib grid picks the method, re-run | |
| ranks {16,32,64} at the chosen method on the **full-res 2000-image** calib (`scripts/11` → | |
| `data/monet_calib`) to confirm the choice holds with richer activation statistics. (Needs the | |
| image download; do NOT run until the 300-calib grid is read.) | |
| 5. **Sweep** α, group size, more refine iters; then **W4A4** for the aggressive corner. | |
| 6. **One-time stats cache** (absmax+Gram independent of rank/α/group) → cheap sweeps. | |
| 7. **Nunchaku-ready export** + **real low-bit kernels** on Ada/Blackwell → the throughput number. | |
| ## 6. Open items | |
| - Real speedup needs fused low-bit kernels (Nunchaku, Ada/Blackwell — NOT our A100 sm_80). On | |
| this box we only have the quality story; throughput/VRAM land on a Blackwell / B200 run. | |
| - ~~The method ablation is an L-shape, not a grid~~ — **CLOSED (2026-06-01):** the full 4×3 grid ran. | |
| New open item it surfaced: **whitening is unreliable at 300-calib** (non-monotonic in rank) → the | |
| deferred 2000-img re-sweep (§5, item 4) tests whether richer calib stabilizes it. | |
| - Whether the 5 double (cross-modal) blocks tolerate W4A8 as well as the singles — only revisit | |
| with per-block precision if eval/montages show binding failures. | |
| - α for W4A8: 8-bit activations may want *less* smoothing (less weight burden) — sweep. | |
Xet Storage Details
- Size:
- 9.29 kB
- Xet hash:
- 004f50f88ca51efa1011606007259754f5aff53b80465cb7f16b9aea94f4e1d6
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.