Buckets:

Mercity/FluxDistill / init-plan.md
Pranav2748's picture
|
download
raw
11.4 kB

Status: Proposal · Compute target: 1× B200, ~13–15h, ~$66–75 · Deliverable: a ~2B open base model + a downstream distill-down kit.

1. Objective

Compress FLUX.2 [klein] 4B (rectified-flow MM-DiT) into a ~2B general-purpose text-to-image base model that retains most of the teacher's capability, runs meaningfully faster, and serves as a strong foundation for downstream narrow-domain fine-tuning by third parties (e.g. wallpapers, single styles, product imagery). The model is recovered via knowledge distillation on a single B200 within a strict ~$70 / half-day budget. We deliberately do not reduce inference steps; we accept 4→8 steps in exchange for the smaller, faster backbone.

2. Background

  • Teacher: FLUX.2 [klein] base 4B (undistilled — full training signal, best for fine-tuning). Architecture: d=3072, 24 heads, 5 double-stream + 20 single-stream MM-DiT blocks, shared-across-block AdaLN modulation, no bias.
  • Conditioning: Qwen3-4B text encoder (intermediate-layer stack → 7680-dim conditioning). Frozen; not part of the 4B denoiser.
  • Latent space: shared FLUX.2 VAE (128-ch latent) — identical across the whole klein family (4B & 9B). This is what makes cross-size velocity distillation valid (see §7).
  • Capacity distribution: the 20 single-stream blocks hold ~67% of backbone params; the 5 double-stream blocks (cross-modal binding) hold ~33%. This asymmetry drives the architecture choice below.
  • License: klein 4B is Apache 2.0 (commercial-friendly); klein 9B is non-commercial.

3. Target Architecture (~2B student)

Full-width, depth-pruned, with lightweight block replacement for the removed depth.

Component Spec Rationale
Width d=3072 (unchanged) Same as teacher → preserves texture/detail capacity; feature matching needs no projector
Double-stream 5 (all kept) Preserve cross-modal/instruction binding
Single-stream 8 kept full (of 20) Depth-only prune; recovers far better than width prune
Removed depth 12 single blocks → light BRLL surrogates Linear/low-rank residual-branch replacements, shortcut preserved; least-squares (training-free) init = warm start
Effective size core 2.04B + surrogates ~0.41B ≈ **2.45B**
Inference ~1.7× faster than 4B, ~5 GB BF16 Still runs on consumer GPUs

4. Methodology

Initialization

  • Prune from klein-base-4B; inherit surviving block weights.
  • Replace the 12 removed single-stream blocks with BRLL-style surrogates, least-squares initialized from teacher activations (training-free warm start → less to recover).

Distillation (online teacher)

  • Velocity/output matching against the 4B teacher (flow-matching target) — primary signal, runs on unlimited teacher-generated targets.
  • Feature matching on retained blocks — free here because student/teacher widths match (no projector). Use masked KD (exclude DiT extreme-activation outliers).
  • Real-data flow-matching loss on curated images for grounding.
  • No adversarial (LADD) phase in v1 — too costly/finicky for the budget; revisit for a v2 texture polish.

Data pipeline

  • 300k curated images + captions, diversity-first (broad styles/subjects/composition + in-image text coverage).
  • Precompute VAE latents offline (~79 GB @512px BF16, streaming shards). VAE never runs in the training loop.
  • Recompute Qwen3-4B text embeddings online (caching them would be ~2.4 TB — not worth it).
  • Multi-timestep sampling: fresh random timestep per image per epoch (epochs = new velocity-field supervision, not repetition).
  • 4 epochs × 300k = 1.2M views. Views kept at full count (quality prioritized over the ~$17 saving).

Optimizer & precision

  • Muon on 2D hidden weights (attention/MLP) + AdamW on embedders, proj_out, norms, modulation. Tune LR + weight decay only.
  • BF16 compute, FP32 master weights + FP32 moments. No quantization in training.

Resolution

  • Bulk training at 512px; optional short 768/1024 polish phase if memory/time headroom remains.

5. Compute & Infrastructure (1× B200, 192 GB)

Memory budget (fixed ~45 GB)

Item GB
FP32 master weights ~9.8
BF16 weights (compute) ~4.9
BF16 grads ~4.9
Muon momentum (1 moment, FP32) ~9.0
AdamW m,v (FP32, ~8% of params) ~1.6
Teacher 4B (BF16, fwd only) ~6.8
Qwen3-4B encoder (BF16, fwd only) ~8.0
Total fixed ~45

→ ~137 GB free for activations. Muon saves ~8 GB vs full AdamW.

Batch sizing

  • Activations ≈ 1.63 GB/sample @512px (FlashAttention, no checkpointing).
  • Max micro-batch ≈ 84; run micro-batch 64, gradient-accumulate to effective ~256 (Muon's large-batch sweet spot). No gradient checkpointing needed.
  • ~18,750 micro-steps (≈4,700 optimizer steps at accum-4).

Throughput optimizations

  • torch.compile (CUDA-graphs mode) — required to hit target MFU.
  • FlashAttention-4 via PyTorch FlexAttention (BF16, Blackwell-native).
  • Precomputed latents + aggressive prefetch (keeps GPU fed, protects MFU).

Time & cost — 1.2M views × ~40 TFLOP/view:

MFU Wall-clock Cost (@~$5/h)
35% ~17 h ~$85
40% ~15 h ~$75
45% ~13 h ~$66

Plus a one-time ~30–60 min offline VAE-latent precompute (can run on cheaper hardware).

6. Expected Outcomes & Quality Profile

A full-width 2B with a small, defensible gap to the 4B.

  • Robust (holds up well): global structure, color, lighting, coherence; simple/moderate prompt adherence (encoder + double-stream preserved); texture/fine detail (full width retained).
  • Degrades (worst first): dense in-image text/typography; complex multi-object composition, counting, precise spatial/attribute binding; small faces/hands and micro-detail.
  • Forgiving for the use case: downstream narrow fine-tuning re-concentrates the 2B's capacity on one domain, so lost broad generality matters far less for end users.

7. Downstream "Distill-Down" Kit

Ship the 2B base plus a kit so users can specialize it via a bigger teacher rather than fine-tuning the 2B directly — best for small downstream datasets.

  • Mechanism: fine-tune the larger teacher (4B or 9B) on the narrow domain → use it as a data amplifier (generate synthetic in-domain samples) → distill into the 2B via velocity matching.
  • Cross-size validity: the entire klein family shares one VAE/latent space, so 9B→2B and 4B→2B velocity matching is well-defined. Each model uses its own text encoder; the student stays on Qwen3-4B (do not bolt the 9B's Qwen3-8B encoder onto the 2B — dimension mismatch).
  • Asset to ship: because the 2B descends from the 4B, document the block-correspondence map (which student block ↔ which teacher block) + projector recipe for optional feature matching — the hard part users can't easily derive themselves.
  • License note: 4B teacher path = commercial-OK (Apache 2.0); 9B teacher path = non-commercial only.

8. Risks & Open Questions

  • Muon on warm-start recovery is under-validated (strong evidence is LLM/from-scratch + one flow-matching benchmark). → A/B against AdamW on the first training chunk before committing.
  • 1.2M views is a partial recovery in absolute terms. → Monitor the loss curve; be willing to extend if a specific capability (e.g. text) stays weak.
  • Surrogate weight is a tradeoff: light = cheap warm start, near-2B FLOPs; heavy = better quality but more cost and less speedup. Default light.
  • Coverage holes: quality gaps will track wherever the 300k set is thin. Curate for breadth, not volume.
  • MFU sensitivity: a sloppy loop at ~25% MFU ≈ 1.6× longer/costlier; the §5 optimizations are how we hit 40–45%.

9. Implementation & Engineering Setup

What we actually build, in dependency order.

A. Data & latent caching (offline, one-time)

  • Curate + caption 300k images; assign stable IDs.
  • Batch VAE-encode all images → BF16 latents written to streaming shards (WebDataset / tensordict). The VAE runs only here, never in the training loop.
  • Dataloader streams (latent, caption) with aggressive prefetch; per step samples timestep t and noise, forms x_t = (1−t)·noise + t·latent.
  • Text embeddings: run Qwen3-4B online, intercepting intermediate layers (→ 7680-dim); encoder frozen. (Optionally batch-precompute per shard if I/O-bound.)

B. Model surgery (prune + BRLL surrogates)

  • Load klein-base-4B transformer (diffusers Flux2 / BFL reference impl).
  • Depth-select: keep 5 double + 8 of 20 single (importance / CKA similarity or learnable mask); inherit weights for kept blocks.
  • Replace the 12 removed single blocks with BRLL surrogates — linear/low-rank residual branch with the shortcut preserved; least-squares-init from teacher activations on a calibration batch (training-free warm start).
  • Sanity: shape + forward-parity checks. Full width d=3072 → no projector needed for feature matching.

C. Distillation training loop (online teacher)

  • Teacher 4B loaded frozen (eval, no_grad, BF16, compiled).
  • Per step: latents + prompts → text emb → sample t, noise → x_t; teacher forward → target velocity v_teacher; student forward → v_student.
  • Loss = w1·velocity-MSE (primary) + w2·feature-match (masked KD on retained blocks) + w3·real-data flow-matching. Tune weights.
  • Param-group split: Muon on 2D hidden weights, AdamW on embedders / proj_out / norms / modulation.
  • BF16 autocast, FP32 master weights + moments; grad-accumulate to effective ~256; torch.compile the student; FlashAttention-4 via FlexAttention.

D. Infra, logging & eval

  • Single B200, micro-batch 64, resumable checkpointing.
  • Log per-component losses, throughput / MFU, and periodic sample grids for eyeballing.
  • Eval harness: FID/CLIP + targeted prompt probes (text rendering, composition/counting, texture) vs the 4B; downstream fine-tune smoke test.

E. Distill-down kit (reuses C)

  • Same loop with the teacher swappable (4B or 9B); student stays on the Qwen3-4B encoder.
  • Flow: domain fine-tune the teacher → generate synthetic in-domain samples → velocity-match into the 2B.
  • Package: block-correspondence map + projector recipe + scripts + minimal config.

10. Milestones

  1. Data + latents — curate 300k diverse pairs; precompute VAE latents to shards.
  2. Architecture — prune to 5D+8S full-width; least-squares-init BRLL surrogates.
  3. Recovery run — Muon+AdamW, online teacher, velocity + feature + FM losses, 1.2M views @512px, BF16.
  4. Eval — FID/CLIP + targeted probes (text, composition, texture) vs 4B; downstream fine-tune smoke test.
  5. (Optional) Polish — short 768/1024 phase; optional LADD for v2.
  6. Release — 2B base weights + distill-down kit (correspondence map, scripts) + model card.

11. Budget Summary

Line Estimate
Recovery training (1× B200, 1.2M views) ~$66–75
VAE-latent precompute (one-time) <$5
Total per run ~$70

Experiments to be tracked individually under this page per the standard experiment format.

Xet Storage Details

Size:
11.4 kB
·
Xet hash:
c67552ce266c3e9c9873b704c95c412308e6dbce619fdebe9d2d96744e32cf63

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.