|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
tags: |
|
|
- diffusion |
|
|
- flow-matching |
|
|
- flux |
|
|
- text-to-image |
|
|
- image-generation |
|
|
- tinyflux |
|
|
- lailah |
|
|
- experimental |
|
|
library_name: pytorch |
|
|
pipeline_tag: text-to-image |
|
|
base_model: |
|
|
- AbstractPhil/tiny-flux |
|
|
- black-forest-labs/FLUX.1-schnell |
|
|
datasets: |
|
|
- AbstractPhil/flux-schnell-teacher-latents |
|
|
- AbstractPhil/imagenet-synthetic |
|
|
--- |
|
|
|
|
|
# TinyFlux-Deep v4.1 (Lailah) |
|
|
|
|
|
A compact **246M parameter** flow-matching diffusion model that distills knowledge from multiple teacher models into an efficient architecture. TinyFlux-Deep uses a dual expert system to capture both trajectory dynamics (from SD1.5) and structural attention patterns (from a geometric prior), enabling high-quality image generation at a fraction of the compute cost of full-scale models. |
|
|
|
|
|
## Table of Contents |
|
|
|
|
|
- [Key Features](#key-features) |
|
|
- [Quick Start](#quick-start) |
|
|
- [Architecture](#architecture) |
|
|
- [Dual Expert System](#dual-expert-system) |
|
|
- [Configuration](#configuration) |
|
|
- [Inference](#inference) |
|
|
- [Training](#training) |
|
|
- [Checkpoint Conversion](#checkpoint-conversion) |
|
|
- [Repository Structure](#repository-structure) |
|
|
- [API Reference](#api-reference) |
|
|
- [Samples](#samples) |
|
|
- [Limitations](#limitations) |
|
|
- [Citation](#citation) |
|
|
|
|
|
--- |
|
|
|
|
|
## Key Features |
|
|
|
|
|
| Feature | Description | |
|
|
|---------|-------------| |
|
|
| **Compact Size** | 246M params (~500MB bf16) vs Flux's 12B (~24GB) | |
|
|
| **Dual Expert Distillation** | Learns from both SD1.5 trajectory features and geometric attention priors | |
|
|
| **Flow Matching** | Rectified flow objective with Flux-style timestep shifting | |
|
|
| **T5 + CLIP Conditioning** | Dual text encoder pathway with learnable balance | |
|
|
| **Huber Loss** | Robust training that handles outliers gracefully | |
|
|
| **Identity-Init Conversion** | v3→v4 conversion preserves pretrained weights exactly | |
|
|
|
|
|
--- |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
### Colab Inference |
|
|
|
|
|
```python |
|
|
!pip install torch transformers safetensors huggingface_hub accelerate |
|
|
|
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
# Download model code and weights |
|
|
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py") |
|
|
weights = hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors") |
|
|
|
|
|
# Load model |
|
|
exec(open(model_py).read()) |
|
|
config = TinyFluxConfig() |
|
|
model = TinyFluxDeep(config).to("cuda", torch.bfloat16) |
|
|
model.load_state_dict(load_file(weights), strict=False) |
|
|
model.eval() |
|
|
|
|
|
# For full inference pipeline with text encoders and sampling: |
|
|
inference_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/inference_v3.py") |
|
|
exec(open(inference_py).read()) |
|
|
# Then call: image = generate("your prompt here") |
|
|
``` |
|
|
|
|
|
### Minimal Generation Loop |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
def flux_shift(t, s=3.0): |
|
|
"""Flux-style timestep shifting - biases toward data end.""" |
|
|
return s * t / (1 + (s - 1) * t) |
|
|
|
|
|
def generate(model, t5_emb, clip_emb, clip_pooled, num_steps=25, cfg_scale=4.0): |
|
|
"""Euler sampling with classifier-free guidance.""" |
|
|
device = next(model.parameters()).device |
|
|
dtype = next(model.parameters()).dtype |
|
|
|
|
|
# Start from pure noise (t=0) |
|
|
x = torch.randn(1, 64*64, 16, device=device, dtype=dtype) |
|
|
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device) |
|
|
|
|
|
# Rectified flow: integrate from t=0 (noise) to t=1 (data) |
|
|
timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1, device=device)) |
|
|
|
|
|
for i in range(num_steps): |
|
|
t_curr = timesteps[i] |
|
|
t_next = timesteps[i + 1] |
|
|
dt = t_next - t_curr |
|
|
|
|
|
t_batch = t_curr.expand(1) |
|
|
|
|
|
# Conditional prediction |
|
|
v_cond = model( |
|
|
hidden_states=x, |
|
|
encoder_hidden_states=t5_emb, |
|
|
pooled_projections=clip_pooled, |
|
|
timestep=t_batch, |
|
|
img_ids=img_ids, |
|
|
) |
|
|
|
|
|
# Unconditional prediction (for CFG) |
|
|
v_uncond = model( |
|
|
hidden_states=x, |
|
|
encoder_hidden_states=torch.zeros_like(t5_emb), |
|
|
pooled_projections=torch.zeros_like(clip_pooled), |
|
|
timestep=t_batch, |
|
|
img_ids=img_ids, |
|
|
) |
|
|
|
|
|
# Classifier-free guidance |
|
|
v = v_uncond + cfg_scale * (v_cond - v_uncond) |
|
|
|
|
|
# Euler step |
|
|
x = x + v * dt |
|
|
|
|
|
return x # [1, 4096, 16] - decode with VAE |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Architecture |
|
|
|
|
|
### Model Comparison |
|
|
|
|
|
| Component | TinyFlux | TinyFlux-Deep v3 | TinyFlux-Deep v4.1 | Flux-Schnell | |
|
|
|-----------|----------|------------------|--------------------| -------------| |
|
|
| Hidden size | 256 | 512 | 512 | 3072 | |
|
|
| Attention heads | 2 | 4 | 4 | 24 | |
|
|
| Head dimension | 128 | 128 | 128 | 128 | |
|
|
| Double-stream layers | 3 | 15 | 15 | 19 | |
|
|
| Single-stream layers | 3 | 25 | 25 | 38 | |
|
|
| MLP ratio | 4.0 | 4.0 | 4.0 | 4.0 | |
|
|
| RoPE dims | (16,56,56) | (16,56,56) | (16,56,56) | (16,56,56) | |
|
|
| Lune Expert | ✗ | ✓ | ✓ | ✗ | |
|
|
| Sol Attention Prior | ✗ | ✗ | ✓ | ✗ | |
|
|
| T5 Vec Enhancement | ✗ | ✗ | ✓ | ✗ | |
|
|
| **Total Parameters** | ~10.7M | ~244.7M | ~246.4M | ~12B | |
|
|
| **Memory (bf16)** | ~22MB | ~490MB | ~493MB | ~24GB | |
|
|
|
|
|
### Block Structure |
|
|
|
|
|
**Double-Stream Blocks (15 layers):** |
|
|
- Separate text and image pathways |
|
|
- Joint attention between modalities |
|
|
- AdaLN-Zero conditioning from vec |
|
|
- Sol spatial modulation on image Q/K only |
|
|
|
|
|
**Single-Stream Blocks (25 layers):** |
|
|
- Concatenated text + image sequence |
|
|
- Full self-attention with RoPE |
|
|
- Sol modulation skips text tokens |
|
|
|
|
|
``` |
|
|
Input: img_latents [B, 4096, 16], t5_emb [B, 77, 768], clip_pooled [B, 768] |
|
|
│ |
|
|
┌───────────────┴───────────────┐ |
|
|
▼ ▼ |
|
|
img_in (Linear) txt_in (Linear) |
|
|
│ │ |
|
|
▼ ▼ |
|
|
[B, 4096, 512] [B, 77, 512] |
|
|
│ │ |
|
|
└───────────┬───────────────────┘ |
|
|
│ |
|
|
vec = time_emb + clip_vec + t5_vec + lune_signal |
|
|
│ |
|
|
┌───────────┴───────────┐ |
|
|
▼ ▼ |
|
|
Double Blocks (×15) Sol Prior → temperature, spatial_mod |
|
|
│ │ |
|
|
▼ │ |
|
|
Single Blocks (×25) ◄─────────┘ |
|
|
│ |
|
|
▼ |
|
|
final_norm → final_linear |
|
|
│ |
|
|
▼ |
|
|
Output: [B, 4096, 16] |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Dual Expert System |
|
|
|
|
|
TinyFlux-Deep v4.1 uses two complementary expert pathways to inject knowledge from teacher models without the "twin-tail paradox" (mixing incompatible prediction targets). |
|
|
|
|
|
### Lune Expert Predictor (Trajectory Guidance) |
|
|
|
|
|
**Purpose:** Captures SD1.5's understanding of "how denoising should flow" - the trajectory through latent space. |
|
|
|
|
|
**Architecture:** |
|
|
```python |
|
|
LuneExpertPredictor( |
|
|
time_dim=512, # From timestep MLP |
|
|
clip_dim=768, # CLIP pooled features |
|
|
expert_dim=1280, # SD1.5 mid-block dimension (prediction target) |
|
|
hidden_dim=512, # Internal MLP width |
|
|
output_dim=512, # Output added to vec |
|
|
dropout=0.1, |
|
|
) |
|
|
``` |
|
|
|
|
|
**How it works:** |
|
|
1. Concatenates timestep embedding + CLIP pooled → hidden |
|
|
2. Predicts what SD1.5's mid-block features would be |
|
|
3. During training: uses real SD1.5 features when available |
|
|
4. During inference: uses predicted features |
|
|
5. Gates output with learnable sigmoid (init 0.5) |
|
|
6. Adds `expert_signal` to global `vec` conditioning |
|
|
|
|
|
**Training signal:** Cosine similarity loss against real SD1.5 UNet mid-block features (soft directional matching, not exact reconstruction). |
|
|
|
|
|
### Sol Attention Prior (Structural Guidance) |
|
|
|
|
|
**Purpose:** Captures geometric/structural knowledge about WHERE attention should focus, without injecting incompatible features. |
|
|
|
|
|
**Key insight:** Sol (a V-prediction DDPM model) has valuable attention patterns, but its features are incompatible with TinyFlux's linear flow matching. We extract attention *statistics* instead: |
|
|
- **Locality:** How local vs global is attention? |
|
|
- **Entropy:** How focused vs diffuse? |
|
|
- **Clustering:** How structured vs uniform? |
|
|
- **Spatial importance:** Which regions matter most? |
|
|
|
|
|
**Architecture:** |
|
|
```python |
|
|
SolAttentionPrior( |
|
|
time_dim=512, |
|
|
clip_dim=768, |
|
|
hidden_dim=256, |
|
|
num_heads=4, # Matches TinyFlux attention heads |
|
|
spatial_size=8, # 8×8 importance map |
|
|
geometric_weight=0.7, # David's 70/30 split |
|
|
) |
|
|
``` |
|
|
|
|
|
**How it works:** |
|
|
1. **Geometric prior (70%):** Timestep-based heuristics |
|
|
- Early denoising (high t): Higher temperature → softer, global attention |
|
|
- Late denoising (low t): Lower temperature → sharper, local attention |
|
|
- Spatial: Uniform early, center-biased late |
|
|
|
|
|
2. **Learned prior (30%):** Content-based predictions |
|
|
- Predicts attention statistics from (timestep, CLIP) |
|
|
- Predicts spatial importance map |
|
|
|
|
|
3. **Blending:** `blend * geometric + (1-blend) * learned` with learnable blend gate |
|
|
|
|
|
4. **Output application:** |
|
|
- `temperature [B, 4]`: Scales attention logits per head |
|
|
- `spatial_mod [B, H, W]`: Modulates Q/K at each position via `exp(conv(spatial))` |
|
|
|
|
|
**Identity initialization:** All Sol components initialize to zero-effect: |
|
|
- `spatial_to_mod` Conv2d: zero weight, zero bias → `exp(0) = 1` (identity) |
|
|
- Allows gradual learning without disrupting pretrained attention |
|
|
|
|
|
### T5 Vec Enhancement |
|
|
|
|
|
**Purpose:** Adds T5's semantic understanding to the global conditioning pathway (previously only CLIP pooled). |
|
|
|
|
|
```python |
|
|
# Attention-weighted pooling of T5 sequence |
|
|
t5_attn = softmax(t5_emb.mean(dim=-1)) # [B, 77] |
|
|
t5_pooled = (t5_emb * t5_attn.unsqueeze(-1)).sum(dim=1) # [B, 768] |
|
|
t5_vec = t5_pool_mlp(t5_pooled) # [B, 512] |
|
|
|
|
|
# Learnable balance between CLIP and T5 |
|
|
balance = sigmoid(text_balance) # Initialized to 0.5 |
|
|
text_vec = balance * clip_vec + (1 - balance) * t5_vec |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Configuration |
|
|
|
|
|
### TinyFluxConfig |
|
|
|
|
|
```python |
|
|
from dataclasses import dataclass |
|
|
from typing import Tuple |
|
|
|
|
|
@dataclass |
|
|
class TinyFluxConfig: |
|
|
# Core architecture |
|
|
hidden_size: int = 512 |
|
|
num_attention_heads: int = 4 |
|
|
attention_head_dim: int = 128 # hidden_size = heads × head_dim |
|
|
in_channels: int = 16 # VAE latent channels |
|
|
patch_size: int = 1 |
|
|
joint_attention_dim: int = 768 # T5 embedding dim |
|
|
pooled_projection_dim: int = 768 # CLIP pooled dim |
|
|
num_double_layers: int = 15 |
|
|
num_single_layers: int = 25 |
|
|
mlp_ratio: float = 4.0 |
|
|
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) # Must sum to head_dim |
|
|
|
|
|
# Lune expert predictor |
|
|
use_lune_expert: bool = True |
|
|
lune_expert_dim: int = 1280 # SD1.5 mid-block dim |
|
|
lune_hidden_dim: int = 512 |
|
|
lune_dropout: float = 0.1 |
|
|
|
|
|
# Sol attention prior |
|
|
use_sol_prior: bool = True |
|
|
sol_spatial_size: int = 8 # 8×8 spatial importance map |
|
|
sol_hidden_dim: int = 256 |
|
|
sol_geometric_weight: float = 0.7 # 70% geometric, 30% learned |
|
|
|
|
|
# T5 enhancement |
|
|
use_t5_vec: bool = True |
|
|
t5_pool_mode: str = "attention" # "attention", "mean", "cls" |
|
|
|
|
|
# Loss configuration |
|
|
lune_distill_mode: str = "cosine" # "hard", "soft", "cosine", "huber" |
|
|
use_huber_loss: bool = True |
|
|
huber_delta: float = 0.1 |
|
|
|
|
|
# Legacy compatibility |
|
|
guidance_embeds: bool = False |
|
|
``` |
|
|
|
|
|
### Loading from JSON |
|
|
|
|
|
```python |
|
|
# From file |
|
|
config = TinyFluxConfig.from_json("lailah_401434_v4_config.json") |
|
|
|
|
|
# From dict |
|
|
config = TinyFluxConfig.from_dict({ |
|
|
"hidden_size": 512, |
|
|
"num_attention_heads": 4, |
|
|
... |
|
|
}) |
|
|
|
|
|
# Save with metadata |
|
|
config.save_json("config.json", metadata={"source_step": 401434}) |
|
|
``` |
|
|
|
|
|
### Validation |
|
|
|
|
|
```python |
|
|
# Config validates constraints on creation |
|
|
config = TinyFluxConfig(hidden_size=512, num_attention_heads=4, attention_head_dim=128) |
|
|
# ✓ OK: 512 = 4 × 128 |
|
|
|
|
|
config = TinyFluxConfig(hidden_size=512, num_attention_heads=4, attention_head_dim=64) |
|
|
# ✗ ValueError: hidden_size (512) must equal num_attention_heads * attention_head_dim (256) |
|
|
|
|
|
# Validate checkpoint compatibility |
|
|
warnings = config.validate_checkpoint(state_dict) |
|
|
if warnings: |
|
|
print("Warnings:", warnings) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Inference |
|
|
|
|
|
### Full Pipeline |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer |
|
|
from diffusers import AutoencoderKL |
|
|
from huggingface_hub import hf_hub_download |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
# Load text encoders |
|
|
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") |
|
|
t5_model = T5EncoderModel.from_pretrained("google/flan-t5-base").to("cuda", torch.bfloat16) |
|
|
|
|
|
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
|
clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda", torch.bfloat16) |
|
|
|
|
|
# Load VAE |
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
"black-forest-labs/FLUX.1-schnell", |
|
|
subfolder="vae", |
|
|
torch_dtype=torch.bfloat16 |
|
|
).to("cuda") |
|
|
|
|
|
# Load TinyFlux-Deep |
|
|
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py") |
|
|
exec(open(model_py).read()) |
|
|
|
|
|
config = TinyFluxConfig() |
|
|
model = TinyFluxDeep(config).to("cuda", torch.bfloat16) |
|
|
weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors")) |
|
|
model.load_state_dict(weights, strict=False) |
|
|
model.eval() |
|
|
|
|
|
def encode_prompt(prompt): |
|
|
"""Encode prompt with both T5 and CLIP.""" |
|
|
# T5 |
|
|
t5_tokens = t5_tokenizer(prompt, return_tensors="pt", padding="max_length", |
|
|
max_length=77, truncation=True).to("cuda") |
|
|
with torch.no_grad(): |
|
|
t5_emb = t5_model(**t5_tokens).last_hidden_state.to(torch.bfloat16) |
|
|
|
|
|
# CLIP |
|
|
clip_tokens = clip_tokenizer(prompt, return_tensors="pt", padding="max_length", |
|
|
max_length=77, truncation=True).to("cuda") |
|
|
with torch.no_grad(): |
|
|
clip_out = clip_model(**clip_tokens) |
|
|
clip_pooled = clip_out.pooler_output.to(torch.bfloat16) |
|
|
|
|
|
return t5_emb, clip_pooled |
|
|
|
|
|
def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None): |
|
|
""" |
|
|
Euler sampling for rectified flow. |
|
|
|
|
|
Flow: x_t = (1-t)*noise + t*data |
|
|
Integrate from t=0 (noise) to t=1 (data) |
|
|
""" |
|
|
if seed is not None: |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
t5_emb, clip_pooled = encode_prompt(prompt) |
|
|
|
|
|
# Null embeddings for CFG |
|
|
t5_null, clip_null = encode_prompt("") |
|
|
|
|
|
# Start from pure noise (t=0) |
|
|
x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16) |
|
|
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda") |
|
|
|
|
|
# Rectified flow: 0 → 1 with Flux shift |
|
|
def flux_shift(t, s=3.0): |
|
|
return s * t / (1 + (s - 1) * t) |
|
|
|
|
|
timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1, device="cuda")) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(num_steps): |
|
|
t = timesteps[i].expand(1) |
|
|
dt = timesteps[i + 1] - timesteps[i] # Positive |
|
|
|
|
|
# Conditional |
|
|
v_cond = model(x, t5_emb, clip_pooled, t, img_ids) |
|
|
|
|
|
# Unconditional |
|
|
v_uncond = model(x, t5_null, clip_null, t, img_ids) |
|
|
|
|
|
# CFG |
|
|
v = v_uncond + cfg_scale * (v_cond - v_uncond) |
|
|
|
|
|
# Euler step |
|
|
x = x + v * dt |
|
|
|
|
|
# Decode with VAE |
|
|
x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) # [B, C, H, W] |
|
|
x = x / vae.config.scaling_factor |
|
|
with torch.no_grad(): |
|
|
image = vae.decode(x).sample |
|
|
|
|
|
# Convert to PIL |
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
image = image[0].permute(1, 2, 0).cpu().float().numpy() |
|
|
image = (image * 255).astype("uint8") |
|
|
|
|
|
from PIL import Image |
|
|
return Image.fromarray(image) |
|
|
|
|
|
# Generate! |
|
|
image = generate_image("a photograph of a tiger in natural habitat", seed=42) |
|
|
image.save("tiger.png") |
|
|
``` |
|
|
|
|
|
### Batch Inference |
|
|
|
|
|
```python |
|
|
def generate_batch(prompts, **kwargs): |
|
|
"""Generate multiple images.""" |
|
|
return [generate_image(p, **kwargs) for p in prompts] |
|
|
|
|
|
images = generate_batch([ |
|
|
"a red bird with blue beak", |
|
|
"a mountain landscape at sunset", |
|
|
"an astronaut riding a horse", |
|
|
], num_steps=25, cfg_scale=4.0) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Training |
|
|
|
|
|
### Loss Computation |
|
|
|
|
|
```python |
|
|
# Forward pass with expert info |
|
|
output, expert_info = model( |
|
|
hidden_states=noisy_latents, |
|
|
encoder_hidden_states=t5_emb, |
|
|
pooled_projections=clip_pooled, |
|
|
timestep=timesteps, |
|
|
img_ids=img_ids, |
|
|
lune_features=sd15_midblock_features, # From SD1.5 teacher |
|
|
sol_stats=sol_attention_stats, # From Sol teacher (optional) |
|
|
sol_spatial=sol_spatial_importance, # From Sol teacher (optional) |
|
|
return_expert_pred=True, |
|
|
) |
|
|
|
|
|
# Compute loss |
|
|
losses = model.compute_loss( |
|
|
output=output, |
|
|
target=flow_target, # data - noise for flow matching |
|
|
expert_info=expert_info, |
|
|
lune_features=sd15_midblock_features, |
|
|
sol_stats=sol_attention_stats, |
|
|
sol_spatial=sol_spatial_importance, |
|
|
|
|
|
# Loss weights |
|
|
lune_weight=0.1, # Weight for Lune distillation |
|
|
sol_weight=0.05, # Weight for Sol distillation |
|
|
|
|
|
# Loss options |
|
|
use_huber=True, # Huber loss for main objective (robust to outliers) |
|
|
huber_delta=0.1, # Huber delta (smaller = tighter MSE region) |
|
|
lune_distill_mode="cosine", # "hard", "soft", "cosine", "huber" |
|
|
spatial_weighting=True, # Weight loss by Sol spatial importance |
|
|
) |
|
|
|
|
|
# losses dict contains: |
|
|
# - main: flow matching loss |
|
|
# - lune_distill: Lune prediction loss |
|
|
# - sol_stat_distill: Sol statistics prediction loss |
|
|
# - sol_spatial_distill: Sol spatial prediction loss |
|
|
# - total: weighted sum |
|
|
|
|
|
loss = losses['total'] |
|
|
loss.backward() |
|
|
``` |
|
|
|
|
|
### Distillation Modes |
|
|
|
|
|
| Mode | Description | Use Case | |
|
|
|------|-------------|----------| |
|
|
| `"hard"` | MSE against teacher features | Exact reconstruction | |
|
|
| `"soft"` | Temperature-scaled MSE | Softer matching | |
|
|
| `"cosine"` | Cosine similarity loss | Directional alignment (recommended) | |
|
|
| `"huber"` | Huber loss on features | Robust to outliers | |
|
|
|
|
|
### Training Loop Example |
|
|
|
|
|
```python |
|
|
from torch.optim import AdamW |
|
|
from torch.cuda.amp import autocast, GradScaler |
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.99), weight_decay=0.01) |
|
|
scaler = GradScaler() |
|
|
|
|
|
# EMA |
|
|
ema_decay = 0.9999 |
|
|
ema_model = copy.deepcopy(model) |
|
|
|
|
|
for step, batch in enumerate(dataloader): |
|
|
optimizer.zero_grad() |
|
|
|
|
|
with autocast(dtype=torch.bfloat16): |
|
|
# Sample timesteps with logit-normal distribution |
|
|
u = torch.randn(batch_size, device=device) |
|
|
t = torch.sigmoid(u) # Logit-normal |
|
|
t = flux_shift(t, s=3.0) # Flux shift |
|
|
|
|
|
# Add noise |
|
|
noise = torch.randn_like(batch['latents']) |
|
|
noisy = t.view(-1,1,1) * batch['latents'] + (1-t.view(-1,1,1)) * noise |
|
|
target = batch['latents'] - noise # Flow matching target |
|
|
|
|
|
# Forward |
|
|
output, expert_info = model( |
|
|
hidden_states=noisy, |
|
|
encoder_hidden_states=batch['t5_emb'], |
|
|
pooled_projections=batch['clip_pooled'], |
|
|
timestep=t, |
|
|
img_ids=img_ids, |
|
|
lune_features=batch.get('sd15_features'), |
|
|
return_expert_pred=True, |
|
|
) |
|
|
|
|
|
# Loss |
|
|
losses = model.compute_loss(output, target, expert_info, |
|
|
lune_features=batch.get('sd15_features')) |
|
|
|
|
|
scaler.scale(losses['total']).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
|
|
|
# EMA update |
|
|
with torch.no_grad(): |
|
|
for p, p_ema in zip(model.parameters(), ema_model.parameters()): |
|
|
p_ema.data.lerp_(p.data, 1 - ema_decay) |
|
|
``` |
|
|
|
|
|
### Hyperparameters |
|
|
|
|
|
| Parameter | Value | Notes | |
|
|
|-----------|-------|-------| |
|
|
| Optimizer | AdamW | | |
|
|
| Learning rate | 3e-4 | With cosine schedule | |
|
|
| Betas | (0.9, 0.99) | | |
|
|
| Weight decay | 0.01 | | |
|
|
| Batch size | 32 | 16 × 2 gradient accumulation | |
|
|
| EMA decay | 0.9999 | | |
|
|
| Precision | bfloat16 | | |
|
|
| Timestep shift | s=3.0 | Flux-style | |
|
|
| Timestep sampling | Logit-normal | | |
|
|
| Lune weight | 0.1 | | |
|
|
| Sol weight | 0.05 | | |
|
|
| Huber delta | 0.1 | | |
|
|
|
|
|
--- |
|
|
|
|
|
## Checkpoint Conversion |
|
|
|
|
|
### v3 → v4.1 Conversion |
|
|
|
|
|
The converter preserves all pretrained weights and initializes new v4.1 components to identity/zero-effect: |
|
|
|
|
|
**What gets converted:** |
|
|
| v3 Key | v4.1 Key | Action | |
|
|
|--------|----------|--------| |
|
|
| `expert_predictor.*` | `lune_predictor.*` | Rename | |
|
|
| `expert_gate` (0.5) | `expert_gate` (0.0) | Convert to logit space | |
|
|
| - | `sol_prior.*` | Initialize (zero-effect) | |
|
|
| - | `t5_pool.*` | Initialize (Xavier) | |
|
|
| - | `text_balance` | Initialize (0.0 = 50/50) | |
|
|
| - | `*.spatial_to_mod.*` | Initialize (zero = identity) | |
|
|
|
|
|
**Parameter growth:** |
|
|
- v3: ~244.7M parameters |
|
|
- v4.1: ~246.4M parameters |
|
|
- Added: ~1.7M parameters (0.7% increase) |
|
|
|
|
|
### Python API |
|
|
|
|
|
```python |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# Download converter |
|
|
converter = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/convert_v3_to_v4.py") |
|
|
exec(open(converter).read()) |
|
|
|
|
|
# Simple: download, convert, upload |
|
|
from convert_v3_to_v4 import run |
|
|
result = run(401434) # Step number |
|
|
|
|
|
# With custom config |
|
|
result = run(401434, config={ |
|
|
"hidden_size": 512, |
|
|
"num_attention_heads": 4, |
|
|
"sol_geometric_weight": 0.8, # More geometric, less learned |
|
|
}) |
|
|
|
|
|
# From JSON config file |
|
|
result = run(401434, config="my_config.json") |
|
|
|
|
|
# Low-level API |
|
|
from convert_v3_to_v4 import convert_state_dict, analyze_checkpoint, TinyFluxConfig |
|
|
|
|
|
# Analyze checkpoint version |
|
|
state_dict = load_file("checkpoint.safetensors") |
|
|
info = analyze_checkpoint(state_dict) |
|
|
print(f"Version: {info.version}") # "v3", "v4.0", "v4.1", etc. |
|
|
print(f"Has Sol prior: {info.has_sol_prior}") |
|
|
|
|
|
# Convert state dict |
|
|
config = TinyFluxConfig() |
|
|
v4_state, report = convert_state_dict(state_dict, config) |
|
|
print(f"Renamed {len(report['renamed'])} keys") |
|
|
print(f"Initialized {len(report['initialized'])} keys") |
|
|
``` |
|
|
|
|
|
### CLI |
|
|
|
|
|
```bash |
|
|
# Basic conversion |
|
|
python convert_v3_to_v4.py --step 401434 |
|
|
|
|
|
# Local file |
|
|
python convert_v3_to_v4.py --input model_v3.safetensors |
|
|
|
|
|
# Analyze only (don't convert) |
|
|
python convert_v3_to_v4.py --step 401434 --analyze-only |
|
|
|
|
|
# Custom output |
|
|
python convert_v3_to_v4.py --step 401434 --output-dir my_converted --name mymodel |
|
|
|
|
|
# With custom config |
|
|
python convert_v3_to_v4.py --step 401434 --config my_config.json |
|
|
``` |
|
|
|
|
|
### Output Structure |
|
|
|
|
|
``` |
|
|
checkpoint_runs/v4_init/ |
|
|
├── lailah_401434_v4_init.safetensors # Converted model |
|
|
├── lailah_401434_v4_init_ema.safetensors # Fresh EMA (copy of model) |
|
|
├── lailah_401434_v4_init_ema_secondary.safetensors # Converted old EMA |
|
|
└── lailah_401434_v4_config.json # Config with conversion metadata |
|
|
``` |
|
|
|
|
|
### Config JSON Format |
|
|
|
|
|
```json |
|
|
{ |
|
|
"hidden_size": 512, |
|
|
"num_attention_heads": 4, |
|
|
"attention_head_dim": 128, |
|
|
"num_double_layers": 15, |
|
|
"num_single_layers": 25, |
|
|
"use_lune_expert": true, |
|
|
"use_sol_prior": true, |
|
|
"use_t5_vec": true, |
|
|
"sol_geometric_weight": 0.7, |
|
|
"lune_distill_mode": "cosine", |
|
|
"use_huber_loss": true, |
|
|
"huber_delta": 0.1, |
|
|
"_conversion_info": { |
|
|
"source_step": 401434, |
|
|
"source_repo": "AbstractPhil/tiny-flux-deep", |
|
|
"source_version": "v3", |
|
|
"target_version": "v4.1", |
|
|
"source_params": 244690849, |
|
|
"target_params": 246347234, |
|
|
"params_added": 1656385, |
|
|
"converter_version": "4.1.0" |
|
|
} |
|
|
} |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Repository Structure |
|
|
|
|
|
``` |
|
|
AbstractPhil/tiny-flux-deep/ |
|
|
│ |
|
|
├── model.safetensors # Latest training weights |
|
|
├── model_ema.safetensors # EMA weights (use for inference) |
|
|
├── config.json # Model configuration |
|
|
├── README.md |
|
|
│ |
|
|
├── scripts/ # All Python code |
|
|
│ ├── model_v4.py # v4.1 architecture (current) |
|
|
│ ├── model_v3.py # v3 architecture (reference) |
|
|
│ ├── model_v2.py # v2 architecture (legacy) |
|
|
│ ├── inference_v3.py # Full inference pipeline |
|
|
│ ├── convert_v3_to_v4.py # Checkpoint converter |
|
|
│ ├── trainer_v3_expert_guidance.py # Training with distillation |
|
|
│ ├── trainer_v2.py # Previous trainer |
|
|
│ ├── trainer.py # Original trainer |
|
|
│ ├── port_tiny_to_deep.py # TinyFlux → Deep port script |
|
|
│ └── colab_inference_lailah_early.py # Simple Colab notebook |
|
|
│ |
|
|
├── checkpoints/ # v3 checkpoints (legacy) |
|
|
│ ├── step_401434.safetensors |
|
|
│ └── step_401434_ema.safetensors |
|
|
│ |
|
|
├── checkpoint_runs/ # Organized checkpoint runs |
|
|
│ └── v4_init/ # v4.1 initialization from v3 |
|
|
│ ├── lailah_401434_v4_init.safetensors |
|
|
│ ├── lailah_401434_v4_init_ema.safetensors |
|
|
│ ├── lailah_401434_v4_init_ema_secondary.safetensors |
|
|
│ └── lailah_401434_v4_config.json |
|
|
│ |
|
|
├── samples/ # Generated samples per step |
|
|
│ └── 20260127_074318_step_401434.png |
|
|
│ |
|
|
└── logs/ # TensorBoard training logs |
|
|
└── run_20260126_220714/ |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## API Reference |
|
|
|
|
|
### TinyFluxDeep |
|
|
|
|
|
```python |
|
|
class TinyFluxDeep(nn.Module): |
|
|
def __init__(self, config: Optional[TinyFluxConfig] = None): |
|
|
"""Initialize model with config (uses defaults if None).""" |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, # [B, N, 16] image latents |
|
|
encoder_hidden_states: torch.Tensor, # [B, L, 768] T5 embeddings |
|
|
pooled_projections: torch.Tensor, # [B, 768] CLIP pooled |
|
|
timestep: torch.Tensor, # [B] timestep in [0, 1] |
|
|
img_ids: torch.Tensor, # [N, 3] position IDs |
|
|
txt_ids: Optional[torch.Tensor] = None, |
|
|
guidance: Optional[torch.Tensor] = None, # Legacy |
|
|
lune_features: Optional[torch.Tensor] = None, # [B, 1280] SD1.5 features |
|
|
sol_stats: Optional[torch.Tensor] = None, # [B, 3] attention stats |
|
|
sol_spatial: Optional[torch.Tensor] = None, # [B, 8, 8] spatial importance |
|
|
expert_features: Optional[torch.Tensor] = None, # Legacy API |
|
|
return_expert_pred: bool = False, |
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Returns: |
|
|
output: [B, N, 16] predicted velocity |
|
|
expert_info: dict with predictions (if return_expert_pred=True) |
|
|
""" |
|
|
|
|
|
def compute_loss( |
|
|
self, |
|
|
output: torch.Tensor, |
|
|
target: torch.Tensor, |
|
|
expert_info: Optional[Dict] = None, |
|
|
lune_features: Optional[torch.Tensor] = None, |
|
|
sol_stats: Optional[torch.Tensor] = None, |
|
|
sol_spatial: Optional[torch.Tensor] = None, |
|
|
lune_weight: float = 0.1, |
|
|
sol_weight: float = 0.05, |
|
|
use_huber: bool = True, |
|
|
huber_delta: float = 0.1, |
|
|
lune_distill_mode: str = "cosine", |
|
|
spatial_weighting: bool = True, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Compute combined loss with distillation.""" |
|
|
|
|
|
@staticmethod |
|
|
def create_img_ids(batch_size: int, height: int, width: int, device) -> torch.Tensor: |
|
|
"""Create image position IDs for RoPE.""" |
|
|
|
|
|
@staticmethod |
|
|
def create_txt_ids(text_len: int, device) -> torch.Tensor: |
|
|
"""Create text position IDs.""" |
|
|
|
|
|
def count_parameters(self) -> Dict[str, int]: |
|
|
"""Count parameters by component.""" |
|
|
``` |
|
|
|
|
|
### Converter Functions |
|
|
|
|
|
```python |
|
|
# High-level |
|
|
def run(step, name="lailah", config=None, ...): |
|
|
"""One-liner: download, convert, upload.""" |
|
|
|
|
|
def convert_checkpoint(step=None, input_path=None, config=None, ...) -> ConversionResult: |
|
|
"""Convert checkpoint with full control.""" |
|
|
|
|
|
# Low-level |
|
|
def analyze_checkpoint(state_dict) -> CheckpointInfo: |
|
|
"""Analyze checkpoint version and contents.""" |
|
|
|
|
|
def convert_state_dict(state_dict, config=None) -> Tuple[Dict, Dict]: |
|
|
"""Convert state dict, return (new_state, report).""" |
|
|
|
|
|
def download_from_hf(step, repo_id, ...) -> Tuple[str, str]: |
|
|
"""Download checkpoint from HuggingFace.""" |
|
|
|
|
|
# Config |
|
|
class TinyFluxConfig: |
|
|
def to_dict(self) -> Dict |
|
|
def from_dict(cls, d) -> TinyFluxConfig |
|
|
def from_json(cls, path) -> TinyFluxConfig |
|
|
def save_json(self, path, metadata=None) |
|
|
def validate_checkpoint(self, state_dict) -> List[str] |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Samples |
|
|
|
|
|
### Step 401434 (v3 weights) |
|
|
|
|
|
**"subject, animal, cat, photograph of a tiger, natural habitat"** |
|
|
|
|
|
 |
|
|
|
|
|
**"subject, bird, blue beak, red eyes, green claws"** |
|
|
|
|
|
 |
|
|
|
|
|
**"subject, bird, red haired bird in a tree"** |
|
|
|
|
|
 |
|
|
|
|
|
--- |
|
|
|
|
|
## Limitations |
|
|
|
|
|
| Limitation | Details | |
|
|
|------------|---------| |
|
|
| **Resolution** | 512×512 only (64×64 latent space) | |
|
|
| **Text encoder** | flan-t5-base (768 dim) vs Flux's T5-XXL (4096 dim) | |
|
|
| **Attention heads** | 4 heads vs Flux's 24 - limits capacity | |
|
|
| **Training data** | Teacher latents, not real images | |
|
|
| **v4.1 status** | New architecture, training just starting | |
|
|
| **Artifacts** | Expect imperfections - research model | |
|
|
|
|
|
--- |
|
|
|
|
|
## Name |
|
|
|
|
|
**Lailah** (לילה) — In Jewish tradition, the angel of the night who guards souls and teaches wisdom to the unborn. Chosen for this model's role as a smaller guardian exploring the same latent space as larger models, learning from their knowledge while finding its own path. |
|
|
|
|
|
--- |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@misc{tinyfluxdeep2026, |
|
|
title={TinyFlux-Deep: Compact Flow Matching with Dual Expert Distillation}, |
|
|
author={AbstractPhil}, |
|
|
year={2026}, |
|
|
howpublished={\url{https://huggingface.co/AbstractPhil/tiny-flux-deep}}, |
|
|
note={246M parameter text-to-image model with Lune trajectory guidance and Sol attention priors} |
|
|
} |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Related Projects |
|
|
|
|
|
| Project | Description | |
|
|
|---------|-------------| |
|
|
| [AbstractPhil/tiny-flux](https://huggingface.co/AbstractPhil/tiny-flux) | Original TinyFlux (10.7M params) | |
|
|
| [AbstractPhil/flux-schnell-teacher-latents](https://huggingface.co/datasets/AbstractPhil/flux-schnell-teacher-latents) | Training dataset | |
|
|
| [AbstractPhil/imagenet-synthetic](https://huggingface.co/datasets/AbstractPhil/imagenet-synthetic) | ImageNet-style synthetic data | |
|
|
| [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | Teacher model | |
|
|
|
|
|
--- |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License - free for research and commercial use. |
|
|
|
|
|
--- |
|
|
|
|
|
**Status**: v4.1 architecture complete. Converting v3 checkpoints and resuming training with dual expert distillation. |