Buckets:

Mercity/FluxDistill / block_surgery_plan.md
Pranav2748's picture
|
download
raw
24.2 kB
# FLUX.2 [klein] 4B → ~2B Distillation — Build Plan
> **Status:** Locked v1 (supersedes `init-plan.md`) · **Compute:** 1× B200 (192 GB) · **Teacher:** klein **distilled** 4-step 4B (Apache 2.0)
> **Deliverable:** a ~2.05B, 4-step, CFG-free text-to-image model positioned as a fast inference endpoint + **distill-down target** for narrow domains, plus the downstream distill-down kit.
This doc reflects the decisions locked in discussion. Where it departs from `init-plan.md`, this doc wins.
---
## Addendum (2026-05-31) — verified config, dataset, dev-environment scope
**Architecture verified** from `black-forest-labs/FLUX.2-klein-4B/transformer/config.json` (the §2 gate is cleared):
- `num_attention_heads=24` × `attention_head_dim=128`**d=3072** ✓; `num_layers=5` (double) ✓; `num_single_layers=20` (single) ✓; `joint_attention_dim=7680` ✓; `in_channels=128` ✓.
- `is_distilled=true` ✓ (correct checkpoint); **`guidance_embeds=false`** → fully CFG-free, *no guidance vector at all* (nothing to inherit).
- Corrections to earlier assumptions: **`mlp_ratio=3.0`** (not 4.0 — re-derive per-block params from the real state dict); transformer weights = 7.75 GB BF16 ≈ **3.87B params** (the "4B" is the denoiser). Stack: `Flux2Transformer2DModel` / `Qwen3ForCausalLM` / `AutoencoderKLFlux2` / `FlowMatchEulerDiscreteScheduler`, diffusers `0.37.0.dev0` (install from git).
**Dataset:** [`jasperai/monet`](https://huggingface.co/datasets/jasperai/monet) — large; **stream** it. For this dev pass use **20–30k rows @ 512×512** for the distillation/grounding signal (full 300k is the later large-GPU run).
**Empirical findings from the first surgery run (2026-05-31):**
- **Student is ~2.44B, not ~2.05B.** Real `Flux2SingleTransformerBlock` = 122.7M and `Flux2TransformerBlock` (double) = 245.4M (5 doubles = 1.23B un-prunable floor). Keep 8 single + 5 double + embedders ≈ 2.40B core; +12 surrogates @ r=512 (~3.15M each = 38M) → **2.44B** (matches the *original* init-plan estimate; the "2.05B" correction was wrong).
- **A per-token low-rank surrogate cannot reproduce attention token-mixing** — warm-start residual rel-err ≈ 0.9 (only the MLP-ish per-token component is captured). Those 12 positions effectively become per-token MLP layers; the student is "13 attention blocks + 12 per-token layers." Pre-training images are heavily degraded (text destroyed, coarse layout partially survives) — **expected pre-recovery**. **→ surrogate design is the primary thing to iterate** (options: cheap token-mixing surrogate, keep more single blocks full, or rely on distillation to compensate via the remaining 13 attention layers). Block selection via absolute tail-energy chose to surrogate single blocks 0–11 and keep 12–19 full.
**Surgery v2 decision (2026-05-31, after eval):** the per-token surrogate is the wrong function class (can't token-mix), so v1 (drop 12) is non-functional. **v2: keep most single blocks full, drop only ~4–6 selected by measured importance** (leave-one-out ablation: skip each block, measure relative change in the final latent; drop the least-important). Dropping 6 → keep 14 full → student ~3.16B (~1.2× faster) but *functional*; this is the conservative staged baseline. Aggressive compression returns once the surrogate can token-mix.
**TODO / to think about — cheap token-mixing surrogate (the real lever for getting back to ~2B):** replace the per-token `x + B·σ(A·x)` with a surrogate that moves information between token positions at low cost, so we can drop *many* blocks again without collapse. Candidates: (a) **local-window attention** (cheap neighborhood mixing — cheapest path to restore edges/texture); (b) **linear / kernelized attention** (Performer-style, O(N), genuine global mixing); (c) **token-mixing MLP** (MLP-Mixer along the token axis). Warm-start target: per-block residual rel-err ≤0.3–0.4 before training. This is what makes the original "drop 12 → ~2.4B" viable.
**RESULT (2026-05-31) — linear-attention surrogate wins decisively.** Implemented `LinearAttentionSurrogate` (O(N) linear attention, elu+1 feature map, 8 heads × 64, zero-init output → identity start; `scripts/09_build_linattn.py`). Same 6 dropped blocks (12–17), same held-out eval batch, AdamW @ constant 1e-4. Eval velocity-loss **0.5046 → 0.2528 (−50%)** vs the per-token surrogate's **0.308** — an **18% lower floor**, and it crossed *below* the per-token's converged floor by **step 50** (0.282). Samples visibly sharper (near-teacher lake; coherent, detailed storefront with partially legible sign) vs the per-token's soft output. Cost ~equal: surrogates 38M vs 19M (1.2% vs 0.6% of model), O(N) so negligible speed impact. **Conclusion: token-mixing is the right surrogate; this unblocks dropping *more* blocks (toward the original ~2B target) while staying recoverable.** Next: push `drop_k` (10–12) with the linear-attn surrogate. (Minor: constant LR jittered at the floor at step 250 — use the floored-cosine for the real run.)
---
## Session summary (2026-05-31) — full deep-dive in `report/REPORT.pdf`
**Block-ID convention:** single-stream blocks **S0..S19** (S0 = first/shallowest after the 5 double
blocks; S19 = last/deepest). "Deepest N" = highest IDs.
**Where things stand (the measured speed↔quality frontier):** best quality = linattn drop-6
(drop S12–S17) +RoPE+conv+warmstart → **0.231 @ 1.15×, 3.18B**; best perceptual = drop-8 (drop
S10–S17) +focused+FFN → **0.269 @ 1.20×**; most compression = drop-10 (drop S8–S17, FFN on S14–S17) →
**~0.322 @ 1.26×, 2.74B**; v1 per-token drop-12 (drop S0–S11) **collapsed**.
**Two structural ceilings (measured):** (1) single blocks are **78%** of transformer compute, the 5
double blocks only **21%** (per-block ~1.08×) → single-block surgery is the right lever but it's near
its useful limit; double-block FFN shrink ≈ 2% speedup, **rejected**. (2) Single blocks have low
redundancy (every one shifts the output ≥43%).
**Recommended next:** lock **drop-7 (FFN on deepest 3 incl. last)** as the balanced single-block model
(~3.08B, ~1.17×, ~0.25). For a genuine ~2×, the next *project* is **step reduction (4→2) via
step-distillation** — not more block surgery — plus text-encoder caching and the B200 scale-up.
**Basic training run result (2026-05-31): DIVERGED — recipe too aggressive.** Loss stable ~0.83 to step ~120, then spiked (2.46 @160) and blew up (8.08 @300); samples degraded from step 0 and collapsed to checkerboard noise. Diagnosis (ranked): (1) **Muon lr 0.02 applied to the well-pretrained *kept* blocks corrupts them** — kept blocks should be frozen or trained at a much lower lr than the surrogates; (2) **bf16 weights, no fp32 master**; (3) **no LR warmup, no grad clipping**. The step-0 v2 warm-start remained the best checkpoint. This empirically confirms the §8 risk ("Muon on warm-start recovery under-validated → A/B vs AdamW"). **Fixes for next run:** per-group LRs (surrogates high, kept blocks ~0 or tiny / freeze), LR warmup + grad clip, fp32 master, AdamW-vs-Muon A/B, lower base LR.
**Corrected recovery recipe (`scripts/08_train_recover.py`) — research-led.** LR research: Muon's lr≈0.02 is a *bulk-pretraining* value (nanoGPT speedruns; Kimi/Moonshot RMS-match it to reuse AdamW-scale LRs) — wrong for a few adapter modules and the cause of the blowup. Our surrogates are adapter-like (fresh, small, frozen base) → **diffusion/LoRA regime: AdamW ~1e-4–2e-4** (1e-5 too weak, full-finetune ~1e-6). **Recipe:** freeze the whole pretrained net, train **only the 6 surrogates (~18.9M, 0.6%)**; AdamW @ 1e-4, **cosine decay, grad-clip 1.0, fp32 master** on trained params (bf16 autocast compute); fixed held-out **eval velocity-loss** as the objective metric; per-step image logging; best-ckpt; divergence guard (stop if eval-loss > 3× baseline). **LR schedule:** cosine from base **to a 15%-of-base floor** (`MIN_LR_FRAC=0.15`), no warmup — NOT to 0. Decaying to exactly 0 wasted the last ~30 steps (lr <1e-6 ≈ dead tail) and would break resume/extend (a continued run would start at lr≈0). Also **size `STEPS` to the data** so the cosine tail isn't spent doing nothing (this dev run flattened by ~step 250). Muon stays the tool for the *later full-recovery* run (its §8 A/B), where much of the network trains. Note: surrogate-only recovery is *bounded* by the per-token surrogate's inability to token-mix — expect stability + modest improvement, not full sharpening; the token-mixing surrogate TODO is what unlocks real recovery.
**Result (corrected run, 300 steps, surrogate-only):** eval velocity-loss **0.4756 → 0.308, −35%**, monotonic, **zero divergence**, 28.7 GB, ~7 min; samples **visibly sharper** (mountain ridges/reflection, storefront detail) toward the teacher. **Controlled Muon-vs-AdamW A/B** (identical recipe, only optimizer swapped): AdamW@1e-4 → 0.3081 (−35.2%), Muon@2e-3 → 0.3056 (−35.8%) — **statistically equivalent**, both stable. Takeaway: the earlier blow-up was *not* Muon — it was lr 0.02 on the wrong params (all weights incl. pretrained kept blocks). Used correctly (low lr, only new params) Muon is fine but offers no edge for a handful of adapter modules; reserve it for the bulk full-recovery run.
**Dev-environment scope (this box ≠ the target box):** running on **1× A100-80GB**, not B200. So for now: **no FlashAttention-4 / Blackwell kernels — use PyTorch SDPA**; smaller batch; goal is a **clean, correct, scalable baseline** (surgery + calibration + inference + eval + a basic training loop) that lifts onto a bigger GPU later, *not* a perfected/perf-tuned run. Memory/throughput numbers in §8 are the B200 target, not this box.
---
## 0. What changed from `init-plan.md` (read first)
1. **Teacher is the distilled 4-step 4B, NOT the 50-step base.** Pure velocity/output matching (L2 regression) **cannot compress step count** — it clones a velocity field, and integrating a cloned field needs the same step budget as the source. Distilling from the 50-step base would leave us owing a separate 50→4 step-compression phase (DMD/consistency/adversarial), which we explicitly excluded. Cloning the already-4-step distilled teacher inherits 4-step behavior for free.
2. **Student is ~2.05B, not 2.45B.** With *true* low-rank surrogates (r=512), the 12 surrogates total ~38M params, not the ~410M the old plan assumed. Core kept ≈ 2.04B + surrogates ≈ 0.04B ≈ **~2.05B**.
3. **Surrogates are low-rank + nonlinearity** (`x + B·σ(A·x)`), r=512, not bare linear. ~Same FLOPs as bare linear (the activation is ~free), more faithful to a removed transformer block. Costs us the *exact* closed-form init → we add a short offline calibration fit (below).
4. **Distillation is trajectory-aware**, sampled along the teacher's 4-step rollout — not the naive `x_t=(1−t)·noise+t·latent` data interpolant. With only 4 large Euler steps, exposure bias compounds viciously; matching where the student will actually operate is what kills it.
5. **Positioning:** the 2B is a fast inference endpoint + distill-down target for **narrow** domains, **not** a directly-fine-tunable general base. (Distilled models are harder to fine-tune directly; the distill-down kit routes downstream specialization through the teacher, so the 2B never needs direct fine-tuning.)
6. **CFG-free inheritance:** the distilled teacher is guidance-distilled (`guidance_scale=1.0`, single forward, no CFG doubling). The student inherits the guidance embedder and runs CFG-free too.
---
## 1. Objective
Compress **FLUX.2 [klein] distilled 4B** (rectified-flow MM-DiT, 4-step, guidance-distilled) into a **~2.05B** model that:
- runs in **4 steps** (same as teacher — we are not changing step count),
- is **~1.6–1.7× faster** than the 4B and ~5 GB BF16,
- retains most of teacher-at-4-steps quality,
- serves as a **distill-down target**: third parties specialize it to a narrow domain via the kit in §11 (fine-tune a bigger teacher → amplify → distill into the 2B), rather than fine-tuning the 2B directly.
Recovered via knowledge distillation on a single B200, ~half-day.
---
## 2. Teacher
- **Model:** FLUX.2 [klein] **distilled** 4B — step- and guidance-distilled, 4 inference steps, `guidance_scale=1.0`, Apache 2.0.
- **Why this one (not base):** see §0.1. Regression preserves step count; we want a 4-step student, so we need a 4-step teacher.
- **Conditioning:** Qwen3-4B text encoder (intermediate-layer stack → conditioning vector). Frozen; not part of the denoiser. **Student stays on Qwen3-4B** (never bolt the 9B's Qwen3-8B encoder on — dimension mismatch).
- **Latent space:** shared FLUX.2 VAE. Identical across the klein family — this is what makes the §11 cross-size distill-down valid.
- **Forward cost:** single forward per step (CFG-free). Used frozen (eval, `no_grad`, BF16, compiled) as an online target generator.
> **Open item (verify before surgery):** the exact 4B block layout (`init-plan.md` asserts 5 double + 20 single, d=3072, 24 heads — *unverified from the model card*). Pull `config.json` / the diffusers Flux2 impl and confirm block counts + hidden dim before cutting. The entire prune + correspondence map depends on these numbers. Everything below assumes the asserted layout; adjust if the config differs.
---
## 3. Target architecture (~2.05B student)
Full-width, depth-pruned, removed depth replaced by light low-rank surrogates.
| Component | Spec | Rationale |
| --- | --- | --- |
| Width | `d=3072` (unchanged) | Preserves texture/detail; widths match teacher → **no projector** needed for feature matching |
| Double-stream | **5 (all kept)** | Preserve cross-modal / instruction binding (these are also the un-prunable FLOP floor) |
| Single-stream | **8 kept full** (of 20) | Depth-only prune; recovers better than width prune |
| Removed depth | **12 single blocks → low-rank+nonlinearity surrogates** (r=512) | Residual-only low-rank approximation, shortcut preserved; SVD-init + calibration warm-start |
| Embedders / modulation | Inherited from teacher (timestep, **guidance**, shared-across-block AdaLN) | Keeps CFG-free behavior; AdamW param group |
| Effective size | core ~2.04B + surrogates ~0.04B ≈ **~2.05B** | |
| Inference | **~1.6–1.7× faster** than 4B, ~5 GB BF16, **4 steps** | |
---
## 4. Surrogate design (the removed 12 blocks)
**Form:** `out = x + B·σ(A·x)`, with `A: d→r`, `B: r→d`, `r=512`, `σ` = GELU-family (match teacher block activation). Shortcut `x` preserved exactly — we only approximate the block's **residual delta**.
**Compression at r=512** (vs d=3072):
- Bottleneck = 1/6 (16.7%) of residual-stream width.
- ~110M params/block → ~3.15M/block ⇒ **~35× param compression** (keep ~3% of the block).
- ~200 → ~10 GFLOP/block ⇒ **~21× cheaper** (skips attention/N²).
- Rank is a **quality knob, not a speed knob**: across r∈{128,256,512} total-model FLOPs move <3% and speedup stays ~1.65×. 512 = most capacity at no speed cost. (Escape hatch: 768/1024 still <5% of total FLOPs if we ever want less compression.)
**Which 12 to surrogate vs which 8 to keep full:** data-driven on a calibration batch. For each single block, compute the SVD energy of its **residual delta** (input→output map). Blocks whose top-512 singular values capture most energy are cheap to approximate → **surrogate them**; high-rank residuals → **keep full**. (Falls back to / cross-checks against block-influence/CKA importance. Middle blocks are usually the most compressible.)
**Init (warm start):** the nonlinearity breaks the exact closed-form least-squares init, so:
1. Init `A,B` via **truncated-SVD of the linear least-squares solution** (treat σ≈identity / scale so `A·x` lands in σ's near-linear region).
2. **Short offline calibration fit** — a few hundred gradient steps per surrogate to match each block's I/O on the calibration set, *before* the main run. Cheap, offline, not part of the training-budget run.
**Sanity:** shape + forward-parity checks after surgery; confirm full width `d=3072` so feature matching needs no projector.
---
## 5. Distillation methodology (online teacher)
Convention: `x_t = (1−t)·noise + t·latent` (t=0 noise → t=1 data); straight-line velocity `v = latent − noise`.
**Primary — trajectory velocity matching (instills 4-step behavior):**
- Roll out the **teacher's 4-step sampler** from noise, collecting the 4 visited `(x_{t_i}, t_i)` points — the exact points the student sees at inference.
- At each point: teacher velocity `v_teacher` is the target; student predicts `v_student`; loss = MSE.
- **Concentrate timesteps on the 4 schedule points**, not uniform `t`. This is the departure from the old plan's data-interpolant sampling, and it's what kills few-step exposure bias.
- v1: **teacher-trajectory** rollout (stable, cheap). Optional polish later: a fraction of **student-trajectory** (on-policy) steps where the teacher corrects student-visited points (DAgger-style) to close residual exposure bias.
**Secondary — masked feature matching on retained blocks:**
- Match student↔teacher hidden states at the retained-block positions (no projector — widths match). Match each kept block to its *original teacher position* (indices renumber after the prune).
- **Masked KD:** exclude DiT extreme-activation outliers (massive activations) from the feature loss.
**Grounding — real-data flow-matching:**
- Standard FM loss on real latents at randomly sampled `t` (broad timestep coverage + grounding). Cheap (1 forward).
**Loss:** `w1·velocity-MSE (trajectory)` + `w2·feature-match (masked)` + `w3·real-data FM`. Tune weights; velocity primary.
**Multi-timestep / epochs:** fresh noise + fresh rollout per image per epoch (new supervision, not repetition). 4 epochs × 300k ≈ **1.2M views**.
---
## 6. Optimizer & precision
- **Muon** on 2D hidden weights (attention/MLP, incl. surrogate `A,B`) + **AdamW** on embedders, `proj_out`, norms, shared AdaLN modulation. Muon LR runs higher than AdamW — tune LR + WD per group.
- **BF16** compute, **FP32** master weights + FP32 moments. No quantization in training.
- Micro-batch **64**, gradient-accumulate to **effective ~256** (Muon's large-batch sweet spot). No gradient checkpointing planned (fits).
- `torch.compile` (CUDA-graph mode); **FlashAttention** (Blackwell-native, via FlexAttention) for the retained blocks.
---
## 7. Data pipeline
- **300k curated images + captions**, diversity-first (broad styles/subjects/composition + in-image text coverage). Stable IDs. *(Sourcing/licensing for an Apache-2.0 release is an open item — §13.)*
- **VAE-encode offline → BF16 latents in streaming shards** (WebDataset / tensordict). VAE never runs in the training loop.
- **Qwen3-4B text embeddings:** run online (intermediate layers, frozen). Optionally batch-precompute per shard if I/O-bound (deferred — not a v1 priority).
- Dataloader streams `(latent, caption)` with aggressive prefetch; per step samples noise and forms the teacher rollout.
- Bulk training at **512px**; optional short 768/1024 polish if time/memory headroom remains.
---
## 8. Compute & infra (1× B200, 192 GB)
**Memory (fixed ~45 GB):** FP32 master ~9.8 + BF16 weights ~4.9 + BF16 grads ~4.9 + Muon momentum ~9.0 + AdamW m,v ~1.6 + teacher 4B (fwd) ~6.8 + Qwen3-4B (fwd) ~8.0. → ~137 GB free for activations. (Student is ~2.05B; weight/grad/state lines scale to that.)
**Batch:** activations ≈ 1.63 GB/sample @512px → micro-batch 64 ≈ ~104 GB activations; fits without checkpointing.
**Throughput / MFU:** the loop runs three forwards (student fwd+bwd, teacher fwd, Qwen3-4B fwd) + compile + FlexAttention. **Plan around 30–35% MFU**; treat 40%+ as upside. **Validate MFU on a ~100-step smoke run before committing the full run.**
**Time/cost (context only — budget de-prioritized):** ~13–17 h depending on MFU; GPU cost ~$66–85. Data curation/captioning is a separate, larger cost not in this figure (§13).
---
## 9. De-risking / early experiments (run before the full run)
Two unproven components sit in the critical path — isolate each on a small first chunk:
1. **Optimizer A/B:** Muon vs AdamW on the first training chunk (loss-curve + sample quality). Confirm Muon's warm-start recovery before committing.
2. **Surrogate A/B:** low-rank+nonlinearity surrogate **vs** plain block-drop (no replacement), same chunk — confirms the surrogate is earning its keep.
3. **Rank/calibration check:** verify SVD-init + calibration-fit reconstruction error per surrogate; confirm r=512 (vs 256) recovery delta.
4. **MFU smoke:** 100-step compiled run to confirm ≥30% before the long run.
---
## 10. Eval & go/no-go
- **Metrics:** FID/CLIP + targeted prompt probes (in-image text/typography, multi-object composition/counting, spatial/attribute binding, texture/fine detail, small faces/hands) — all measured **vs teacher-at-4-steps** (the realistic ceiling).
- **Go/no-go thresholds (set concretely before the run):** e.g. CLIP within X of teacher@4, FID within Y, no probe category catastrophically below teacher@4. Define X/Y on a held-out probe set during the early chunk.
- **Downstream smoke test:** one narrow-domain distill-down pass to confirm the kit path works end-to-end.
---
## 11. Distill-down kit (downstream specialization)
Ship the 2B **plus** a kit so users specialize it via a bigger teacher rather than fine-tuning the 2B directly — best for **small downstream datasets**.
- **Mechanism:** fine-tune the larger teacher (4B or 9B) on the narrow domain → use it as a data amplifier (synthetic in-domain samples) → distill into the 2B via velocity matching (same loop as §5, teacher swappable).
- **Cross-size validity:** whole klein family shares one VAE/latent space → 9B→2B and 4B→2B velocity matching is well-defined. **Student stays on Qwen3-4B** regardless of teacher.
- **Ship:** the **block-correspondence map** (student block ↔ teacher block) + projector recipe for optional feature matching + scripts + minimal config — the hard part users can't easily derive.
- **License:** 4B teacher path = commercial-OK (Apache 2.0); 9B teacher path = non-commercial only.
---
## 12. Build order (milestones)
1. **Verify 4B config** (block counts, d, encoder) — gate before surgery.
2. **Data + latents** — curate 300k diverse pairs; VAE-encode to shards.
3. **Surgery** — SVD-energy block selection (keep 8 / surrogate 12); build low-rank+nonlinearity surrogates; SVD-init + offline calibration fit; forward-parity checks.
4. **De-risking chunk** (§9) — optimizer A/B, surrogate A/B, rank check, MFU smoke. Decide go/no-go.
5. **Recovery run** — Muon+AdamW, online 4-step teacher, trajectory velocity + masked feature + real-data FM, 1.2M views @512px, BF16, effective batch 256.
6. **Eval** (§10) — FID/CLIP + probes vs teacher@4; downstream smoke test.
7. **(Optional) polish** — short 768/1024 phase.
8. **Release** — 2B weights + distill-down kit + model card (state: 4-step, CFG-free, distill-down-target positioning).
---
## 13. Open items / risks
- **Verify 4B architecture from config** before surgery (gates everything). [§2, §12.1]
- **Data sourcing & licensing** — 300k curated/captioned images for an Apache-2.0 release; cost + provenance not yet settled. Curate for breadth, not volume; quality gaps will track wherever the set is thin.
- **Muon on warm-start recovery** under-validated → A/B (§9.1).
- **Surrogate quality vs compression** — r=512 default; per-block rank could be tuned by reconstruction target if some blocks under-recover.
- **1.2M views is partial recovery** in absolute terms → monitor loss/probes; be willing to extend if a capability (esp. in-image text) stays weak.
- **MFU sensitivity** — a sloppy loop at ~25% ≈ 1.6× longer/costlier; the §8 optimizations + smoke test are the guardrail.
- **Positioning discipline** — model card must be explicit that the 2B is a narrow-domain distill-down target, not a directly-fine-tunable general base.

Xet Storage Details

Size:
24.2 kB
·
Xet hash:
de7177f23f27e7ff48a2b0dc9416387f315e8078bdf256e2c5c0b18949125e32

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.