Buckets:

Mercity/FluxDistill / block_surgery_plan.md
Pranav2748's picture
|
download
raw
24.2 kB

FLUX.2 [klein] 4B → ~2B Distillation — Build Plan

Status: Locked v1 (supersedes init-plan.md) · Compute: 1× B200 (192 GB) · Teacher: klein distilled 4-step 4B (Apache 2.0) Deliverable: a ~2.05B, 4-step, CFG-free text-to-image model positioned as a fast inference endpoint + distill-down target for narrow domains, plus the downstream distill-down kit.

This doc reflects the decisions locked in discussion. Where it departs from init-plan.md, this doc wins.


Addendum (2026-05-31) — verified config, dataset, dev-environment scope

Architecture verified from black-forest-labs/FLUX.2-klein-4B/transformer/config.json (the §2 gate is cleared):

  • num_attention_heads=24 × attention_head_dim=128d=3072 ✓; num_layers=5 (double) ✓; num_single_layers=20 (single) ✓; joint_attention_dim=7680 ✓; in_channels=128 ✓.
  • is_distilled=true ✓ (correct checkpoint); guidance_embeds=false → fully CFG-free, no guidance vector at all (nothing to inherit).
  • Corrections to earlier assumptions: mlp_ratio=3.0 (not 4.0 — re-derive per-block params from the real state dict); transformer weights = 7.75 GB BF16 ≈ 3.87B params (the "4B" is the denoiser). Stack: Flux2Transformer2DModel / Qwen3ForCausalLM / AutoencoderKLFlux2 / FlowMatchEulerDiscreteScheduler, diffusers 0.37.0.dev0 (install from git).

Dataset: jasperai/monet — large; stream it. For this dev pass use 20–30k rows @ 512×512 for the distillation/grounding signal (full 300k is the later large-GPU run).

Empirical findings from the first surgery run (2026-05-31):

  • Student is ~2.44B, not ~2.05B. Real Flux2SingleTransformerBlock = 122.7M and Flux2TransformerBlock (double) = 245.4M (5 doubles = 1.23B un-prunable floor). Keep 8 single + 5 double + embedders ≈ 2.40B core; +12 surrogates @ r=512 (~3.15M each = 38M) → 2.44B (matches the original init-plan estimate; the "2.05B" correction was wrong).
  • A per-token low-rank surrogate cannot reproduce attention token-mixing — warm-start residual rel-err ≈ 0.9 (only the MLP-ish per-token component is captured). Those 12 positions effectively become per-token MLP layers; the student is "13 attention blocks + 12 per-token layers." Pre-training images are heavily degraded (text destroyed, coarse layout partially survives) — expected pre-recovery. → surrogate design is the primary thing to iterate (options: cheap token-mixing surrogate, keep more single blocks full, or rely on distillation to compensate via the remaining 13 attention layers). Block selection via absolute tail-energy chose to surrogate single blocks 0–11 and keep 12–19 full.

Surgery v2 decision (2026-05-31, after eval): the per-token surrogate is the wrong function class (can't token-mix), so v1 (drop 12) is non-functional. v2: keep most single blocks full, drop only ~4–6 selected by measured importance (leave-one-out ablation: skip each block, measure relative change in the final latent; drop the least-important). Dropping 6 → keep 14 full → student 3.16B (1.2× faster) but functional; this is the conservative staged baseline. Aggressive compression returns once the surrogate can token-mix.

TODO / to think about — cheap token-mixing surrogate (the real lever for getting back to ~2B): replace the per-token x + B·σ(A·x) with a surrogate that moves information between token positions at low cost, so we can drop many blocks again without collapse. Candidates: (a) local-window attention (cheap neighborhood mixing — cheapest path to restore edges/texture); (b) linear / kernelized attention (Performer-style, O(N), genuine global mixing); (c) token-mixing MLP (MLP-Mixer along the token axis). Warm-start target: per-block residual rel-err ≤0.3–0.4 before training. This is what makes the original "drop 12 → ~2.4B" viable.

RESULT (2026-05-31) — linear-attention surrogate wins decisively. Implemented LinearAttentionSurrogate (O(N) linear attention, elu+1 feature map, 8 heads × 64, zero-init output → identity start; scripts/09_build_linattn.py). Same 6 dropped blocks (12–17), same held-out eval batch, AdamW @ constant 1e-4. Eval velocity-loss 0.5046 → 0.2528 (−50%) vs the per-token surrogate's 0.308 — an 18% lower floor, and it crossed below the per-token's converged floor by step 50 (0.282). Samples visibly sharper (near-teacher lake; coherent, detailed storefront with partially legible sign) vs the per-token's soft output. Cost ~equal: surrogates 38M vs 19M (1.2% vs 0.6% of model), O(N) so negligible speed impact. Conclusion: token-mixing is the right surrogate; this unblocks dropping more blocks (toward the original ~2B target) while staying recoverable. Next: push drop_k (10–12) with the linear-attn surrogate. (Minor: constant LR jittered at the floor at step 250 — use the floored-cosine for the real run.)


Session summary (2026-05-31) — full deep-dive in report/REPORT.pdf

Block-ID convention: single-stream blocks S0..S19 (S0 = first/shallowest after the 5 double blocks; S19 = last/deepest). "Deepest N" = highest IDs.

Where things stand (the measured speed↔quality frontier): best quality = linattn drop-6 (drop S12–S17) +RoPE+conv+warmstart → 0.231 @ 1.15×, 3.18B; best perceptual = drop-8 (drop S10–S17) +focused+FFN → 0.269 @ 1.20×; most compression = drop-10 (drop S8–S17, FFN on S14–S17) → ~0.322 @ 1.26×, 2.74B; v1 per-token drop-12 (drop S0–S11) collapsed.

Two structural ceilings (measured): (1) single blocks are 78% of transformer compute, the 5 double blocks only 21% (per-block ~1.08×) → single-block surgery is the right lever but it's near its useful limit; double-block FFN shrink ≈ 2% speedup, rejected. (2) Single blocks have low redundancy (every one shifts the output ≥43%).

Recommended next: lock drop-7 (FFN on deepest 3 incl. last) as the balanced single-block model (~3.08B, ~1.17×, ~0.25). For a genuine ~2×, the next project is step reduction (4→2) via step-distillation — not more block surgery — plus text-encoder caching and the B200 scale-up.

Basic training run result (2026-05-31): DIVERGED — recipe too aggressive. Loss stable ~0.83 to step ~120, then spiked (2.46 @160) and blew up (8.08 @300); samples degraded from step 0 and collapsed to checkerboard noise. Diagnosis (ranked): (1) Muon lr 0.02 applied to the well-pretrained kept blocks corrupts them — kept blocks should be frozen or trained at a much lower lr than the surrogates; (2) bf16 weights, no fp32 master; (3) no LR warmup, no grad clipping. The step-0 v2 warm-start remained the best checkpoint. This empirically confirms the §8 risk ("Muon on warm-start recovery under-validated → A/B vs AdamW"). Fixes for next run: per-group LRs (surrogates high, kept blocks ~0 or tiny / freeze), LR warmup + grad clip, fp32 master, AdamW-vs-Muon A/B, lower base LR.

Corrected recovery recipe (scripts/08_train_recover.py) — research-led. LR research: Muon's lr≈0.02 is a bulk-pretraining value (nanoGPT speedruns; Kimi/Moonshot RMS-match it to reuse AdamW-scale LRs) — wrong for a few adapter modules and the cause of the blowup. Our surrogates are adapter-like (fresh, small, frozen base) → diffusion/LoRA regime: AdamW ~1e-4–2e-4 (1e-5 too weak, full-finetune 1e-6). Recipe: freeze the whole pretrained net, train **only the 6 surrogates (18.9M, 0.6%)**; AdamW @ 1e-4, cosine decay, grad-clip 1.0, fp32 master on trained params (bf16 autocast compute); fixed held-out eval velocity-loss as the objective metric; per-step image logging; best-ckpt; divergence guard (stop if eval-loss > 3× baseline). LR schedule: cosine from base to a 15%-of-base floor (MIN_LR_FRAC=0.15), no warmup — NOT to 0. Decaying to exactly 0 wasted the last ~30 steps (lr <1e-6 ≈ dead tail) and would break resume/extend (a continued run would start at lr≈0). Also size STEPS to the data so the cosine tail isn't spent doing nothing (this dev run flattened by ~step 250). Muon stays the tool for the later full-recovery run (its §8 A/B), where much of the network trains. Note: surrogate-only recovery is bounded by the per-token surrogate's inability to token-mix — expect stability + modest improvement, not full sharpening; the token-mixing surrogate TODO is what unlocks real recovery.

Result (corrected run, 300 steps, surrogate-only): eval velocity-loss 0.4756 → 0.308, −35%, monotonic, zero divergence, 28.7 GB, ~7 min; samples visibly sharper (mountain ridges/reflection, storefront detail) toward the teacher. Controlled Muon-vs-AdamW A/B (identical recipe, only optimizer swapped): AdamW@1e-4 → 0.3081 (−35.2%), Muon@2e-3 → 0.3056 (−35.8%) — statistically equivalent, both stable. Takeaway: the earlier blow-up was not Muon — it was lr 0.02 on the wrong params (all weights incl. pretrained kept blocks). Used correctly (low lr, only new params) Muon is fine but offers no edge for a handful of adapter modules; reserve it for the bulk full-recovery run.

Dev-environment scope (this box ≠ the target box): running on 1× A100-80GB, not B200. So for now: no FlashAttention-4 / Blackwell kernels — use PyTorch SDPA; smaller batch; goal is a clean, correct, scalable baseline (surgery + calibration + inference + eval + a basic training loop) that lifts onto a bigger GPU later, not a perfected/perf-tuned run. Memory/throughput numbers in §8 are the B200 target, not this box.


0. What changed from init-plan.md (read first)

  1. Teacher is the distilled 4-step 4B, NOT the 50-step base. Pure velocity/output matching (L2 regression) cannot compress step count — it clones a velocity field, and integrating a cloned field needs the same step budget as the source. Distilling from the 50-step base would leave us owing a separate 50→4 step-compression phase (DMD/consistency/adversarial), which we explicitly excluded. Cloning the already-4-step distilled teacher inherits 4-step behavior for free.
  2. Student is ~2.05B, not 2.45B. With true low-rank surrogates (r=512), the 12 surrogates total 38M params, not the ~410M the old plan assumed. Core kept ≈ 2.04B + surrogates ≈ 0.04B ≈ **2.05B**.
  3. Surrogates are low-rank + nonlinearity (x + B·σ(A·x)), r=512, not bare linear. ~Same FLOPs as bare linear (the activation is ~free), more faithful to a removed transformer block. Costs us the exact closed-form init → we add a short offline calibration fit (below).
  4. Distillation is trajectory-aware, sampled along the teacher's 4-step rollout — not the naive x_t=(1−t)·noise+t·latent data interpolant. With only 4 large Euler steps, exposure bias compounds viciously; matching where the student will actually operate is what kills it.
  5. Positioning: the 2B is a fast inference endpoint + distill-down target for narrow domains, not a directly-fine-tunable general base. (Distilled models are harder to fine-tune directly; the distill-down kit routes downstream specialization through the teacher, so the 2B never needs direct fine-tuning.)
  6. CFG-free inheritance: the distilled teacher is guidance-distilled (guidance_scale=1.0, single forward, no CFG doubling). The student inherits the guidance embedder and runs CFG-free too.

1. Objective

Compress FLUX.2 [klein] distilled 4B (rectified-flow MM-DiT, 4-step, guidance-distilled) into a ~2.05B model that:

  • runs in 4 steps (same as teacher — we are not changing step count),
  • is ~1.6–1.7× faster than the 4B and ~5 GB BF16,
  • retains most of teacher-at-4-steps quality,
  • serves as a distill-down target: third parties specialize it to a narrow domain via the kit in §11 (fine-tune a bigger teacher → amplify → distill into the 2B), rather than fine-tuning the 2B directly.

Recovered via knowledge distillation on a single B200, ~half-day.


2. Teacher

  • Model: FLUX.2 [klein] distilled 4B — step- and guidance-distilled, 4 inference steps, guidance_scale=1.0, Apache 2.0.
  • Why this one (not base): see §0.1. Regression preserves step count; we want a 4-step student, so we need a 4-step teacher.
  • Conditioning: Qwen3-4B text encoder (intermediate-layer stack → conditioning vector). Frozen; not part of the denoiser. Student stays on Qwen3-4B (never bolt the 9B's Qwen3-8B encoder on — dimension mismatch).
  • Latent space: shared FLUX.2 VAE. Identical across the klein family — this is what makes the §11 cross-size distill-down valid.
  • Forward cost: single forward per step (CFG-free). Used frozen (eval, no_grad, BF16, compiled) as an online target generator.

Open item (verify before surgery): the exact 4B block layout (init-plan.md asserts 5 double + 20 single, d=3072, 24 heads — unverified from the model card). Pull config.json / the diffusers Flux2 impl and confirm block counts + hidden dim before cutting. The entire prune + correspondence map depends on these numbers. Everything below assumes the asserted layout; adjust if the config differs.


3. Target architecture (~2.05B student)

Full-width, depth-pruned, removed depth replaced by light low-rank surrogates.

Component Spec Rationale
Width d=3072 (unchanged) Preserves texture/detail; widths match teacher → no projector needed for feature matching
Double-stream 5 (all kept) Preserve cross-modal / instruction binding (these are also the un-prunable FLOP floor)
Single-stream 8 kept full (of 20) Depth-only prune; recovers better than width prune
Removed depth 12 single blocks → low-rank+nonlinearity surrogates (r=512) Residual-only low-rank approximation, shortcut preserved; SVD-init + calibration warm-start
Embedders / modulation Inherited from teacher (timestep, guidance, shared-across-block AdaLN) Keeps CFG-free behavior; AdamW param group
Effective size core 2.04B + surrogates ~0.04B ≈ **2.05B**
Inference ~1.6–1.7× faster than 4B, ~5 GB BF16, 4 steps

4. Surrogate design (the removed 12 blocks)

Form: out = x + B·σ(A·x), with A: d→r, B: r→d, r=512, σ = GELU-family (match teacher block activation). Shortcut x preserved exactly — we only approximate the block's residual delta.

Compression at r=512 (vs d=3072):

  • Bottleneck = 1/6 (16.7%) of residual-stream width.
  • 110M params/block → ~3.15M/block ⇒ **35× param compression** (keep ~3% of the block).
  • 200 → ~10 GFLOP/block ⇒ **21× cheaper** (skips attention/N²).
  • Rank is a quality knob, not a speed knob: across r∈{128,256,512} total-model FLOPs move <3% and speedup stays ~1.65×. 512 = most capacity at no speed cost. (Escape hatch: 768/1024 still <5% of total FLOPs if we ever want less compression.)

Which 12 to surrogate vs which 8 to keep full: data-driven on a calibration batch. For each single block, compute the SVD energy of its residual delta (input→output map). Blocks whose top-512 singular values capture most energy are cheap to approximate → surrogate them; high-rank residuals → keep full. (Falls back to / cross-checks against block-influence/CKA importance. Middle blocks are usually the most compressible.)

Init (warm start): the nonlinearity breaks the exact closed-form least-squares init, so:

  1. Init A,B via truncated-SVD of the linear least-squares solution (treat σ≈identity / scale so A·x lands in σ's near-linear region).
  2. Short offline calibration fit — a few hundred gradient steps per surrogate to match each block's I/O on the calibration set, before the main run. Cheap, offline, not part of the training-budget run.

Sanity: shape + forward-parity checks after surgery; confirm full width d=3072 so feature matching needs no projector.


5. Distillation methodology (online teacher)

Convention: x_t = (1−t)·noise + t·latent (t=0 noise → t=1 data); straight-line velocity v = latent − noise.

Primary — trajectory velocity matching (instills 4-step behavior):

  • Roll out the teacher's 4-step sampler from noise, collecting the 4 visited (x_{t_i}, t_i) points — the exact points the student sees at inference.
  • At each point: teacher velocity v_teacher is the target; student predicts v_student; loss = MSE.
  • Concentrate timesteps on the 4 schedule points, not uniform t. This is the departure from the old plan's data-interpolant sampling, and it's what kills few-step exposure bias.
  • v1: teacher-trajectory rollout (stable, cheap). Optional polish later: a fraction of student-trajectory (on-policy) steps where the teacher corrects student-visited points (DAgger-style) to close residual exposure bias.

Secondary — masked feature matching on retained blocks:

  • Match student↔teacher hidden states at the retained-block positions (no projector — widths match). Match each kept block to its original teacher position (indices renumber after the prune).
  • Masked KD: exclude DiT extreme-activation outliers (massive activations) from the feature loss.

Grounding — real-data flow-matching:

  • Standard FM loss on real latents at randomly sampled t (broad timestep coverage + grounding). Cheap (1 forward).

Loss: w1·velocity-MSE (trajectory) + w2·feature-match (masked) + w3·real-data FM. Tune weights; velocity primary.

Multi-timestep / epochs: fresh noise + fresh rollout per image per epoch (new supervision, not repetition). 4 epochs × 300k ≈ 1.2M views.


6. Optimizer & precision

  • Muon on 2D hidden weights (attention/MLP, incl. surrogate A,B) + AdamW on embedders, proj_out, norms, shared AdaLN modulation. Muon LR runs higher than AdamW — tune LR + WD per group.
  • BF16 compute, FP32 master weights + FP32 moments. No quantization in training.
  • Micro-batch 64, gradient-accumulate to effective ~256 (Muon's large-batch sweet spot). No gradient checkpointing planned (fits).
  • torch.compile (CUDA-graph mode); FlashAttention (Blackwell-native, via FlexAttention) for the retained blocks.

7. Data pipeline

  • 300k curated images + captions, diversity-first (broad styles/subjects/composition + in-image text coverage). Stable IDs. (Sourcing/licensing for an Apache-2.0 release is an open item — §13.)
  • VAE-encode offline → BF16 latents in streaming shards (WebDataset / tensordict). VAE never runs in the training loop.
  • Qwen3-4B text embeddings: run online (intermediate layers, frozen). Optionally batch-precompute per shard if I/O-bound (deferred — not a v1 priority).
  • Dataloader streams (latent, caption) with aggressive prefetch; per step samples noise and forms the teacher rollout.
  • Bulk training at 512px; optional short 768/1024 polish if time/memory headroom remains.

8. Compute & infra (1× B200, 192 GB)

Memory (fixed ~45 GB): FP32 master ~9.8 + BF16 weights ~4.9 + BF16 grads ~4.9 + Muon momentum ~9.0 + AdamW m,v ~1.6 + teacher 4B (fwd) ~6.8 + Qwen3-4B (fwd) ~8.0. → ~137 GB free for activations. (Student is ~2.05B; weight/grad/state lines scale to that.)

Batch: activations ≈ 1.63 GB/sample @512px → micro-batch 64 ≈ ~104 GB activations; fits without checkpointing.

Throughput / MFU: the loop runs three forwards (student fwd+bwd, teacher fwd, Qwen3-4B fwd) + compile + FlexAttention. Plan around 30–35% MFU; treat 40%+ as upside. Validate MFU on a ~100-step smoke run before committing the full run.

Time/cost (context only — budget de-prioritized): ~13–17 h depending on MFU; GPU cost ~$66–85. Data curation/captioning is a separate, larger cost not in this figure (§13).


9. De-risking / early experiments (run before the full run)

Two unproven components sit in the critical path — isolate each on a small first chunk:

  1. Optimizer A/B: Muon vs AdamW on the first training chunk (loss-curve + sample quality). Confirm Muon's warm-start recovery before committing.
  2. Surrogate A/B: low-rank+nonlinearity surrogate vs plain block-drop (no replacement), same chunk — confirms the surrogate is earning its keep.
  3. Rank/calibration check: verify SVD-init + calibration-fit reconstruction error per surrogate; confirm r=512 (vs 256) recovery delta.
  4. MFU smoke: 100-step compiled run to confirm ≥30% before the long run.

10. Eval & go/no-go

  • Metrics: FID/CLIP + targeted prompt probes (in-image text/typography, multi-object composition/counting, spatial/attribute binding, texture/fine detail, small faces/hands) — all measured vs teacher-at-4-steps (the realistic ceiling).
  • Go/no-go thresholds (set concretely before the run): e.g. CLIP within X of teacher@4, FID within Y, no probe category catastrophically below teacher@4. Define X/Y on a held-out probe set during the early chunk.
  • Downstream smoke test: one narrow-domain distill-down pass to confirm the kit path works end-to-end.

11. Distill-down kit (downstream specialization)

Ship the 2B plus a kit so users specialize it via a bigger teacher rather than fine-tuning the 2B directly — best for small downstream datasets.

  • Mechanism: fine-tune the larger teacher (4B or 9B) on the narrow domain → use it as a data amplifier (synthetic in-domain samples) → distill into the 2B via velocity matching (same loop as §5, teacher swappable).
  • Cross-size validity: whole klein family shares one VAE/latent space → 9B→2B and 4B→2B velocity matching is well-defined. Student stays on Qwen3-4B regardless of teacher.
  • Ship: the block-correspondence map (student block ↔ teacher block) + projector recipe for optional feature matching + scripts + minimal config — the hard part users can't easily derive.
  • License: 4B teacher path = commercial-OK (Apache 2.0); 9B teacher path = non-commercial only.

12. Build order (milestones)

  1. Verify 4B config (block counts, d, encoder) — gate before surgery.
  2. Data + latents — curate 300k diverse pairs; VAE-encode to shards.
  3. Surgery — SVD-energy block selection (keep 8 / surrogate 12); build low-rank+nonlinearity surrogates; SVD-init + offline calibration fit; forward-parity checks.
  4. De-risking chunk (§9) — optimizer A/B, surrogate A/B, rank check, MFU smoke. Decide go/no-go.
  5. Recovery run — Muon+AdamW, online 4-step teacher, trajectory velocity + masked feature + real-data FM, 1.2M views @512px, BF16, effective batch 256.
  6. Eval (§10) — FID/CLIP + probes vs teacher@4; downstream smoke test.
  7. (Optional) polish — short 768/1024 phase.
  8. Release — 2B weights + distill-down kit + model card (state: 4-step, CFG-free, distill-down-target positioning).

13. Open items / risks

  • Verify 4B architecture from config before surgery (gates everything). [§2, §12.1]
  • Data sourcing & licensing — 300k curated/captioned images for an Apache-2.0 release; cost + provenance not yet settled. Curate for breadth, not volume; quality gaps will track wherever the set is thin.
  • Muon on warm-start recovery under-validated → A/B (§9.1).
  • Surrogate quality vs compression — r=512 default; per-block rank could be tuned by reconstruction target if some blocks under-recover.
  • 1.2M views is partial recovery in absolute terms → monitor loss/probes; be willing to extend if a capability (esp. in-image text) stays weak.
  • MFU sensitivity — a sloppy loop at ~25% ≈ 1.6× longer/costlier; the §8 optimizations + smoke test are the guardrail.
  • Positioning discipline — model card must be explicit that the 2B is a narrow-domain distill-down target, not a directly-fine-tunable general base.

Xet Storage Details

Size:
24.2 kB
·
Xet hash:
de7177f23f27e7ff48a2b0dc9416387f315e8078bdf256e2c5c0b18949125e32

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.