Buckets:

Mercity/FluxDistill / init-plan.md
Pranav2748's picture
|
download
raw
11.4 kB
> **Status:** Proposal · **Compute target:** 1× B200, ~13–15h, ~$66–75 · **Deliverable:** a ~2B open base model + a downstream distill-down kit.
>
# 1. Objective
Compress **FLUX.2 [klein] 4B** (rectified-flow MM-DiT) into a **~2B general-purpose text-to-image base model** that retains most of the teacher's capability, runs meaningfully faster, and serves as a strong foundation for downstream **narrow-domain fine-tuning** by third parties (e.g. wallpapers, single styles, product imagery). The model is recovered via knowledge distillation on a single B200 within a strict ~$70 / half-day budget. We deliberately do **not** reduce inference steps; we accept 4→8 steps in exchange for the smaller, faster backbone.
# 2. Background
- **Teacher:** FLUX.2 [klein] **base** 4B (undistilled — full training signal, best for fine-tuning). Architecture: `d=3072`, `24 heads`, **5 double-stream + 20 single-stream** MM-DiT blocks, shared-across-block AdaLN modulation, no bias.
- **Conditioning:** Qwen3-4B text encoder (intermediate-layer stack → 7680-dim conditioning). Frozen; not part of the 4B denoiser.
- **Latent space:** shared FLUX.2 VAE (128-ch latent) — **identical across the whole klein family (4B & 9B)**. This is what makes cross-size velocity distillation valid (see §7).
- **Capacity distribution:** the 20 single-stream blocks hold **~67%** of backbone params; the 5 double-stream blocks (cross-modal binding) hold ~33%. This asymmetry drives the architecture choice below.
- **License:** klein 4B is Apache 2.0 (commercial-friendly); klein 9B is non-commercial.
# 3. Target Architecture (~2B student)
Full-width, depth-pruned, with lightweight block replacement for the removed depth.
| Component | Spec | Rationale |
| --- | --- | --- |
| Width | `d=3072` (unchanged) | Same as teacher → preserves texture/detail capacity; feature matching needs no projector |
| Double-stream | **5 (all kept)** | Preserve cross-modal/instruction binding |
| Single-stream | **8 kept full** (of 20) | Depth-only prune; recovers far better than width prune |
| Removed depth | **12 single blocks → light BRLL surrogates** | Linear/low-rank residual-branch replacements, shortcut preserved; least-squares (training-free) init = warm start |
| Effective size | core ~2.04B + surrogates ~0.41B ≈ **~2.45B** | |
| Inference | **~1.7× faster than 4B**, ~5 GB BF16 | Still runs on consumer GPUs |
# 4. Methodology
**Initialization**
- Prune from klein-**base**-4B; inherit surviving block weights.
- Replace the 12 removed single-stream blocks with BRLL-style surrogates, **least-squares initialized** from teacher activations (training-free warm start → less to recover).
**Distillation (online teacher)**
- **Velocity/output matching** against the 4B teacher (flow-matching target) — primary signal, runs on unlimited teacher-generated targets.
- **Feature matching** on retained blocks — free here because student/teacher widths match (no projector). Use **masked KD** (exclude DiT extreme-activation outliers).
- **Real-data flow-matching loss** on curated images for grounding.
- **No adversarial (LADD) phase in v1** — too costly/finicky for the budget; revisit for a v2 texture polish.
**Data pipeline**
- **300k curated images + captions**, diversity-first (broad styles/subjects/composition + in-image text coverage).
- **Precompute VAE latents offline** (~79 GB @512px BF16, streaming shards). VAE never runs in the training loop.
- **Recompute Qwen3-4B text embeddings online** (caching them would be ~2.4 TB — not worth it).
- **Multi-timestep sampling**: fresh random timestep per image per epoch (epochs = new velocity-field supervision, not repetition).
- **4 epochs × 300k = 1.2M views.** Views kept at full count (quality prioritized over the ~$17 saving).
**Optimizer & precision**
- **Muon** on 2D hidden weights (attention/MLP) + **AdamW** on embedders, `proj_out`, norms, modulation. Tune LR + weight decay only.
- **BF16** compute, FP32 master weights + FP32 moments. No quantization in training.
**Resolution**
- Bulk training at **512px**; optional short 768/1024 polish phase if memory/time headroom remains.
# 5. Compute & Infrastructure (1× B200, 192 GB)
**Memory budget (fixed ~45 GB)**
| Item | GB |
| --- | --- |
| FP32 master weights | ~9.8 |
| BF16 weights (compute) | ~4.9 |
| BF16 grads | ~4.9 |
| Muon momentum (1 moment, FP32) | ~9.0 |
| AdamW m,v (FP32, ~8% of params) | ~1.6 |
| Teacher 4B (BF16, fwd only) | ~6.8 |
| Qwen3-4B encoder (BF16, fwd only) | ~8.0 |
| **Total fixed** | **~45** |
→ ~137 GB free for activations. Muon saves ~8 GB vs full AdamW.
**Batch sizing**
- Activations ≈ **1.63 GB/sample** @512px (FlashAttention, no checkpointing).
- **Max micro-batch ≈ 84**; run **micro-batch 64**, gradient-accumulate to **effective ~256** (Muon's large-batch sweet spot). No gradient checkpointing needed.
- ~18,750 micro-steps (≈4,700 optimizer steps at accum-4).
**Throughput optimizations**
- `torch.compile` (CUDA-graphs mode) — required to hit target MFU.
- **FlashAttention-4** via PyTorch FlexAttention (BF16, Blackwell-native).
- Precomputed latents + aggressive prefetch (keeps GPU fed, protects MFU).
**Time & cost** — 1.2M views × ~40 TFLOP/view:
| MFU | Wall-clock | Cost (@~$5/h) |
| --- | --- | --- |
| 35% | ~17 h | ~$85 |
| 40% | ~15 h | ~$75 |
| 45% | ~13 h | ~$66 |
Plus a one-time ~30–60 min offline VAE-latent precompute (can run on cheaper hardware).
# 6. Expected Outcomes & Quality Profile
A full-width 2B with a small, defensible gap to the 4B.
- **Robust (holds up well):** global structure, color, lighting, coherence; simple/moderate prompt adherence (encoder + double-stream preserved); texture/fine detail (full width retained).
- **Degrades (worst first):** dense in-image text/typography; complex multi-object composition, counting, precise spatial/attribute binding; small faces/hands and micro-detail.
- **Forgiving for the use case:** downstream narrow fine-tuning re-concentrates the 2B's capacity on one domain, so lost broad generality matters far less for end users.
# 7. Downstream "Distill-Down" Kit
Ship the 2B base **plus** a kit so users can specialize it via a bigger teacher rather than fine-tuning the 2B directly — best for **small downstream datasets**.
- **Mechanism:** fine-tune the larger teacher (4B or 9B) on the narrow domain → use it as a data amplifier (generate synthetic in-domain samples) → distill into the 2B via **velocity matching**.
- **Cross-size validity:** the entire klein family shares one VAE/latent space, so 9B→2B and 4B→2B velocity matching is well-defined. Each model uses its own text encoder; **the student stays on Qwen3-4B** (do not bolt the 9B's Qwen3-8B encoder onto the 2B — dimension mismatch).
- **Asset to ship:** because the 2B descends from the 4B, document the **block-correspondence map** (which student block ↔ which teacher block) + projector recipe for optional feature matching — the hard part users can't easily derive themselves.
- **License note:** 4B teacher path = commercial-OK (Apache 2.0); 9B teacher path = non-commercial only.
# 8. Risks & Open Questions
- **Muon on warm-start recovery is under-validated** (strong evidence is LLM/from-scratch + one flow-matching benchmark). → A/B against AdamW on the first training chunk before committing.
- **1.2M views is a partial recovery in absolute terms.** → Monitor the loss curve; be willing to extend if a specific capability (e.g. text) stays weak.
- **Surrogate weight is a tradeoff:** light = cheap warm start, near-2B FLOPs; heavy = better quality but more cost and less speedup. Default light.
- **Coverage holes:** quality gaps will track wherever the 300k set is thin. Curate for breadth, not volume.
- **MFU sensitivity:** a sloppy loop at ~25% MFU ≈ 1.6× longer/costlier; the §5 optimizations are how we hit 40–45%.
# 9. Implementation & Engineering Setup
What we actually build, in dependency order.
**A. Data & latent caching (offline, one-time)**
- Curate + caption 300k images; assign stable IDs.
- Batch **VAE-encode all images → BF16 latents written to streaming shards** (WebDataset / tensordict). The VAE runs only here, never in the training loop.
- Dataloader streams `(latent, caption)` with aggressive prefetch; per step samples timestep `t` and noise, forms `x_t = (1−t)·noise + t·latent`.
- Text embeddings: run **Qwen3-4B online**, intercepting intermediate layers (→ 7680-dim); encoder frozen. (Optionally batch-precompute per shard if I/O-bound.)
**B. Model surgery (prune + BRLL surrogates)**
- Load klein-**base**-4B transformer (diffusers Flux2 / BFL reference impl).
- Depth-select: keep 5 double + 8 of 20 single (importance / CKA similarity or learnable mask); inherit weights for kept blocks.
- Replace the 12 removed single blocks with **BRLL surrogates** — linear/low-rank residual branch with the shortcut preserved; **least-squares-init from teacher activations** on a calibration batch (training-free warm start).
- Sanity: shape + forward-parity checks. Full width `d=3072` → no projector needed for feature matching.
**C. Distillation training loop (online teacher)**
- Teacher 4B loaded **frozen** (eval, `no_grad`, BF16, compiled).
- Per step: `latents + prompts → text emb → sample t, noise → x_t`; **teacher forward → target velocity** `v_teacher`; student forward → `v_student`.
- Loss = `w1·velocity-MSE` (primary) + `w2·feature-match` (masked KD on retained blocks) + `w3·real-data flow-matching`. Tune weights.
- Param-group split: **Muon** on 2D hidden weights, **AdamW** on embedders / `proj_out` / norms / modulation.
- BF16 autocast, FP32 master weights + moments; grad-accumulate to effective ~256; `torch.compile` the student; **FlashAttention-4** via FlexAttention.
**D. Infra, logging & eval**
- Single B200, micro-batch 64, resumable checkpointing.
- Log per-component losses, throughput / MFU, and periodic sample grids for eyeballing.
- Eval harness: FID/CLIP + targeted prompt probes (text rendering, composition/counting, texture) vs the 4B; downstream fine-tune smoke test.
**E. Distill-down kit (reuses C)**
- Same loop with the teacher **swappable (4B or 9B)**; student stays on the Qwen3-4B encoder.
- Flow: domain fine-tune the teacher → generate synthetic in-domain samples → velocity-match into the 2B.
- Package: block-correspondence map + projector recipe + scripts + minimal config.
# 10. Milestones
1. **Data + latents** — curate 300k diverse pairs; precompute VAE latents to shards.
2. **Architecture** — prune to 5D+8S full-width; least-squares-init BRLL surrogates.
3. **Recovery run** — Muon+AdamW, online teacher, velocity + feature + FM losses, 1.2M views @512px, BF16.
4. **Eval** — FID/CLIP + targeted probes (text, composition, texture) vs 4B; downstream fine-tune smoke test.
5. **(Optional) Polish** — short 768/1024 phase; optional LADD for v2.
6. **Release** — 2B base weights + distill-down kit (correspondence map, scripts) + model card.
# 11. Budget Summary
| Line | Estimate |
| --- | --- |
| Recovery training (1× B200, 1.2M views) | **~$66–75** |
| VAE-latent precompute (one-time) | <$5 |
| **Total per run** | **~$70** |
Experiments to be tracked individually under this page per the standard experiment format.

Xet Storage Details

Size:
11.4 kB
·
Xet hash:
c67552ce266c3e9c9873b704c95c412308e6dbce619fdebe9d2d96744e32cf63

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.