Buckets:
| > **Status:** Proposal · **Compute target:** 1× B200, ~13–15h, ~$66–75 · **Deliverable:** a ~2B open base model + a downstream distill-down kit. | |
| > | |
| # 1. Objective | |
| Compress **FLUX.2 [klein] 4B** (rectified-flow MM-DiT) into a **~2B general-purpose text-to-image base model** that retains most of the teacher's capability, runs meaningfully faster, and serves as a strong foundation for downstream **narrow-domain fine-tuning** by third parties (e.g. wallpapers, single styles, product imagery). The model is recovered via knowledge distillation on a single B200 within a strict ~$70 / half-day budget. We deliberately do **not** reduce inference steps; we accept 4→8 steps in exchange for the smaller, faster backbone. | |
| # 2. Background | |
| - **Teacher:** FLUX.2 [klein] **base** 4B (undistilled — full training signal, best for fine-tuning). Architecture: `d=3072`, `24 heads`, **5 double-stream + 20 single-stream** MM-DiT blocks, shared-across-block AdaLN modulation, no bias. | |
| - **Conditioning:** Qwen3-4B text encoder (intermediate-layer stack → 7680-dim conditioning). Frozen; not part of the 4B denoiser. | |
| - **Latent space:** shared FLUX.2 VAE (128-ch latent) — **identical across the whole klein family (4B & 9B)**. This is what makes cross-size velocity distillation valid (see §7). | |
| - **Capacity distribution:** the 20 single-stream blocks hold **~67%** of backbone params; the 5 double-stream blocks (cross-modal binding) hold ~33%. This asymmetry drives the architecture choice below. | |
| - **License:** klein 4B is Apache 2.0 (commercial-friendly); klein 9B is non-commercial. | |
| # 3. Target Architecture (~2B student) | |
| Full-width, depth-pruned, with lightweight block replacement for the removed depth. | |
| | Component | Spec | Rationale | | |
| | --- | --- | --- | | |
| | Width | `d=3072` (unchanged) | Same as teacher → preserves texture/detail capacity; feature matching needs no projector | | |
| | Double-stream | **5 (all kept)** | Preserve cross-modal/instruction binding | | |
| | Single-stream | **8 kept full** (of 20) | Depth-only prune; recovers far better than width prune | | |
| | Removed depth | **12 single blocks → light BRLL surrogates** | Linear/low-rank residual-branch replacements, shortcut preserved; least-squares (training-free) init = warm start | | |
| | Effective size | core ~2.04B + surrogates ~0.41B ≈ **~2.45B** | | | |
| | Inference | **~1.7× faster than 4B**, ~5 GB BF16 | Still runs on consumer GPUs | | |
| # 4. Methodology | |
| **Initialization** | |
| - Prune from klein-**base**-4B; inherit surviving block weights. | |
| - Replace the 12 removed single-stream blocks with BRLL-style surrogates, **least-squares initialized** from teacher activations (training-free warm start → less to recover). | |
| **Distillation (online teacher)** | |
| - **Velocity/output matching** against the 4B teacher (flow-matching target) — primary signal, runs on unlimited teacher-generated targets. | |
| - **Feature matching** on retained blocks — free here because student/teacher widths match (no projector). Use **masked KD** (exclude DiT extreme-activation outliers). | |
| - **Real-data flow-matching loss** on curated images for grounding. | |
| - **No adversarial (LADD) phase in v1** — too costly/finicky for the budget; revisit for a v2 texture polish. | |
| **Data pipeline** | |
| - **300k curated images + captions**, diversity-first (broad styles/subjects/composition + in-image text coverage). | |
| - **Precompute VAE latents offline** (~79 GB @512px BF16, streaming shards). VAE never runs in the training loop. | |
| - **Recompute Qwen3-4B text embeddings online** (caching them would be ~2.4 TB — not worth it). | |
| - **Multi-timestep sampling**: fresh random timestep per image per epoch (epochs = new velocity-field supervision, not repetition). | |
| - **4 epochs × 300k = 1.2M views.** Views kept at full count (quality prioritized over the ~$17 saving). | |
| **Optimizer & precision** | |
| - **Muon** on 2D hidden weights (attention/MLP) + **AdamW** on embedders, `proj_out`, norms, modulation. Tune LR + weight decay only. | |
| - **BF16** compute, FP32 master weights + FP32 moments. No quantization in training. | |
| **Resolution** | |
| - Bulk training at **512px**; optional short 768/1024 polish phase if memory/time headroom remains. | |
| # 5. Compute & Infrastructure (1× B200, 192 GB) | |
| **Memory budget (fixed ~45 GB)** | |
| | Item | GB | | |
| | --- | --- | | |
| | FP32 master weights | ~9.8 | | |
| | BF16 weights (compute) | ~4.9 | | |
| | BF16 grads | ~4.9 | | |
| | Muon momentum (1 moment, FP32) | ~9.0 | | |
| | AdamW m,v (FP32, ~8% of params) | ~1.6 | | |
| | Teacher 4B (BF16, fwd only) | ~6.8 | | |
| | Qwen3-4B encoder (BF16, fwd only) | ~8.0 | | |
| | **Total fixed** | **~45** | | |
| → ~137 GB free for activations. Muon saves ~8 GB vs full AdamW. | |
| **Batch sizing** | |
| - Activations ≈ **1.63 GB/sample** @512px (FlashAttention, no checkpointing). | |
| - **Max micro-batch ≈ 84**; run **micro-batch 64**, gradient-accumulate to **effective ~256** (Muon's large-batch sweet spot). No gradient checkpointing needed. | |
| - ~18,750 micro-steps (≈4,700 optimizer steps at accum-4). | |
| **Throughput optimizations** | |
| - `torch.compile` (CUDA-graphs mode) — required to hit target MFU. | |
| - **FlashAttention-4** via PyTorch FlexAttention (BF16, Blackwell-native). | |
| - Precomputed latents + aggressive prefetch (keeps GPU fed, protects MFU). | |
| **Time & cost** — 1.2M views × ~40 TFLOP/view: | |
| | MFU | Wall-clock | Cost (@~$5/h) | | |
| | --- | --- | --- | | |
| | 35% | ~17 h | ~$85 | | |
| | 40% | ~15 h | ~$75 | | |
| | 45% | ~13 h | ~$66 | | |
| Plus a one-time ~30–60 min offline VAE-latent precompute (can run on cheaper hardware). | |
| # 6. Expected Outcomes & Quality Profile | |
| A full-width 2B with a small, defensible gap to the 4B. | |
| - **Robust (holds up well):** global structure, color, lighting, coherence; simple/moderate prompt adherence (encoder + double-stream preserved); texture/fine detail (full width retained). | |
| - **Degrades (worst first):** dense in-image text/typography; complex multi-object composition, counting, precise spatial/attribute binding; small faces/hands and micro-detail. | |
| - **Forgiving for the use case:** downstream narrow fine-tuning re-concentrates the 2B's capacity on one domain, so lost broad generality matters far less for end users. | |
| # 7. Downstream "Distill-Down" Kit | |
| Ship the 2B base **plus** a kit so users can specialize it via a bigger teacher rather than fine-tuning the 2B directly — best for **small downstream datasets**. | |
| - **Mechanism:** fine-tune the larger teacher (4B or 9B) on the narrow domain → use it as a data amplifier (generate synthetic in-domain samples) → distill into the 2B via **velocity matching**. | |
| - **Cross-size validity:** the entire klein family shares one VAE/latent space, so 9B→2B and 4B→2B velocity matching is well-defined. Each model uses its own text encoder; **the student stays on Qwen3-4B** (do not bolt the 9B's Qwen3-8B encoder onto the 2B — dimension mismatch). | |
| - **Asset to ship:** because the 2B descends from the 4B, document the **block-correspondence map** (which student block ↔ which teacher block) + projector recipe for optional feature matching — the hard part users can't easily derive themselves. | |
| - **License note:** 4B teacher path = commercial-OK (Apache 2.0); 9B teacher path = non-commercial only. | |
| # 8. Risks & Open Questions | |
| - **Muon on warm-start recovery is under-validated** (strong evidence is LLM/from-scratch + one flow-matching benchmark). → A/B against AdamW on the first training chunk before committing. | |
| - **1.2M views is a partial recovery in absolute terms.** → Monitor the loss curve; be willing to extend if a specific capability (e.g. text) stays weak. | |
| - **Surrogate weight is a tradeoff:** light = cheap warm start, near-2B FLOPs; heavy = better quality but more cost and less speedup. Default light. | |
| - **Coverage holes:** quality gaps will track wherever the 300k set is thin. Curate for breadth, not volume. | |
| - **MFU sensitivity:** a sloppy loop at ~25% MFU ≈ 1.6× longer/costlier; the §5 optimizations are how we hit 40–45%. | |
| # 9. Implementation & Engineering Setup | |
| What we actually build, in dependency order. | |
| **A. Data & latent caching (offline, one-time)** | |
| - Curate + caption 300k images; assign stable IDs. | |
| - Batch **VAE-encode all images → BF16 latents written to streaming shards** (WebDataset / tensordict). The VAE runs only here, never in the training loop. | |
| - Dataloader streams `(latent, caption)` with aggressive prefetch; per step samples timestep `t` and noise, forms `x_t = (1−t)·noise + t·latent`. | |
| - Text embeddings: run **Qwen3-4B online**, intercepting intermediate layers (→ 7680-dim); encoder frozen. (Optionally batch-precompute per shard if I/O-bound.) | |
| **B. Model surgery (prune + BRLL surrogates)** | |
| - Load klein-**base**-4B transformer (diffusers Flux2 / BFL reference impl). | |
| - Depth-select: keep 5 double + 8 of 20 single (importance / CKA similarity or learnable mask); inherit weights for kept blocks. | |
| - Replace the 12 removed single blocks with **BRLL surrogates** — linear/low-rank residual branch with the shortcut preserved; **least-squares-init from teacher activations** on a calibration batch (training-free warm start). | |
| - Sanity: shape + forward-parity checks. Full width `d=3072` → no projector needed for feature matching. | |
| **C. Distillation training loop (online teacher)** | |
| - Teacher 4B loaded **frozen** (eval, `no_grad`, BF16, compiled). | |
| - Per step: `latents + prompts → text emb → sample t, noise → x_t`; **teacher forward → target velocity** `v_teacher`; student forward → `v_student`. | |
| - Loss = `w1·velocity-MSE` (primary) + `w2·feature-match` (masked KD on retained blocks) + `w3·real-data flow-matching`. Tune weights. | |
| - Param-group split: **Muon** on 2D hidden weights, **AdamW** on embedders / `proj_out` / norms / modulation. | |
| - BF16 autocast, FP32 master weights + moments; grad-accumulate to effective ~256; `torch.compile` the student; **FlashAttention-4** via FlexAttention. | |
| **D. Infra, logging & eval** | |
| - Single B200, micro-batch 64, resumable checkpointing. | |
| - Log per-component losses, throughput / MFU, and periodic sample grids for eyeballing. | |
| - Eval harness: FID/CLIP + targeted prompt probes (text rendering, composition/counting, texture) vs the 4B; downstream fine-tune smoke test. | |
| **E. Distill-down kit (reuses C)** | |
| - Same loop with the teacher **swappable (4B or 9B)**; student stays on the Qwen3-4B encoder. | |
| - Flow: domain fine-tune the teacher → generate synthetic in-domain samples → velocity-match into the 2B. | |
| - Package: block-correspondence map + projector recipe + scripts + minimal config. | |
| # 10. Milestones | |
| 1. **Data + latents** — curate 300k diverse pairs; precompute VAE latents to shards. | |
| 2. **Architecture** — prune to 5D+8S full-width; least-squares-init BRLL surrogates. | |
| 3. **Recovery run** — Muon+AdamW, online teacher, velocity + feature + FM losses, 1.2M views @512px, BF16. | |
| 4. **Eval** — FID/CLIP + targeted probes (text, composition, texture) vs 4B; downstream fine-tune smoke test. | |
| 5. **(Optional) Polish** — short 768/1024 phase; optional LADD for v2. | |
| 6. **Release** — 2B base weights + distill-down kit (correspondence map, scripts) + model card. | |
| # 11. Budget Summary | |
| | Line | Estimate | | |
| | --- | --- | | |
| | Recovery training (1× B200, 1.2M views) | **~$66–75** | | |
| | VAE-latent precompute (one-time) | <$5 | | |
| | **Total per run** | **~$70** | | |
| Experiments to be tracked individually under this page per the standard experiment format. |
Xet Storage Details
- Size:
- 11.4 kB
- Xet hash:
- c67552ce266c3e9c9873b704c95c412308e6dbce619fdebe9d2d96744e32cf63
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.