Buckets:
FLUX.2 [klein] 4B → ~2B Distillation — Build Plan
Status: Locked v1 (supersedes
init-plan.md) · Compute: 1× B200 (192 GB) · Teacher: klein distilled 4-step 4B (Apache 2.0) Deliverable: a ~2.05B, 4-step, CFG-free text-to-image model positioned as a fast inference endpoint + distill-down target for narrow domains, plus the downstream distill-down kit.
This doc reflects the decisions locked in discussion. Where it departs from init-plan.md, this doc wins.
Addendum (2026-05-31) — verified config, dataset, dev-environment scope
Architecture verified from black-forest-labs/FLUX.2-klein-4B/transformer/config.json (the §2 gate is cleared):
num_attention_heads=24×attention_head_dim=128→ d=3072 ✓;num_layers=5(double) ✓;num_single_layers=20(single) ✓;joint_attention_dim=7680✓;in_channels=128✓.is_distilled=true✓ (correct checkpoint);guidance_embeds=false→ fully CFG-free, no guidance vector at all (nothing to inherit).- Corrections to earlier assumptions:
mlp_ratio=3.0(not 4.0 — re-derive per-block params from the real state dict); transformer weights = 7.75 GB BF16 ≈ 3.87B params (the "4B" is the denoiser). Stack:Flux2Transformer2DModel/Qwen3ForCausalLM/AutoencoderKLFlux2/FlowMatchEulerDiscreteScheduler, diffusers0.37.0.dev0(install from git).
Dataset: jasperai/monet — large; stream it. For this dev pass use 20–30k rows @ 512×512 for the distillation/grounding signal (full 300k is the later large-GPU run).
Empirical findings from the first surgery run (2026-05-31):
- Student is ~2.44B, not ~2.05B. Real
Flux2SingleTransformerBlock= 122.7M andFlux2TransformerBlock(double) = 245.4M (5 doubles = 1.23B un-prunable floor). Keep 8 single + 5 double + embedders ≈ 2.40B core; +12 surrogates @ r=512 (~3.15M each = 38M) → 2.44B (matches the original init-plan estimate; the "2.05B" correction was wrong). - A per-token low-rank surrogate cannot reproduce attention token-mixing — warm-start residual rel-err ≈ 0.9 (only the MLP-ish per-token component is captured). Those 12 positions effectively become per-token MLP layers; the student is "13 attention blocks + 12 per-token layers." Pre-training images are heavily degraded (text destroyed, coarse layout partially survives) — expected pre-recovery. → surrogate design is the primary thing to iterate (options: cheap token-mixing surrogate, keep more single blocks full, or rely on distillation to compensate via the remaining 13 attention layers). Block selection via absolute tail-energy chose to surrogate single blocks 0–11 and keep 12–19 full.
Surgery v2 decision (2026-05-31, after eval): the per-token surrogate is the wrong function class (can't token-mix), so v1 (drop 12) is non-functional. v2: keep most single blocks full, drop only ~4–6 selected by measured importance (leave-one-out ablation: skip each block, measure relative change in the final latent; drop the least-important). Dropping 6 → keep 14 full → student 3.16B (1.2× faster) but functional; this is the conservative staged baseline. Aggressive compression returns once the surrogate can token-mix.
TODO / to think about — cheap token-mixing surrogate (the real lever for getting back to ~2B): replace the per-token x + B·σ(A·x) with a surrogate that moves information between token positions at low cost, so we can drop many blocks again without collapse. Candidates: (a) local-window attention (cheap neighborhood mixing — cheapest path to restore edges/texture); (b) linear / kernelized attention (Performer-style, O(N), genuine global mixing); (c) token-mixing MLP (MLP-Mixer along the token axis). Warm-start target: per-block residual rel-err ≤0.3–0.4 before training. This is what makes the original "drop 12 → ~2.4B" viable.
RESULT (2026-05-31) — linear-attention surrogate wins decisively. Implemented LinearAttentionSurrogate (O(N) linear attention, elu+1 feature map, 8 heads × 64, zero-init output → identity start; scripts/09_build_linattn.py). Same 6 dropped blocks (12–17), same held-out eval batch, AdamW @ constant 1e-4. Eval velocity-loss 0.5046 → 0.2528 (−50%) vs the per-token surrogate's 0.308 — an 18% lower floor, and it crossed below the per-token's converged floor by step 50 (0.282). Samples visibly sharper (near-teacher lake; coherent, detailed storefront with partially legible sign) vs the per-token's soft output. Cost ~equal: surrogates 38M vs 19M (1.2% vs 0.6% of model), O(N) so negligible speed impact. Conclusion: token-mixing is the right surrogate; this unblocks dropping more blocks (toward the original ~2B target) while staying recoverable. Next: push drop_k (10–12) with the linear-attn surrogate. (Minor: constant LR jittered at the floor at step 250 — use the floored-cosine for the real run.)
Session summary (2026-05-31) — full deep-dive in report/REPORT.pdf
Block-ID convention: single-stream blocks S0..S19 (S0 = first/shallowest after the 5 double blocks; S19 = last/deepest). "Deepest N" = highest IDs.
Where things stand (the measured speed↔quality frontier): best quality = linattn drop-6 (drop S12–S17) +RoPE+conv+warmstart → 0.231 @ 1.15×, 3.18B; best perceptual = drop-8 (drop S10–S17) +focused+FFN → 0.269 @ 1.20×; most compression = drop-10 (drop S8–S17, FFN on S14–S17) → ~0.322 @ 1.26×, 2.74B; v1 per-token drop-12 (drop S0–S11) collapsed.
Two structural ceilings (measured): (1) single blocks are 78% of transformer compute, the 5 double blocks only 21% (per-block ~1.08×) → single-block surgery is the right lever but it's near its useful limit; double-block FFN shrink ≈ 2% speedup, rejected. (2) Single blocks have low redundancy (every one shifts the output ≥43%).
Recommended next: lock drop-7 (FFN on deepest 3 incl. last) as the balanced single-block model (~3.08B, ~1.17×, ~0.25). For a genuine ~2×, the next project is step reduction (4→2) via step-distillation — not more block surgery — plus text-encoder caching and the B200 scale-up.
Basic training run result (2026-05-31): DIVERGED — recipe too aggressive. Loss stable ~0.83 to step ~120, then spiked (2.46 @160) and blew up (8.08 @300); samples degraded from step 0 and collapsed to checkerboard noise. Diagnosis (ranked): (1) Muon lr 0.02 applied to the well-pretrained kept blocks corrupts them — kept blocks should be frozen or trained at a much lower lr than the surrogates; (2) bf16 weights, no fp32 master; (3) no LR warmup, no grad clipping. The step-0 v2 warm-start remained the best checkpoint. This empirically confirms the §8 risk ("Muon on warm-start recovery under-validated → A/B vs AdamW"). Fixes for next run: per-group LRs (surrogates high, kept blocks ~0 or tiny / freeze), LR warmup + grad clip, fp32 master, AdamW-vs-Muon A/B, lower base LR.
Corrected recovery recipe (scripts/08_train_recover.py) — research-led. LR research: Muon's lr≈0.02 is a bulk-pretraining value (nanoGPT speedruns; Kimi/Moonshot RMS-match it to reuse AdamW-scale LRs) — wrong for a few adapter modules and the cause of the blowup. Our surrogates are adapter-like (fresh, small, frozen base) → diffusion/LoRA regime: AdamW ~1e-4–2e-4 (1e-5 too weak, full-finetune 1e-6). Recipe: freeze the whole pretrained net, train **only the 6 surrogates (18.9M, 0.6%)**; AdamW @ 1e-4, cosine decay, grad-clip 1.0, fp32 master on trained params (bf16 autocast compute); fixed held-out eval velocity-loss as the objective metric; per-step image logging; best-ckpt; divergence guard (stop if eval-loss > 3× baseline). LR schedule: cosine from base to a 15%-of-base floor (MIN_LR_FRAC=0.15), no warmup — NOT to 0. Decaying to exactly 0 wasted the last ~30 steps (lr <1e-6 ≈ dead tail) and would break resume/extend (a continued run would start at lr≈0). Also size STEPS to the data so the cosine tail isn't spent doing nothing (this dev run flattened by ~step 250). Muon stays the tool for the later full-recovery run (its §8 A/B), where much of the network trains. Note: surrogate-only recovery is bounded by the per-token surrogate's inability to token-mix — expect stability + modest improvement, not full sharpening; the token-mixing surrogate TODO is what unlocks real recovery.
Result (corrected run, 300 steps, surrogate-only): eval velocity-loss 0.4756 → 0.308, −35%, monotonic, zero divergence, 28.7 GB, ~7 min; samples visibly sharper (mountain ridges/reflection, storefront detail) toward the teacher. Controlled Muon-vs-AdamW A/B (identical recipe, only optimizer swapped): AdamW@1e-4 → 0.3081 (−35.2%), Muon@2e-3 → 0.3056 (−35.8%) — statistically equivalent, both stable. Takeaway: the earlier blow-up was not Muon — it was lr 0.02 on the wrong params (all weights incl. pretrained kept blocks). Used correctly (low lr, only new params) Muon is fine but offers no edge for a handful of adapter modules; reserve it for the bulk full-recovery run.
Dev-environment scope (this box ≠ the target box): running on 1× A100-80GB, not B200. So for now: no FlashAttention-4 / Blackwell kernels — use PyTorch SDPA; smaller batch; goal is a clean, correct, scalable baseline (surgery + calibration + inference + eval + a basic training loop) that lifts onto a bigger GPU later, not a perfected/perf-tuned run. Memory/throughput numbers in §8 are the B200 target, not this box.
0. What changed from init-plan.md (read first)
- Teacher is the distilled 4-step 4B, NOT the 50-step base. Pure velocity/output matching (L2 regression) cannot compress step count — it clones a velocity field, and integrating a cloned field needs the same step budget as the source. Distilling from the 50-step base would leave us owing a separate 50→4 step-compression phase (DMD/consistency/adversarial), which we explicitly excluded. Cloning the already-4-step distilled teacher inherits 4-step behavior for free.
- Student is ~2.05B, not 2.45B. With true low-rank surrogates (r=512), the 12 surrogates total
38M params, not the ~410M the old plan assumed. Core kept ≈ 2.04B + surrogates ≈ 0.04B ≈ **2.05B**. - Surrogates are low-rank + nonlinearity (
x + B·σ(A·x)), r=512, not bare linear. ~Same FLOPs as bare linear (the activation is ~free), more faithful to a removed transformer block. Costs us the exact closed-form init → we add a short offline calibration fit (below). - Distillation is trajectory-aware, sampled along the teacher's 4-step rollout — not the naive
x_t=(1−t)·noise+t·latentdata interpolant. With only 4 large Euler steps, exposure bias compounds viciously; matching where the student will actually operate is what kills it. - Positioning: the 2B is a fast inference endpoint + distill-down target for narrow domains, not a directly-fine-tunable general base. (Distilled models are harder to fine-tune directly; the distill-down kit routes downstream specialization through the teacher, so the 2B never needs direct fine-tuning.)
- CFG-free inheritance: the distilled teacher is guidance-distilled (
guidance_scale=1.0, single forward, no CFG doubling). The student inherits the guidance embedder and runs CFG-free too.
1. Objective
Compress FLUX.2 [klein] distilled 4B (rectified-flow MM-DiT, 4-step, guidance-distilled) into a ~2.05B model that:
- runs in 4 steps (same as teacher — we are not changing step count),
- is ~1.6–1.7× faster than the 4B and ~5 GB BF16,
- retains most of teacher-at-4-steps quality,
- serves as a distill-down target: third parties specialize it to a narrow domain via the kit in §11 (fine-tune a bigger teacher → amplify → distill into the 2B), rather than fine-tuning the 2B directly.
Recovered via knowledge distillation on a single B200, ~half-day.
2. Teacher
- Model: FLUX.2 [klein] distilled 4B — step- and guidance-distilled, 4 inference steps,
guidance_scale=1.0, Apache 2.0. - Why this one (not base): see §0.1. Regression preserves step count; we want a 4-step student, so we need a 4-step teacher.
- Conditioning: Qwen3-4B text encoder (intermediate-layer stack → conditioning vector). Frozen; not part of the denoiser. Student stays on Qwen3-4B (never bolt the 9B's Qwen3-8B encoder on — dimension mismatch).
- Latent space: shared FLUX.2 VAE. Identical across the klein family — this is what makes the §11 cross-size distill-down valid.
- Forward cost: single forward per step (CFG-free). Used frozen (eval,
no_grad, BF16, compiled) as an online target generator.
Open item (verify before surgery): the exact 4B block layout (
init-plan.mdasserts 5 double + 20 single, d=3072, 24 heads — unverified from the model card). Pullconfig.json/ the diffusers Flux2 impl and confirm block counts + hidden dim before cutting. The entire prune + correspondence map depends on these numbers. Everything below assumes the asserted layout; adjust if the config differs.
3. Target architecture (~2.05B student)
Full-width, depth-pruned, removed depth replaced by light low-rank surrogates.
| Component | Spec | Rationale |
|---|---|---|
| Width | d=3072 (unchanged) |
Preserves texture/detail; widths match teacher → no projector needed for feature matching |
| Double-stream | 5 (all kept) | Preserve cross-modal / instruction binding (these are also the un-prunable FLOP floor) |
| Single-stream | 8 kept full (of 20) | Depth-only prune; recovers better than width prune |
| Removed depth | 12 single blocks → low-rank+nonlinearity surrogates (r=512) | Residual-only low-rank approximation, shortcut preserved; SVD-init + calibration warm-start |
| Embedders / modulation | Inherited from teacher (timestep, guidance, shared-across-block AdaLN) | Keeps CFG-free behavior; AdamW param group |
| Effective size | core |
|
| Inference | ~1.6–1.7× faster than 4B, ~5 GB BF16, 4 steps |
4. Surrogate design (the removed 12 blocks)
Form: out = x + B·σ(A·x), with A: d→r, B: r→d, r=512, σ = GELU-family (match teacher block activation). Shortcut x preserved exactly — we only approximate the block's residual delta.
Compression at r=512 (vs d=3072):
- Bottleneck = 1/6 (16.7%) of residual-stream width.
110M params/block → ~3.15M/block ⇒ **35× param compression** (keep ~3% of the block).200 → ~10 GFLOP/block ⇒ **21× cheaper** (skips attention/N²).- Rank is a quality knob, not a speed knob: across r∈{128,256,512} total-model FLOPs move <3% and speedup stays ~1.65×. 512 = most capacity at no speed cost. (Escape hatch: 768/1024 still <5% of total FLOPs if we ever want less compression.)
Which 12 to surrogate vs which 8 to keep full: data-driven on a calibration batch. For each single block, compute the SVD energy of its residual delta (input→output map). Blocks whose top-512 singular values capture most energy are cheap to approximate → surrogate them; high-rank residuals → keep full. (Falls back to / cross-checks against block-influence/CKA importance. Middle blocks are usually the most compressible.)
Init (warm start): the nonlinearity breaks the exact closed-form least-squares init, so:
- Init
A,Bvia truncated-SVD of the linear least-squares solution (treat σ≈identity / scale soA·xlands in σ's near-linear region). - Short offline calibration fit — a few hundred gradient steps per surrogate to match each block's I/O on the calibration set, before the main run. Cheap, offline, not part of the training-budget run.
Sanity: shape + forward-parity checks after surgery; confirm full width d=3072 so feature matching needs no projector.
5. Distillation methodology (online teacher)
Convention: x_t = (1−t)·noise + t·latent (t=0 noise → t=1 data); straight-line velocity v = latent − noise.
Primary — trajectory velocity matching (instills 4-step behavior):
- Roll out the teacher's 4-step sampler from noise, collecting the 4 visited
(x_{t_i}, t_i)points — the exact points the student sees at inference. - At each point: teacher velocity
v_teacheris the target; student predictsv_student; loss = MSE. - Concentrate timesteps on the 4 schedule points, not uniform
t. This is the departure from the old plan's data-interpolant sampling, and it's what kills few-step exposure bias. - v1: teacher-trajectory rollout (stable, cheap). Optional polish later: a fraction of student-trajectory (on-policy) steps where the teacher corrects student-visited points (DAgger-style) to close residual exposure bias.
Secondary — masked feature matching on retained blocks:
- Match student↔teacher hidden states at the retained-block positions (no projector — widths match). Match each kept block to its original teacher position (indices renumber after the prune).
- Masked KD: exclude DiT extreme-activation outliers (massive activations) from the feature loss.
Grounding — real-data flow-matching:
- Standard FM loss on real latents at randomly sampled
t(broad timestep coverage + grounding). Cheap (1 forward).
Loss: w1·velocity-MSE (trajectory) + w2·feature-match (masked) + w3·real-data FM. Tune weights; velocity primary.
Multi-timestep / epochs: fresh noise + fresh rollout per image per epoch (new supervision, not repetition). 4 epochs × 300k ≈ 1.2M views.
6. Optimizer & precision
- Muon on 2D hidden weights (attention/MLP, incl. surrogate
A,B) + AdamW on embedders,proj_out, norms, shared AdaLN modulation. Muon LR runs higher than AdamW — tune LR + WD per group. - BF16 compute, FP32 master weights + FP32 moments. No quantization in training.
- Micro-batch 64, gradient-accumulate to effective ~256 (Muon's large-batch sweet spot). No gradient checkpointing planned (fits).
torch.compile(CUDA-graph mode); FlashAttention (Blackwell-native, via FlexAttention) for the retained blocks.
7. Data pipeline
- 300k curated images + captions, diversity-first (broad styles/subjects/composition + in-image text coverage). Stable IDs. (Sourcing/licensing for an Apache-2.0 release is an open item — §13.)
- VAE-encode offline → BF16 latents in streaming shards (WebDataset / tensordict). VAE never runs in the training loop.
- Qwen3-4B text embeddings: run online (intermediate layers, frozen). Optionally batch-precompute per shard if I/O-bound (deferred — not a v1 priority).
- Dataloader streams
(latent, caption)with aggressive prefetch; per step samples noise and forms the teacher rollout. - Bulk training at 512px; optional short 768/1024 polish if time/memory headroom remains.
8. Compute & infra (1× B200, 192 GB)
Memory (fixed ~45 GB): FP32 master ~9.8 + BF16 weights ~4.9 + BF16 grads ~4.9 + Muon momentum ~9.0 + AdamW m,v ~1.6 + teacher 4B (fwd) ~6.8 + Qwen3-4B (fwd) ~8.0. → ~137 GB free for activations. (Student is ~2.05B; weight/grad/state lines scale to that.)
Batch: activations ≈ 1.63 GB/sample @512px → micro-batch 64 ≈ ~104 GB activations; fits without checkpointing.
Throughput / MFU: the loop runs three forwards (student fwd+bwd, teacher fwd, Qwen3-4B fwd) + compile + FlexAttention. Plan around 30–35% MFU; treat 40%+ as upside. Validate MFU on a ~100-step smoke run before committing the full run.
Time/cost (context only — budget de-prioritized): ~13–17 h depending on MFU; GPU cost ~$66–85. Data curation/captioning is a separate, larger cost not in this figure (§13).
9. De-risking / early experiments (run before the full run)
Two unproven components sit in the critical path — isolate each on a small first chunk:
- Optimizer A/B: Muon vs AdamW on the first training chunk (loss-curve + sample quality). Confirm Muon's warm-start recovery before committing.
- Surrogate A/B: low-rank+nonlinearity surrogate vs plain block-drop (no replacement), same chunk — confirms the surrogate is earning its keep.
- Rank/calibration check: verify SVD-init + calibration-fit reconstruction error per surrogate; confirm r=512 (vs 256) recovery delta.
- MFU smoke: 100-step compiled run to confirm ≥30% before the long run.
10. Eval & go/no-go
- Metrics: FID/CLIP + targeted prompt probes (in-image text/typography, multi-object composition/counting, spatial/attribute binding, texture/fine detail, small faces/hands) — all measured vs teacher-at-4-steps (the realistic ceiling).
- Go/no-go thresholds (set concretely before the run): e.g. CLIP within X of teacher@4, FID within Y, no probe category catastrophically below teacher@4. Define X/Y on a held-out probe set during the early chunk.
- Downstream smoke test: one narrow-domain distill-down pass to confirm the kit path works end-to-end.
11. Distill-down kit (downstream specialization)
Ship the 2B plus a kit so users specialize it via a bigger teacher rather than fine-tuning the 2B directly — best for small downstream datasets.
- Mechanism: fine-tune the larger teacher (4B or 9B) on the narrow domain → use it as a data amplifier (synthetic in-domain samples) → distill into the 2B via velocity matching (same loop as §5, teacher swappable).
- Cross-size validity: whole klein family shares one VAE/latent space → 9B→2B and 4B→2B velocity matching is well-defined. Student stays on Qwen3-4B regardless of teacher.
- Ship: the block-correspondence map (student block ↔ teacher block) + projector recipe for optional feature matching + scripts + minimal config — the hard part users can't easily derive.
- License: 4B teacher path = commercial-OK (Apache 2.0); 9B teacher path = non-commercial only.
12. Build order (milestones)
- Verify 4B config (block counts, d, encoder) — gate before surgery.
- Data + latents — curate 300k diverse pairs; VAE-encode to shards.
- Surgery — SVD-energy block selection (keep 8 / surrogate 12); build low-rank+nonlinearity surrogates; SVD-init + offline calibration fit; forward-parity checks.
- De-risking chunk (§9) — optimizer A/B, surrogate A/B, rank check, MFU smoke. Decide go/no-go.
- Recovery run — Muon+AdamW, online 4-step teacher, trajectory velocity + masked feature + real-data FM, 1.2M views @512px, BF16, effective batch 256.
- Eval (§10) — FID/CLIP + probes vs teacher@4; downstream smoke test.
- (Optional) polish — short 768/1024 phase.
- Release — 2B weights + distill-down kit + model card (state: 4-step, CFG-free, distill-down-target positioning).
13. Open items / risks
- Verify 4B architecture from config before surgery (gates everything). [§2, §12.1]
- Data sourcing & licensing — 300k curated/captioned images for an Apache-2.0 release; cost + provenance not yet settled. Curate for breadth, not volume; quality gaps will track wherever the set is thin.
- Muon on warm-start recovery under-validated → A/B (§9.1).
- Surrogate quality vs compression — r=512 default; per-block rank could be tuned by reconstruction target if some blocks under-recover.
- 1.2M views is partial recovery in absolute terms → monitor loss/probes; be willing to extend if a capability (esp. in-image text) stays weak.
- MFU sensitivity — a sloppy loop at ~25% ≈ 1.6× longer/costlier; the §8 optimizations + smoke test are the guardrail.
- Positioning discipline — model card must be explicit that the 2B is a narrow-domain distill-down target, not a directly-fine-tunable general base.
Xet Storage Details
- Size:
- 24.2 kB
- Xet hash:
- de7177f23f27e7ff48a2b0dc9416387f315e8078bdf256e2c5c0b18949125e32
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.