| | --- |
| | license: mit |
| | language: |
| | - en |
| | tags: |
| | - diffusion |
| | - flow-matching |
| | - flux |
| | - text-to-image |
| | - image-generation |
| | - tinyflux |
| | - lailah |
| | - experimental |
| | library_name: pytorch |
| | pipeline_tag: text-to-image |
| | base_model: |
| | - AbstractPhil/tiny-flux |
| | - black-forest-labs/FLUX.1-schnell |
| | datasets: |
| | - AbstractPhil/flux-schnell-teacher-latents |
| | - AbstractPhil/imagenet-synthetic |
| | --- |
| | |
| | # TinyFlux-Deep v4.1 (Lailah) |
| |
|
| | A compact **246M parameter** flow-matching diffusion model that distills knowledge from multiple teacher models into an efficient architecture. TinyFlux-Deep uses a dual expert system to capture both trajectory dynamics (from SD1.5) and structural attention patterns (from a geometric prior), enabling high-quality image generation at a fraction of the compute cost of full-scale models. |
| |
|
| | ## Table of Contents |
| |
|
| | - [Key Features](#key-features) |
| | - [Quick Start](#quick-start) |
| | - [Architecture](#architecture) |
| | - [Dual Expert System](#dual-expert-system) |
| | - [Configuration](#configuration) |
| | - [Inference](#inference) |
| | - [Training](#training) |
| | - [Checkpoint Conversion](#checkpoint-conversion) |
| | - [Repository Structure](#repository-structure) |
| | - [API Reference](#api-reference) |
| | - [Samples](#samples) |
| | - [Limitations](#limitations) |
| | - [Citation](#citation) |
| |
|
| | --- |
| |
|
| | ## Key Features |
| |
|
| | | Feature | Description | |
| | |---------|-------------| |
| | | **Compact Size** | 246M params (~500MB bf16) vs Flux's 12B (~24GB) | |
| | | **Dual Expert Distillation** | Learns from both SD1.5 trajectory features and geometric attention priors | |
| | | **Flow Matching** | Rectified flow objective with Flux-style timestep shifting | |
| | | **T5 + CLIP Conditioning** | Dual text encoder pathway with learnable balance | |
| | | **Huber Loss** | Robust training that handles outliers gracefully | |
| | | **Identity-Init Conversion** | v3→v4 conversion preserves pretrained weights exactly | |
| |
|
| | --- |
| |
|
| | ## Quick Start |
| |
|
| | ### Colab Inference |
| |
|
| | ```python |
| | !pip install torch transformers safetensors huggingface_hub accelerate |
| | |
| | import torch |
| | from huggingface_hub import hf_hub_download |
| | from safetensors.torch import load_file |
| | |
| | # Download model code and weights |
| | model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py") |
| | weights = hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors") |
| | |
| | # Load model |
| | exec(open(model_py).read()) |
| | config = TinyFluxConfig() |
| | model = TinyFluxDeep(config).to("cuda", torch.bfloat16) |
| | model.load_state_dict(load_file(weights), strict=False) |
| | model.eval() |
| | |
| | # For full inference pipeline with text encoders and sampling: |
| | inference_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/inference_v3.py") |
| | exec(open(inference_py).read()) |
| | # Then call: image = generate("your prompt here") |
| | ``` |
| |
|
| | ### Minimal Generation Loop |
| |
|
| | ```python |
| | import torch |
| | import torch.nn.functional as F |
| | |
| | def flux_shift(t, s=3.0): |
| | """Flux-style timestep shifting - biases toward data end.""" |
| | return s * t / (1 + (s - 1) * t) |
| | |
| | def generate(model, t5_emb, clip_emb, clip_pooled, num_steps=25, cfg_scale=4.0): |
| | """Euler sampling with classifier-free guidance.""" |
| | device = next(model.parameters()).device |
| | dtype = next(model.parameters()).dtype |
| | |
| | # Start from pure noise (t=0) |
| | x = torch.randn(1, 64*64, 16, device=device, dtype=dtype) |
| | img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device) |
| | |
| | # Rectified flow: integrate from t=0 (noise) to t=1 (data) |
| | timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1, device=device)) |
| | |
| | for i in range(num_steps): |
| | t_curr = timesteps[i] |
| | t_next = timesteps[i + 1] |
| | dt = t_next - t_curr |
| | |
| | t_batch = t_curr.expand(1) |
| | |
| | # Conditional prediction |
| | v_cond = model( |
| | hidden_states=x, |
| | encoder_hidden_states=t5_emb, |
| | pooled_projections=clip_pooled, |
| | timestep=t_batch, |
| | img_ids=img_ids, |
| | ) |
| | |
| | # Unconditional prediction (for CFG) |
| | v_uncond = model( |
| | hidden_states=x, |
| | encoder_hidden_states=torch.zeros_like(t5_emb), |
| | pooled_projections=torch.zeros_like(clip_pooled), |
| | timestep=t_batch, |
| | img_ids=img_ids, |
| | ) |
| | |
| | # Classifier-free guidance |
| | v = v_uncond + cfg_scale * (v_cond - v_uncond) |
| | |
| | # Euler step |
| | x = x + v * dt |
| | |
| | return x # [1, 4096, 16] - decode with VAE |
| | ``` |
| |
|
| | --- |
| |
|
| | ## Architecture |
| |
|
| | ### Model Comparison |
| |
|
| | | Component | TinyFlux | TinyFlux-Deep v3 | TinyFlux-Deep v4.1 | Flux-Schnell | |
| | |-----------|----------|------------------|--------------------| -------------| |
| | | Hidden size | 256 | 512 | 512 | 3072 | |
| | | Attention heads | 2 | 4 | 4 | 24 | |
| | | Head dimension | 128 | 128 | 128 | 128 | |
| | | Double-stream layers | 3 | 15 | 15 | 19 | |
| | | Single-stream layers | 3 | 25 | 25 | 38 | |
| | | MLP ratio | 4.0 | 4.0 | 4.0 | 4.0 | |
| | | RoPE dims | (16,56,56) | (16,56,56) | (16,56,56) | (16,56,56) | |
| | | Lune Expert | ✗ | ✓ | ✓ | ✗ | |
| | | Sol Attention Prior | ✗ | ✗ | ✓ | ✗ | |
| | | T5 Vec Enhancement | ✗ | ✗ | ✓ | ✗ | |
| | | **Total Parameters** | ~10.7M | ~244.7M | ~246.4M | ~12B | |
| | | **Memory (bf16)** | ~22MB | ~490MB | ~493MB | ~24GB | |
| |
|
| | ### Block Structure |
| |
|
| | **Double-Stream Blocks (15 layers):** |
| | - Separate text and image pathways |
| | - Joint attention between modalities |
| | - AdaLN-Zero conditioning from vec |
| | - Sol spatial modulation on image Q/K only |
| |
|
| | **Single-Stream Blocks (25 layers):** |
| | - Concatenated text + image sequence |
| | - Full self-attention with RoPE |
| | - Sol modulation skips text tokens |
| |
|
| | ``` |
| | Input: img_latents [B, 4096, 16], t5_emb [B, 77, 768], clip_pooled [B, 768] |
| | │ |
| | ┌───────────────┴───────────────┐ |
| | ▼ ▼ |
| | img_in (Linear) txt_in (Linear) |
| | │ │ |
| | ▼ ▼ |
| | [B, 4096, 512] [B, 77, 512] |
| | │ │ |
| | └───────────┬───────────────────┘ |
| | │ |
| | vec = time_emb + clip_vec + t5_vec + lune_signal |
| | │ |
| | ┌───────────┴───────────┐ |
| | ▼ ▼ |
| | Double Blocks (×15) Sol Prior → temperature, spatial_mod |
| | │ │ |
| | ▼ │ |
| | Single Blocks (×25) ◄─────────┘ |
| | │ |
| | ▼ |
| | final_norm → final_linear |
| | │ |
| | ▼ |
| | Output: [B, 4096, 16] |
| | ``` |
| |
|
| | --- |
| |
|
| | ## Dual Expert System |
| |
|
| | TinyFlux-Deep v4.1 uses two complementary expert pathways to inject knowledge from teacher models without the "twin-tail paradox" (mixing incompatible prediction targets). |
| |
|
| | ### Lune Expert Predictor (Trajectory Guidance) |
| |
|
| | **Purpose:** Captures SD1.5's understanding of "how denoising should flow" - the trajectory through latent space. |
| |
|
| | **Architecture:** |
| | ```python |
| | LuneExpertPredictor( |
| | time_dim=512, # From timestep MLP |
| | clip_dim=768, # CLIP pooled features |
| | expert_dim=1280, # SD1.5 mid-block dimension (prediction target) |
| | hidden_dim=512, # Internal MLP width |
| | output_dim=512, # Output added to vec |
| | dropout=0.1, |
| | ) |
| | ``` |
| |
|
| | **How it works:** |
| | 1. Concatenates timestep embedding + CLIP pooled → hidden |
| | 2. Predicts what SD1.5's mid-block features would be |
| | 3. During training: uses real SD1.5 features when available |
| | 4. During inference: uses predicted features |
| | 5. Gates output with learnable sigmoid (init 0.5) |
| | 6. Adds `expert_signal` to global `vec` conditioning |
| |
|
| | **Training signal:** Cosine similarity loss against real SD1.5 UNet mid-block features (soft directional matching, not exact reconstruction). |
| |
|
| | ### Sol Attention Prior (Structural Guidance) |
| |
|
| | **Purpose:** Captures geometric/structural knowledge about WHERE attention should focus, without injecting incompatible features. |
| |
|
| | **Key insight:** Sol (a V-prediction DDPM model) has valuable attention patterns, but its features are incompatible with TinyFlux's linear flow matching. We extract attention *statistics* instead: |
| | - **Locality:** How local vs global is attention? |
| | - **Entropy:** How focused vs diffuse? |
| | - **Clustering:** How structured vs uniform? |
| | - **Spatial importance:** Which regions matter most? |
| |
|
| | **Architecture:** |
| | ```python |
| | SolAttentionPrior( |
| | time_dim=512, |
| | clip_dim=768, |
| | hidden_dim=256, |
| | num_heads=4, # Matches TinyFlux attention heads |
| | spatial_size=8, # 8×8 importance map |
| | geometric_weight=0.7, # David's 70/30 split |
| | ) |
| | ``` |
| |
|
| | **How it works:** |
| | 1. **Geometric prior (70%):** Timestep-based heuristics |
| | - Early denoising (high t): Higher temperature → softer, global attention |
| | - Late denoising (low t): Lower temperature → sharper, local attention |
| | - Spatial: Uniform early, center-biased late |
| |
|
| | 2. **Learned prior (30%):** Content-based predictions |
| | - Predicts attention statistics from (timestep, CLIP) |
| | - Predicts spatial importance map |
| |
|
| | 3. **Blending:** `blend * geometric + (1-blend) * learned` with learnable blend gate |
| |
|
| | 4. **Output application:** |
| | - `temperature [B, 4]`: Scales attention logits per head |
| | - `spatial_mod [B, H, W]`: Modulates Q/K at each position via `exp(conv(spatial))` |
| |
|
| | **Identity initialization:** All Sol components initialize to zero-effect: |
| | - `spatial_to_mod` Conv2d: zero weight, zero bias → `exp(0) = 1` (identity) |
| | - Allows gradual learning without disrupting pretrained attention |
| |
|
| | ### T5 Vec Enhancement |
| |
|
| | **Purpose:** Adds T5's semantic understanding to the global conditioning pathway (previously only CLIP pooled). |
| |
|
| | ```python |
| | # Attention-weighted pooling of T5 sequence |
| | t5_attn = softmax(t5_emb.mean(dim=-1)) # [B, 77] |
| | t5_pooled = (t5_emb * t5_attn.unsqueeze(-1)).sum(dim=1) # [B, 768] |
| | t5_vec = t5_pool_mlp(t5_pooled) # [B, 512] |
| | |
| | # Learnable balance between CLIP and T5 |
| | balance = sigmoid(text_balance) # Initialized to 0.5 |
| | text_vec = balance * clip_vec + (1 - balance) * t5_vec |
| | ``` |
| |
|
| | --- |
| |
|
| | ## Configuration |
| |
|
| | ### TinyFluxConfig |
| |
|
| | ```python |
| | from dataclasses import dataclass |
| | from typing import Tuple |
| | |
| | @dataclass |
| | class TinyFluxConfig: |
| | # Core architecture |
| | hidden_size: int = 512 |
| | num_attention_heads: int = 4 |
| | attention_head_dim: int = 128 # hidden_size = heads × head_dim |
| | in_channels: int = 16 # VAE latent channels |
| | patch_size: int = 1 |
| | joint_attention_dim: int = 768 # T5 embedding dim |
| | pooled_projection_dim: int = 768 # CLIP pooled dim |
| | num_double_layers: int = 15 |
| | num_single_layers: int = 25 |
| | mlp_ratio: float = 4.0 |
| | axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) # Must sum to head_dim |
| | |
| | # Lune expert predictor |
| | use_lune_expert: bool = True |
| | lune_expert_dim: int = 1280 # SD1.5 mid-block dim |
| | lune_hidden_dim: int = 512 |
| | lune_dropout: float = 0.1 |
| | |
| | # Sol attention prior |
| | use_sol_prior: bool = True |
| | sol_spatial_size: int = 8 # 8×8 spatial importance map |
| | sol_hidden_dim: int = 256 |
| | sol_geometric_weight: float = 0.7 # 70% geometric, 30% learned |
| | |
| | # T5 enhancement |
| | use_t5_vec: bool = True |
| | t5_pool_mode: str = "attention" # "attention", "mean", "cls" |
| | |
| | # Loss configuration |
| | lune_distill_mode: str = "cosine" # "hard", "soft", "cosine", "huber" |
| | use_huber_loss: bool = True |
| | huber_delta: float = 0.1 |
| | |
| | # Legacy compatibility |
| | guidance_embeds: bool = False |
| | ``` |
| |
|
| | ### Loading from JSON |
| |
|
| | ```python |
| | # From file |
| | config = TinyFluxConfig.from_json("lailah_401434_v4_config.json") |
| | |
| | # From dict |
| | config = TinyFluxConfig.from_dict({ |
| | "hidden_size": 512, |
| | "num_attention_heads": 4, |
| | ... |
| | }) |
| | |
| | # Save with metadata |
| | config.save_json("config.json", metadata={"source_step": 401434}) |
| | ``` |
| |
|
| | ### Validation |
| |
|
| | ```python |
| | # Config validates constraints on creation |
| | config = TinyFluxConfig(hidden_size=512, num_attention_heads=4, attention_head_dim=128) |
| | # ✓ OK: 512 = 4 × 128 |
| | |
| | config = TinyFluxConfig(hidden_size=512, num_attention_heads=4, attention_head_dim=64) |
| | # ✗ ValueError: hidden_size (512) must equal num_attention_heads * attention_head_dim (256) |
| | |
| | # Validate checkpoint compatibility |
| | warnings = config.validate_checkpoint(state_dict) |
| | if warnings: |
| | print("Warnings:", warnings) |
| | ``` |
| |
|
| | --- |
| |
|
| | ## Inference |
| |
|
| | ### Full Pipeline |
| |
|
| | ```python |
| | import torch |
| | from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer |
| | from diffusers import AutoencoderKL |
| | from huggingface_hub import hf_hub_download |
| | from safetensors.torch import load_file |
| | |
| | # Load text encoders |
| | t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") |
| | t5_model = T5EncoderModel.from_pretrained("google/flan-t5-base").to("cuda", torch.bfloat16) |
| | |
| | clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
| | clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda", torch.bfloat16) |
| | |
| | # Load VAE |
| | vae = AutoencoderKL.from_pretrained( |
| | "black-forest-labs/FLUX.1-schnell", |
| | subfolder="vae", |
| | torch_dtype=torch.bfloat16 |
| | ).to("cuda") |
| | |
| | # Load TinyFlux-Deep |
| | model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py") |
| | exec(open(model_py).read()) |
| | |
| | config = TinyFluxConfig() |
| | model = TinyFluxDeep(config).to("cuda", torch.bfloat16) |
| | weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors")) |
| | model.load_state_dict(weights, strict=False) |
| | model.eval() |
| | |
| | def encode_prompt(prompt): |
| | """Encode prompt with both T5 and CLIP.""" |
| | # T5 |
| | t5_tokens = t5_tokenizer(prompt, return_tensors="pt", padding="max_length", |
| | max_length=77, truncation=True).to("cuda") |
| | with torch.no_grad(): |
| | t5_emb = t5_model(**t5_tokens).last_hidden_state.to(torch.bfloat16) |
| | |
| | # CLIP |
| | clip_tokens = clip_tokenizer(prompt, return_tensors="pt", padding="max_length", |
| | max_length=77, truncation=True).to("cuda") |
| | with torch.no_grad(): |
| | clip_out = clip_model(**clip_tokens) |
| | clip_pooled = clip_out.pooler_output.to(torch.bfloat16) |
| | |
| | return t5_emb, clip_pooled |
| | |
| | def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None): |
| | """ |
| | Euler sampling for rectified flow. |
| | |
| | Flow: x_t = (1-t)*noise + t*data |
| | Integrate from t=0 (noise) to t=1 (data) |
| | """ |
| | if seed is not None: |
| | torch.manual_seed(seed) |
| | |
| | t5_emb, clip_pooled = encode_prompt(prompt) |
| | |
| | # Null embeddings for CFG |
| | t5_null, clip_null = encode_prompt("") |
| | |
| | # Start from pure noise (t=0) |
| | x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16) |
| | img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda") |
| | |
| | # Rectified flow: 0 → 1 with Flux shift |
| | def flux_shift(t, s=3.0): |
| | return s * t / (1 + (s - 1) * t) |
| | |
| | timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1, device="cuda")) |
| | |
| | with torch.no_grad(): |
| | for i in range(num_steps): |
| | t = timesteps[i].expand(1) |
| | dt = timesteps[i + 1] - timesteps[i] # Positive |
| | |
| | # Conditional |
| | v_cond = model(x, t5_emb, clip_pooled, t, img_ids) |
| | |
| | # Unconditional |
| | v_uncond = model(x, t5_null, clip_null, t, img_ids) |
| | |
| | # CFG |
| | v = v_uncond + cfg_scale * (v_cond - v_uncond) |
| | |
| | # Euler step |
| | x = x + v * dt |
| | |
| | # Decode with VAE |
| | x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) # [B, C, H, W] |
| | x = x / vae.config.scaling_factor |
| | with torch.no_grad(): |
| | image = vae.decode(x).sample |
| | |
| | # Convert to PIL |
| | image = (image / 2 + 0.5).clamp(0, 1) |
| | image = image[0].permute(1, 2, 0).cpu().float().numpy() |
| | image = (image * 255).astype("uint8") |
| | |
| | from PIL import Image |
| | return Image.fromarray(image) |
| | |
| | # Generate! |
| | image = generate_image("a photograph of a tiger in natural habitat", seed=42) |
| | image.save("tiger.png") |
| | ``` |
| |
|
| | ### Batch Inference |
| |
|
| | ```python |
| | def generate_batch(prompts, **kwargs): |
| | """Generate multiple images.""" |
| | return [generate_image(p, **kwargs) for p in prompts] |
| | |
| | images = generate_batch([ |
| | "a red bird with blue beak", |
| | "a mountain landscape at sunset", |
| | "an astronaut riding a horse", |
| | ], num_steps=25, cfg_scale=4.0) |
| | ``` |
| |
|
| | --- |
| |
|
| | ## Training |
| |
|
| | ### Loss Computation |
| |
|
| | ```python |
| | # Forward pass with expert info |
| | output, expert_info = model( |
| | hidden_states=noisy_latents, |
| | encoder_hidden_states=t5_emb, |
| | pooled_projections=clip_pooled, |
| | timestep=timesteps, |
| | img_ids=img_ids, |
| | lune_features=sd15_midblock_features, # From SD1.5 teacher |
| | sol_stats=sol_attention_stats, # From Sol teacher (optional) |
| | sol_spatial=sol_spatial_importance, # From Sol teacher (optional) |
| | return_expert_pred=True, |
| | ) |
| | |
| | # Compute loss |
| | losses = model.compute_loss( |
| | output=output, |
| | target=flow_target, # data - noise for flow matching |
| | expert_info=expert_info, |
| | lune_features=sd15_midblock_features, |
| | sol_stats=sol_attention_stats, |
| | sol_spatial=sol_spatial_importance, |
| | |
| | # Loss weights |
| | lune_weight=0.1, # Weight for Lune distillation |
| | sol_weight=0.05, # Weight for Sol distillation |
| | |
| | # Loss options |
| | use_huber=True, # Huber loss for main objective (robust to outliers) |
| | huber_delta=0.1, # Huber delta (smaller = tighter MSE region) |
| | lune_distill_mode="cosine", # "hard", "soft", "cosine", "huber" |
| | spatial_weighting=True, # Weight loss by Sol spatial importance |
| | ) |
| | |
| | # losses dict contains: |
| | # - main: flow matching loss |
| | # - lune_distill: Lune prediction loss |
| | # - sol_stat_distill: Sol statistics prediction loss |
| | # - sol_spatial_distill: Sol spatial prediction loss |
| | # - total: weighted sum |
| | |
| | loss = losses['total'] |
| | loss.backward() |
| | ``` |
| |
|
| | ### Distillation Modes |
| |
|
| | | Mode | Description | Use Case | |
| | |------|-------------|----------| |
| | | `"hard"` | MSE against teacher features | Exact reconstruction | |
| | | `"soft"` | Temperature-scaled MSE | Softer matching | |
| | | `"cosine"` | Cosine similarity loss | Directional alignment (recommended) | |
| | | `"huber"` | Huber loss on features | Robust to outliers | |
| |
|
| | ### Training Loop Example |
| |
|
| | ```python |
| | from torch.optim import AdamW |
| | from torch.cuda.amp import autocast, GradScaler |
| | |
| | optimizer = AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.99), weight_decay=0.01) |
| | scaler = GradScaler() |
| | |
| | # EMA |
| | ema_decay = 0.9999 |
| | ema_model = copy.deepcopy(model) |
| | |
| | for step, batch in enumerate(dataloader): |
| | optimizer.zero_grad() |
| | |
| | with autocast(dtype=torch.bfloat16): |
| | # Sample timesteps with logit-normal distribution |
| | u = torch.randn(batch_size, device=device) |
| | t = torch.sigmoid(u) # Logit-normal |
| | t = flux_shift(t, s=3.0) # Flux shift |
| | |
| | # Add noise |
| | noise = torch.randn_like(batch['latents']) |
| | noisy = t.view(-1,1,1) * batch['latents'] + (1-t.view(-1,1,1)) * noise |
| | target = batch['latents'] - noise # Flow matching target |
| | |
| | # Forward |
| | output, expert_info = model( |
| | hidden_states=noisy, |
| | encoder_hidden_states=batch['t5_emb'], |
| | pooled_projections=batch['clip_pooled'], |
| | timestep=t, |
| | img_ids=img_ids, |
| | lune_features=batch.get('sd15_features'), |
| | return_expert_pred=True, |
| | ) |
| | |
| | # Loss |
| | losses = model.compute_loss(output, target, expert_info, |
| | lune_features=batch.get('sd15_features')) |
| | |
| | scaler.scale(losses['total']).backward() |
| | scaler.step(optimizer) |
| | scaler.update() |
| | |
| | # EMA update |
| | with torch.no_grad(): |
| | for p, p_ema in zip(model.parameters(), ema_model.parameters()): |
| | p_ema.data.lerp_(p.data, 1 - ema_decay) |
| | ``` |
| |
|
| | ### Hyperparameters |
| |
|
| | | Parameter | Value | Notes | |
| | |-----------|-------|-------| |
| | | Optimizer | AdamW | | |
| | | Learning rate | 3e-4 | With cosine schedule | |
| | | Betas | (0.9, 0.99) | | |
| | | Weight decay | 0.01 | | |
| | | Batch size | 32 | 16 × 2 gradient accumulation | |
| | | EMA decay | 0.9999 | | |
| | | Precision | bfloat16 | | |
| | | Timestep shift | s=3.0 | Flux-style | |
| | | Timestep sampling | Logit-normal | | |
| | | Lune weight | 0.1 | | |
| | | Sol weight | 0.05 | | |
| | | Huber delta | 0.1 | | |
| |
|
| | --- |
| |
|
| | ## Checkpoint Conversion |
| |
|
| | ### v3 → v4.1 Conversion |
| |
|
| | The converter preserves all pretrained weights and initializes new v4.1 components to identity/zero-effect: |
| |
|
| | **What gets converted:** |
| | | v3 Key | v4.1 Key | Action | |
| | |--------|----------|--------| |
| | | `expert_predictor.*` | `lune_predictor.*` | Rename | |
| | | `expert_gate` (0.5) | `expert_gate` (0.0) | Convert to logit space | |
| | | - | `sol_prior.*` | Initialize (zero-effect) | |
| | | - | `t5_pool.*` | Initialize (Xavier) | |
| | | - | `text_balance` | Initialize (0.0 = 50/50) | |
| | | - | `*.spatial_to_mod.*` | Initialize (zero = identity) | |
| |
|
| | **Parameter growth:** |
| | - v3: ~244.7M parameters |
| | - v4.1: ~246.4M parameters |
| | - Added: ~1.7M parameters (0.7% increase) |
| |
|
| | ### Python API |
| |
|
| | ```python |
| | from huggingface_hub import hf_hub_download |
| | |
| | # Download converter |
| | converter = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/convert_v3_to_v4.py") |
| | exec(open(converter).read()) |
| | |
| | # Simple: download, convert, upload |
| | from convert_v3_to_v4 import run |
| | result = run(401434) # Step number |
| | |
| | # With custom config |
| | result = run(401434, config={ |
| | "hidden_size": 512, |
| | "num_attention_heads": 4, |
| | "sol_geometric_weight": 0.8, # More geometric, less learned |
| | }) |
| | |
| | # From JSON config file |
| | result = run(401434, config="my_config.json") |
| | |
| | # Low-level API |
| | from convert_v3_to_v4 import convert_state_dict, analyze_checkpoint, TinyFluxConfig |
| | |
| | # Analyze checkpoint version |
| | state_dict = load_file("checkpoint.safetensors") |
| | info = analyze_checkpoint(state_dict) |
| | print(f"Version: {info.version}") # "v3", "v4.0", "v4.1", etc. |
| | print(f"Has Sol prior: {info.has_sol_prior}") |
| | |
| | # Convert state dict |
| | config = TinyFluxConfig() |
| | v4_state, report = convert_state_dict(state_dict, config) |
| | print(f"Renamed {len(report['renamed'])} keys") |
| | print(f"Initialized {len(report['initialized'])} keys") |
| | ``` |
| |
|
| | ### CLI |
| |
|
| | ```bash |
| | # Basic conversion |
| | python convert_v3_to_v4.py --step 401434 |
| | |
| | # Local file |
| | python convert_v3_to_v4.py --input model_v3.safetensors |
| | |
| | # Analyze only (don't convert) |
| | python convert_v3_to_v4.py --step 401434 --analyze-only |
| | |
| | # Custom output |
| | python convert_v3_to_v4.py --step 401434 --output-dir my_converted --name mymodel |
| | |
| | # With custom config |
| | python convert_v3_to_v4.py --step 401434 --config my_config.json |
| | ``` |
| |
|
| | ### Output Structure |
| |
|
| | ``` |
| | checkpoint_runs/v4_init/ |
| | ├── lailah_401434_v4_init.safetensors # Converted model |
| | ├── lailah_401434_v4_init_ema.safetensors # Fresh EMA (copy of model) |
| | ├── lailah_401434_v4_init_ema_secondary.safetensors # Converted old EMA |
| | └── lailah_401434_v4_config.json # Config with conversion metadata |
| | ``` |
| |
|
| | ### Config JSON Format |
| |
|
| | ```json |
| | { |
| | "hidden_size": 512, |
| | "num_attention_heads": 4, |
| | "attention_head_dim": 128, |
| | "num_double_layers": 15, |
| | "num_single_layers": 25, |
| | "use_lune_expert": true, |
| | "use_sol_prior": true, |
| | "use_t5_vec": true, |
| | "sol_geometric_weight": 0.7, |
| | "lune_distill_mode": "cosine", |
| | "use_huber_loss": true, |
| | "huber_delta": 0.1, |
| | "_conversion_info": { |
| | "source_step": 401434, |
| | "source_repo": "AbstractPhil/tiny-flux-deep", |
| | "source_version": "v3", |
| | "target_version": "v4.1", |
| | "source_params": 244690849, |
| | "target_params": 246347234, |
| | "params_added": 1656385, |
| | "converter_version": "4.1.0" |
| | } |
| | } |
| | ``` |
| |
|
| | --- |
| |
|
| | ## Repository Structure |
| |
|
| | ``` |
| | AbstractPhil/tiny-flux-deep/ |
| | │ |
| | ├── model.safetensors # Latest training weights |
| | ├── model_ema.safetensors # EMA weights (use for inference) |
| | ├── config.json # Model configuration |
| | ├── README.md |
| | │ |
| | ├── scripts/ # All Python code |
| | │ ├── model_v4.py # v4.1 architecture (current) |
| | │ ├── model_v3.py # v3 architecture (reference) |
| | │ ├── model_v2.py # v2 architecture (legacy) |
| | │ ├── inference_v3.py # Full inference pipeline |
| | │ ├── convert_v3_to_v4.py # Checkpoint converter |
| | │ ├── trainer_v3_expert_guidance.py # Training with distillation |
| | │ ├── trainer_v2.py # Previous trainer |
| | │ ├── trainer.py # Original trainer |
| | │ ├── port_tiny_to_deep.py # TinyFlux → Deep port script |
| | │ └── colab_inference_lailah_early.py # Simple Colab notebook |
| | │ |
| | ├── checkpoints/ # v3 checkpoints (legacy) |
| | │ ├── step_401434.safetensors |
| | │ └── step_401434_ema.safetensors |
| | │ |
| | ├── checkpoint_runs/ # Organized checkpoint runs |
| | │ └── v4_init/ # v4.1 initialization from v3 |
| | │ ├── lailah_401434_v4_init.safetensors |
| | │ ├── lailah_401434_v4_init_ema.safetensors |
| | │ ├── lailah_401434_v4_init_ema_secondary.safetensors |
| | │ └── lailah_401434_v4_config.json |
| | │ |
| | ├── samples/ # Generated samples per step |
| | │ └── 20260127_074318_step_401434.png |
| | │ |
| | └── logs/ # TensorBoard training logs |
| | └── run_20260126_220714/ |
| | ``` |
| |
|
| | --- |
| |
|
| | ## API Reference |
| |
|
| | ### TinyFluxDeep |
| |
|
| | ```python |
| | class TinyFluxDeep(nn.Module): |
| | def __init__(self, config: Optional[TinyFluxConfig] = None): |
| | """Initialize model with config (uses defaults if None).""" |
| | |
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, # [B, N, 16] image latents |
| | encoder_hidden_states: torch.Tensor, # [B, L, 768] T5 embeddings |
| | pooled_projections: torch.Tensor, # [B, 768] CLIP pooled |
| | timestep: torch.Tensor, # [B] timestep in [0, 1] |
| | img_ids: torch.Tensor, # [N, 3] position IDs |
| | txt_ids: Optional[torch.Tensor] = None, |
| | guidance: Optional[torch.Tensor] = None, # Legacy |
| | lune_features: Optional[torch.Tensor] = None, # [B, 1280] SD1.5 features |
| | sol_stats: Optional[torch.Tensor] = None, # [B, 3] attention stats |
| | sol_spatial: Optional[torch.Tensor] = None, # [B, 8, 8] spatial importance |
| | expert_features: Optional[torch.Tensor] = None, # Legacy API |
| | return_expert_pred: bool = False, |
| | ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]: |
| | """ |
| | Forward pass. |
| | |
| | Returns: |
| | output: [B, N, 16] predicted velocity |
| | expert_info: dict with predictions (if return_expert_pred=True) |
| | """ |
| | |
| | def compute_loss( |
| | self, |
| | output: torch.Tensor, |
| | target: torch.Tensor, |
| | expert_info: Optional[Dict] = None, |
| | lune_features: Optional[torch.Tensor] = None, |
| | sol_stats: Optional[torch.Tensor] = None, |
| | sol_spatial: Optional[torch.Tensor] = None, |
| | lune_weight: float = 0.1, |
| | sol_weight: float = 0.05, |
| | use_huber: bool = True, |
| | huber_delta: float = 0.1, |
| | lune_distill_mode: str = "cosine", |
| | spatial_weighting: bool = True, |
| | ) -> Dict[str, torch.Tensor]: |
| | """Compute combined loss with distillation.""" |
| | |
| | @staticmethod |
| | def create_img_ids(batch_size: int, height: int, width: int, device) -> torch.Tensor: |
| | """Create image position IDs for RoPE.""" |
| | |
| | @staticmethod |
| | def create_txt_ids(text_len: int, device) -> torch.Tensor: |
| | """Create text position IDs.""" |
| | |
| | def count_parameters(self) -> Dict[str, int]: |
| | """Count parameters by component.""" |
| | ``` |
| |
|
| | ### Converter Functions |
| |
|
| | ```python |
| | # High-level |
| | def run(step, name="lailah", config=None, ...): |
| | """One-liner: download, convert, upload.""" |
| | |
| | def convert_checkpoint(step=None, input_path=None, config=None, ...) -> ConversionResult: |
| | """Convert checkpoint with full control.""" |
| | |
| | # Low-level |
| | def analyze_checkpoint(state_dict) -> CheckpointInfo: |
| | """Analyze checkpoint version and contents.""" |
| | |
| | def convert_state_dict(state_dict, config=None) -> Tuple[Dict, Dict]: |
| | """Convert state dict, return (new_state, report).""" |
| | |
| | def download_from_hf(step, repo_id, ...) -> Tuple[str, str]: |
| | """Download checkpoint from HuggingFace.""" |
| | |
| | # Config |
| | class TinyFluxConfig: |
| | def to_dict(self) -> Dict |
| | def from_dict(cls, d) -> TinyFluxConfig |
| | def from_json(cls, path) -> TinyFluxConfig |
| | def save_json(self, path, metadata=None) |
| | def validate_checkpoint(self, state_dict) -> List[str] |
| | ``` |
| |
|
| | --- |
| |
|
| | ## Samples |
| |
|
| | ### Step 401434 (v3 weights) |
| |
|
| | **"subject, animal, cat, photograph of a tiger, natural habitat"** |
| |
|
| |  |
| |
|
| | **"subject, bird, blue beak, red eyes, green claws"** |
| |
|
| |  |
| |
|
| | **"subject, bird, red haired bird in a tree"** |
| |
|
| |  |
| |
|
| | --- |
| |
|
| | ## Limitations |
| |
|
| | | Limitation | Details | |
| | |------------|---------| |
| | | **Resolution** | 512×512 only (64×64 latent space) | |
| | | **Text encoder** | flan-t5-base (768 dim) vs Flux's T5-XXL (4096 dim) | |
| | | **Attention heads** | 4 heads vs Flux's 24 - limits capacity | |
| | | **Training data** | Teacher latents, not real images | |
| | | **v4.1 status** | New architecture, training just starting | |
| | | **Artifacts** | Expect imperfections - research model | |
| |
|
| | --- |
| |
|
| | ## Name |
| |
|
| | **Lailah** (לילה) — In Jewish tradition, the angel of the night who guards souls and teaches wisdom to the unborn. Chosen for this model's role as a smaller guardian exploring the same latent space as larger models, learning from their knowledge while finding its own path. |
| |
|
| | --- |
| |
|
| | ## Citation |
| |
|
| | ```bibtex |
| | @misc{tinyfluxdeep2026, |
| | title={TinyFlux-Deep: Compact Flow Matching with Dual Expert Distillation}, |
| | author={AbstractPhil}, |
| | year={2026}, |
| | howpublished={\url{https://huggingface.co/AbstractPhil/tiny-flux-deep}}, |
| | note={246M parameter text-to-image model with Lune trajectory guidance and Sol attention priors} |
| | } |
| | ``` |
| |
|
| | --- |
| |
|
| | ## Related Projects |
| |
|
| | | Project | Description | |
| | |---------|-------------| |
| | | [AbstractPhil/tiny-flux](https://huggingface.co/AbstractPhil/tiny-flux) | Original TinyFlux (10.7M params) | |
| | | [AbstractPhil/flux-schnell-teacher-latents](https://huggingface.co/datasets/AbstractPhil/flux-schnell-teacher-latents) | Training dataset | |
| | | [AbstractPhil/imagenet-synthetic](https://huggingface.co/datasets/AbstractPhil/imagenet-synthetic) | ImageNet-style synthetic data | |
| | | [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | Teacher model | |
| |
|
| | --- |
| |
|
| | ## License |
| |
|
| | MIT License - free for research and commercial use. |
| |
|
| | --- |
| |
|
| | **Status**: v4.1 architecture complete. Converting v3 checkpoints and resuming training with dual expert distillation. |