Buckets:
Status: Proposal · Compute target: 1× B200, ~13–15h, ~$66–75 · Deliverable: a ~2B open base model + a downstream distill-down kit.
1. Objective
Compress FLUX.2 [klein] 4B (rectified-flow MM-DiT) into a ~2B general-purpose text-to-image base model that retains most of the teacher's capability, runs meaningfully faster, and serves as a strong foundation for downstream narrow-domain fine-tuning by third parties (e.g. wallpapers, single styles, product imagery). The model is recovered via knowledge distillation on a single B200 within a strict ~$70 / half-day budget. We deliberately do not reduce inference steps; we accept 4→8 steps in exchange for the smaller, faster backbone.
2. Background
- Teacher: FLUX.2 [klein] base 4B (undistilled — full training signal, best for fine-tuning). Architecture:
d=3072,24 heads, 5 double-stream + 20 single-stream MM-DiT blocks, shared-across-block AdaLN modulation, no bias. - Conditioning: Qwen3-4B text encoder (intermediate-layer stack → 7680-dim conditioning). Frozen; not part of the 4B denoiser.
- Latent space: shared FLUX.2 VAE (128-ch latent) — identical across the whole klein family (4B & 9B). This is what makes cross-size velocity distillation valid (see §7).
- Capacity distribution: the 20 single-stream blocks hold ~67% of backbone params; the 5 double-stream blocks (cross-modal binding) hold ~33%. This asymmetry drives the architecture choice below.
- License: klein 4B is Apache 2.0 (commercial-friendly); klein 9B is non-commercial.
3. Target Architecture (~2B student)
Full-width, depth-pruned, with lightweight block replacement for the removed depth.
| Component | Spec | Rationale |
|---|---|---|
| Width | d=3072 (unchanged) |
Same as teacher → preserves texture/detail capacity; feature matching needs no projector |
| Double-stream | 5 (all kept) | Preserve cross-modal/instruction binding |
| Single-stream | 8 kept full (of 20) | Depth-only prune; recovers far better than width prune |
| Removed depth | 12 single blocks → light BRLL surrogates | Linear/low-rank residual-branch replacements, shortcut preserved; least-squares (training-free) init = warm start |
| Effective size | core |
|
| Inference | ~1.7× faster than 4B, ~5 GB BF16 | Still runs on consumer GPUs |
4. Methodology
Initialization
- Prune from klein-base-4B; inherit surviving block weights.
- Replace the 12 removed single-stream blocks with BRLL-style surrogates, least-squares initialized from teacher activations (training-free warm start → less to recover).
Distillation (online teacher)
- Velocity/output matching against the 4B teacher (flow-matching target) — primary signal, runs on unlimited teacher-generated targets.
- Feature matching on retained blocks — free here because student/teacher widths match (no projector). Use masked KD (exclude DiT extreme-activation outliers).
- Real-data flow-matching loss on curated images for grounding.
- No adversarial (LADD) phase in v1 — too costly/finicky for the budget; revisit for a v2 texture polish.
Data pipeline
- 300k curated images + captions, diversity-first (broad styles/subjects/composition + in-image text coverage).
- Precompute VAE latents offline (~79 GB @512px BF16, streaming shards). VAE never runs in the training loop.
- Recompute Qwen3-4B text embeddings online (caching them would be ~2.4 TB — not worth it).
- Multi-timestep sampling: fresh random timestep per image per epoch (epochs = new velocity-field supervision, not repetition).
- 4 epochs × 300k = 1.2M views. Views kept at full count (quality prioritized over the ~$17 saving).
Optimizer & precision
- Muon on 2D hidden weights (attention/MLP) + AdamW on embedders,
proj_out, norms, modulation. Tune LR + weight decay only. - BF16 compute, FP32 master weights + FP32 moments. No quantization in training.
Resolution
- Bulk training at 512px; optional short 768/1024 polish phase if memory/time headroom remains.
5. Compute & Infrastructure (1× B200, 192 GB)
Memory budget (fixed ~45 GB)
| Item | GB |
|---|---|
| FP32 master weights | ~9.8 |
| BF16 weights (compute) | ~4.9 |
| BF16 grads | ~4.9 |
| Muon momentum (1 moment, FP32) | ~9.0 |
| AdamW m,v (FP32, ~8% of params) | ~1.6 |
| Teacher 4B (BF16, fwd only) | ~6.8 |
| Qwen3-4B encoder (BF16, fwd only) | ~8.0 |
| Total fixed | ~45 |
→ ~137 GB free for activations. Muon saves ~8 GB vs full AdamW.
Batch sizing
- Activations ≈ 1.63 GB/sample @512px (FlashAttention, no checkpointing).
- Max micro-batch ≈ 84; run micro-batch 64, gradient-accumulate to effective ~256 (Muon's large-batch sweet spot). No gradient checkpointing needed.
- ~18,750 micro-steps (≈4,700 optimizer steps at accum-4).
Throughput optimizations
torch.compile(CUDA-graphs mode) — required to hit target MFU.- FlashAttention-4 via PyTorch FlexAttention (BF16, Blackwell-native).
- Precomputed latents + aggressive prefetch (keeps GPU fed, protects MFU).
Time & cost — 1.2M views × ~40 TFLOP/view:
| MFU | Wall-clock | Cost (@~$5/h) |
|---|---|---|
| 35% | ~17 h | ~$85 |
| 40% | ~15 h | ~$75 |
| 45% | ~13 h | ~$66 |
Plus a one-time ~30–60 min offline VAE-latent precompute (can run on cheaper hardware).
6. Expected Outcomes & Quality Profile
A full-width 2B with a small, defensible gap to the 4B.
- Robust (holds up well): global structure, color, lighting, coherence; simple/moderate prompt adherence (encoder + double-stream preserved); texture/fine detail (full width retained).
- Degrades (worst first): dense in-image text/typography; complex multi-object composition, counting, precise spatial/attribute binding; small faces/hands and micro-detail.
- Forgiving for the use case: downstream narrow fine-tuning re-concentrates the 2B's capacity on one domain, so lost broad generality matters far less for end users.
7. Downstream "Distill-Down" Kit
Ship the 2B base plus a kit so users can specialize it via a bigger teacher rather than fine-tuning the 2B directly — best for small downstream datasets.
- Mechanism: fine-tune the larger teacher (4B or 9B) on the narrow domain → use it as a data amplifier (generate synthetic in-domain samples) → distill into the 2B via velocity matching.
- Cross-size validity: the entire klein family shares one VAE/latent space, so 9B→2B and 4B→2B velocity matching is well-defined. Each model uses its own text encoder; the student stays on Qwen3-4B (do not bolt the 9B's Qwen3-8B encoder onto the 2B — dimension mismatch).
- Asset to ship: because the 2B descends from the 4B, document the block-correspondence map (which student block ↔ which teacher block) + projector recipe for optional feature matching — the hard part users can't easily derive themselves.
- License note: 4B teacher path = commercial-OK (Apache 2.0); 9B teacher path = non-commercial only.
8. Risks & Open Questions
- Muon on warm-start recovery is under-validated (strong evidence is LLM/from-scratch + one flow-matching benchmark). → A/B against AdamW on the first training chunk before committing.
- 1.2M views is a partial recovery in absolute terms. → Monitor the loss curve; be willing to extend if a specific capability (e.g. text) stays weak.
- Surrogate weight is a tradeoff: light = cheap warm start, near-2B FLOPs; heavy = better quality but more cost and less speedup. Default light.
- Coverage holes: quality gaps will track wherever the 300k set is thin. Curate for breadth, not volume.
- MFU sensitivity: a sloppy loop at ~25% MFU ≈ 1.6× longer/costlier; the §5 optimizations are how we hit 40–45%.
9. Implementation & Engineering Setup
What we actually build, in dependency order.
A. Data & latent caching (offline, one-time)
- Curate + caption 300k images; assign stable IDs.
- Batch VAE-encode all images → BF16 latents written to streaming shards (WebDataset / tensordict). The VAE runs only here, never in the training loop.
- Dataloader streams
(latent, caption)with aggressive prefetch; per step samples timesteptand noise, formsx_t = (1−t)·noise + t·latent. - Text embeddings: run Qwen3-4B online, intercepting intermediate layers (→ 7680-dim); encoder frozen. (Optionally batch-precompute per shard if I/O-bound.)
B. Model surgery (prune + BRLL surrogates)
- Load klein-base-4B transformer (diffusers Flux2 / BFL reference impl).
- Depth-select: keep 5 double + 8 of 20 single (importance / CKA similarity or learnable mask); inherit weights for kept blocks.
- Replace the 12 removed single blocks with BRLL surrogates — linear/low-rank residual branch with the shortcut preserved; least-squares-init from teacher activations on a calibration batch (training-free warm start).
- Sanity: shape + forward-parity checks. Full width
d=3072→ no projector needed for feature matching.
C. Distillation training loop (online teacher)
- Teacher 4B loaded frozen (eval,
no_grad, BF16, compiled). - Per step:
latents + prompts → text emb → sample t, noise → x_t; teacher forward → target velocityv_teacher; student forward →v_student. - Loss =
w1·velocity-MSE(primary) +w2·feature-match(masked KD on retained blocks) +w3·real-data flow-matching. Tune weights. - Param-group split: Muon on 2D hidden weights, AdamW on embedders /
proj_out/ norms / modulation. - BF16 autocast, FP32 master weights + moments; grad-accumulate to effective ~256;
torch.compilethe student; FlashAttention-4 via FlexAttention.
D. Infra, logging & eval
- Single B200, micro-batch 64, resumable checkpointing.
- Log per-component losses, throughput / MFU, and periodic sample grids for eyeballing.
- Eval harness: FID/CLIP + targeted prompt probes (text rendering, composition/counting, texture) vs the 4B; downstream fine-tune smoke test.
E. Distill-down kit (reuses C)
- Same loop with the teacher swappable (4B or 9B); student stays on the Qwen3-4B encoder.
- Flow: domain fine-tune the teacher → generate synthetic in-domain samples → velocity-match into the 2B.
- Package: block-correspondence map + projector recipe + scripts + minimal config.
10. Milestones
- Data + latents — curate 300k diverse pairs; precompute VAE latents to shards.
- Architecture — prune to 5D+8S full-width; least-squares-init BRLL surrogates.
- Recovery run — Muon+AdamW, online teacher, velocity + feature + FM losses, 1.2M views @512px, BF16.
- Eval — FID/CLIP + targeted probes (text, composition, texture) vs 4B; downstream fine-tune smoke test.
- (Optional) Polish — short 768/1024 phase; optional LADD for v2.
- Release — 2B base weights + distill-down kit (correspondence map, scripts) + model card.
11. Budget Summary
| Line | Estimate |
|---|---|
| Recovery training (1× B200, 1.2M views) | ~$66–75 |
| VAE-latent precompute (one-time) | <$5 |
| Total per run | ~$70 |
Experiments to be tracked individually under this page per the standard experiment format.
Xet Storage Details
- Size:
- 11.4 kB
- Xet hash:
- c67552ce266c3e9c9873b704c95c412308e6dbce619fdebe9d2d96744e32cf63
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.