Buckets:

Mercity/FluxDistill / README.md
Pranav2748's picture
|
download
raw
9.91 kB
# flux2distill — FLUX.2 [klein] 4B compression
Compress **FLUX.2 [klein] distilled 4B** (4-step, CFG-free MM-DiT) into a smaller, faster model.
Current rig (since 2026-06-10): **1× RTX PRO 4500 Blackwell 32 GB (sm_120)**, **system python**,
**torch 2.12.0+cu130** (the older A100/`.venv`/cu126 notes below are historical — see **`CLAUDE.md`**
for the authoritative current environment, the ephemeral-pod caveat, and full decision log).
See `plan.md` for the active design, `RESULTS.md` for all numbers.
> ⚠️ **Ephemeral pod:** only `/workspace` (this repo, synced to the HF bucket) persists — `models/`,
> the python stack, and any agent memory are wiped on restart. Record durable facts in `CLAUDE.md`/docs.
## ACTIVE TRACK — W4A8 SVDQuant (post-training quantization)
Our own fake-quant SVDQuant: per-Linear `smooth → (whitened) SVD low-rank (16-bit) + iterative
refine → 4-bit residual`. **★ As of 2026-06-13 this is a DEPLOYABLE NVFP4 model on real Blackwell
kernels:** NVFP4 (E2M1+group-16+FP8 scales) beats INT4 on quality AND speed; we built Nunchaku from
source (sm_120), wrote our own bf16→Nunchaku NVFP4 exporter, and ship
`outputs/nvfp4/deploy/klein4b_nvfp4_fused.safetensors` (2.9 GB) — teacher-indistinguishable,
**1.74×@512 / 1.90×@1024 end-to-end, −24% VRAM**. Quality champion NVFP4-W+FP8-A r128 = **0.0169**;
deployable NVFP4 W4A4 r128 = **0.0303**. Full write-up + Pareto figures: **`report/NVFP4_REPORT.md`**;
setup/footguns `docs/CUDA_SETUP_RUNBOOK.md`; next-step speedups `docs/SPEEDUP_IDEAS.md`. (Earlier
fake-quant grid + math: `report/QUANT_REPORT*.{md,pdf}`.)
**★★ 2026-06-14 — matched head-to-head (N=512, MJHQ-30k, 512px):** our deployable NVFP4 W4A4 r128 vs
plain NVFP4 r0, our fake-q, and **BFL's official FP8** (real public baseline). The SVDQuant low-rank
branch helps at NVFP4 W4A4 — **LPIPS −19.7%, PSNR +1.27 dB, FID-to-teacher −14.7%** vs plain r0, no
semantic loss; real kernel reproduces fake-quant. **BFL's official NVFP4 could NOT be run** (cutlass
tensor-core swizzled layout, needs their TensorRT runtime — documented, no proxy faked). Full report
**`report/HEADTOHEAD_klein4b_nvfp4.md`**; numbers `RESULTS.md` (2026-06-14) + `outputs/eval/h2h/metrics.json`;
speed `outputs/nvfp4/benchmark_headtohead.json`; figures `report/figures/h2h_*.png`. Pipeline:
`scripts/run_h2h.sh``scripts/34_metrics.py``scripts/42_h2h_figures.py` (+ `scripts/run_probes.sh`,
BFL fp8 loader `scripts/41_gen_bfl_fp8.py`).
```bash
export PYTHONPATH=. # system python (no .venv since 2026-06-10), torch 2.12+cu130
# one grid cell (build + eval), its own logs: args = RANK variant WHITEN REFINE
bash scripts/run_cell.sh 64 plain_refine 0 3 # -> outputs/abl_c300_r64_plain_refine/
python3 scripts/make_quant_report_assets.py # analysis figures
python3 scripts/build_report_pdf.py # report/QUANT_REPORT.pdf (incl. all montages)
```
Run experiments ONE AT A TIME with per-run logs + a Monitor (no batched bg loops). Calibration uses
the cached `data/monet_cache` latents (no image download for the 300-img grid). The 2000-img calib
re-sweep (`scripts/11``data/monet_calib`) is the queued next experiment — see `TODO.md`.
### Backup / sync to the HF bucket
Work is archived to the HF bucket **`hf://buckets/Mercity/FluxDistill`** via `hf sync`. Upload needs a
**write token** (`HF_TOKEN`, never commit it; an ephemeral-pod restart wipes the cached login, so
re-export it — the bucket is *public-read* so downloads work without one). `--no-delete` is the default
(additive backup; local deletions don't propagate). **Preview with `--dry-run` first**, and aggregate
the plan by size — that's how the 311 GB-of-`.pt` footgun was caught.
> **Pattern gotcha:** `hf sync` matching is Python **`fnmatch`**, where `*` already crosses `/`. So use
> `*__pycache__*` (NOT `**/__pycache__/**`, which misses top-level dirs) and `*quant_state.pt` (NOT
> `**/...`). `dir/*` matches everything under `dir` at any depth.
```bash
export HF_TOKEN=hf_... # write token; rotate if it ever leaks
hf sync ./ hf://buckets/Mercity/FluxDistill \
--exclude "models/klein-4b/*" \ # public teacher (re-download via hf)
--exclude "models/bfl-klein-4b-nvfp4/*" --exclude "models/bfl-klein-4b-fp8/*" \ # public BFL checkpoints
--exclude "models/klein-9b-nunchaku/*" \ # public nunchaku FLUX.2 loader repo
--exclude "miniforge3/*" --exclude ".cache/*" --exclude "tmp/*" \
--exclude "*__pycache__*" --exclude "*.pyc" --exclude ".ipynb_checkpoints/*" \
--exclude "recovered/*" --exclude "*quant_state.pt" \ # *.pt fake-quant states are huge + regenerable
--exclude "build_nunchaku/src/build/*" # build temp objects (the wheel + kernel src ARE kept)
# To push ONLY new deliverables (skip the ~10 GB deploy safetensors / calib already in the bucket), also add:
# --exclude "outputs/nvfp4/deploy/*" --exclude "outputs/eval/imgs/*" --exclude "data/*" --exclude "monet_cache/*"
# add --dry-run to preview; --no-delete is default (deletions don't propagate).
```
---
## SHELVED TRACK — block surgery (depth-prune → surrogates → distill)
Topped out at ~1.15–1.26× and was quality-bounded (best 0.231 vs quant's ~0.045). Kept for record
(`block_surgery_plan.md`, `block_surgery_todo.md`, scripts `01–10`). The rest of this README
documents that track. NOTE: its `.pt` model states were deleted to reclaim space (sample images /
logs / `selection.json` kept). Original design + decision log below.
## Status (2026-05-31)
| Stage | State |
|---|---|
| Env + klein-4B download + arch verification | ✅ |
| Surgery: block selection + warm-started surrogates → student | ✅ |
| Inference (teacher & student) | ✅ teacher 0.45s/img, student ~0.31s/img @512/4steps |
| Eval: 28-prompt set + multi-agent visual review | ✅ `outputs/eval/baseline/REVIEW.md` |
| Data: monet URL→VAE-latent cache | ✅ `data/monet_cache/` |
| Basic distillation training loop | ✅ velocity-match + FM grounding, Muon+AdamW |
**Key finding:** a per-token low-rank+GELU surrogate **cannot reproduce attention's
token-mixing**, so dropping 12 of 20 single blocks (v1) collapses the model. **v2** keeps
most blocks full and drops only the **6 least-important** single blocks (by leave-one-out
ablation) → **3.16B, functional pre-training**. The route back to ~2B is a *token-mixing*
surrogate (local-window / linear attention) — see `plan.md` TODO.
## Models produced
- `outputs/student/` — v1 (drop 12 by SVD-energy) — **non-functional** (reference).
- `outputs/student_v2/` — v2 (drop 6 by importance) — **3.16B, functional** baseline.
- `outputs/train_v2/` — v2 after the basic recovery run (+ sample grids).
## Layout
```
flux2distill/
config.py # all knobs (model / surgery / data / train / eval)
surrogate.py # LowRankResidualSurrogate (x + B·σ(A·x)) + lstsq/SVD init
surgery.py # importance ablation, SVD-energy selection, build/attach student
calibration.py # surrogate warm-start gradient fit
losses.py # velocity matching + flow-matching grounding
data.py # cached-latent dataset
model_utils.py # load teacher/student, Muon/AdamW param split, param counts
eval_utils.py # prompt parsing, student loader, comparison grids
optim/muon.py # Muon optimizer (2D weights)
scripts/
01_inspect_model.py # introspect transformer module tree / params
02_teacher_smoke.py # teacher 4-step generation sanity
03_build_student.py # v1 surgery (SVD-energy, drop 12)
04_gen_eval.py [tag] # teacher-vs-student images across prompt set
05_build_student_v2.py [drop_k] # v2 surgery (importance, drop 6)
06_cache_data.py [N] # monet URL → VAE latents cache
07_train.py [steps] # FLAWED baseline run (trained all weights → diverged); kept as record
08_train_recover.py [steps] [adamw|muon] [lr] # CORRECT: surrogate-only, frozen base, cosine+clip
prompts/eval_prompts.txt # 28 prompts, tagged by capability
plan.md # design + decision log + findings
```
## Run order
```bash
export PYTHONPATH=.
python3 scripts/01_inspect_model.py # (optional) verify architecture
python3 scripts/02_teacher_smoke.py # teacher works
python3 scripts/05_build_student_v2.py 6 # build the v2 student (drop 6)
python3 scripts/04_gen_eval.py baseline # teacher vs student image pairs
python3 scripts/06_cache_data.py 200 # cache training data
python3 -u scripts/08_train_recover.py 300 adamw 1e-4 # surrogate-only recovery (correct recipe)
```
## Training recipe (research-led)
Only the **6 surrogate modules (~19M, 0.6%)** are trained; the pretrained network is **frozen**
(training the kept blocks at high LR was what diverged — see plan.md). Surrogates are adapter-like,
so the **diffusion/LoRA regime** applies: **AdamW @ 1e-4, cosine decay to a 15%-of-base floor
(not 0), grad-clip 1.0, fp32 master** on the trained params. Muon's lr~0.02 is a *bulk-pretraining* value (nanoGPT/Kimi) — reserved for the
later full-recovery run (the §8 Muon-vs-AdamW A/B), not adapter training. The loop logs a fixed
held-out **eval velocity-loss** (objective metric), per-step sample images, grad norm, and saves the
best checkpoint; a divergence guard auto-stops if eval-loss exceeds 3× baseline.
## Notes / upgrades for the big run (B200)
- **Surrogate v2 → token-mixing** (the real lever to reach ~2B): local-window or linear attention.
- FlashAttention-4 / FlexAttention; larger batch; `torch.compile`.
- **fp32 master weights** + fp32 moments (current dev run trains in bf16).
- **Trajectory velocity matching** on the 4 schedule sigmas (current run samples σ~U(0,1) on cached latents).
- Feature matching on retained blocks (masked KD); offline latent shards at 300k scale.

Xet Storage Details

Size:
9.91 kB
·
Xet hash:
bf2766713a48a40b6b5dd062414dfcebd9138ed02383e801402e45a3c26b893a

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.