tiny-flux-deep / README.md
AbstractPhil's picture
Update README.md
3beb7ef verified
---
license: mit
language:
- en
tags:
- diffusion
- flow-matching
- flux
- text-to-image
- image-generation
- tinyflux
- lailah
- experimental
library_name: pytorch
pipeline_tag: text-to-image
base_model:
- AbstractPhil/tiny-flux
- black-forest-labs/FLUX.1-schnell
datasets:
- AbstractPhil/flux-schnell-teacher-latents
- AbstractPhil/imagenet-synthetic
---
# TinyFlux-Deep v4.1 (Lailah)
A compact **246M parameter** flow-matching diffusion model that distills knowledge from multiple teacher models into an efficient architecture. TinyFlux-Deep uses a dual expert system to capture both trajectory dynamics (from SD1.5) and structural attention patterns (from a geometric prior), enabling high-quality image generation at a fraction of the compute cost of full-scale models.
## Table of Contents
- [Key Features](#key-features)
- [Quick Start](#quick-start)
- [Architecture](#architecture)
- [Dual Expert System](#dual-expert-system)
- [Configuration](#configuration)
- [Inference](#inference)
- [Training](#training)
- [Checkpoint Conversion](#checkpoint-conversion)
- [Repository Structure](#repository-structure)
- [API Reference](#api-reference)
- [Samples](#samples)
- [Limitations](#limitations)
- [Citation](#citation)
---
## Key Features
| Feature | Description |
|---------|-------------|
| **Compact Size** | 246M params (~500MB bf16) vs Flux's 12B (~24GB) |
| **Dual Expert Distillation** | Learns from both SD1.5 trajectory features and geometric attention priors |
| **Flow Matching** | Rectified flow objective with Flux-style timestep shifting |
| **T5 + CLIP Conditioning** | Dual text encoder pathway with learnable balance |
| **Huber Loss** | Robust training that handles outliers gracefully |
| **Identity-Init Conversion** | v3→v4 conversion preserves pretrained weights exactly |
---
## Quick Start
### Colab Inference
```python
!pip install torch transformers safetensors huggingface_hub accelerate
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# Download model code and weights
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
weights = hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors")
# Load model
exec(open(model_py).read())
config = TinyFluxConfig()
model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
model.load_state_dict(load_file(weights), strict=False)
model.eval()
# For full inference pipeline with text encoders and sampling:
inference_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/inference_v3.py")
exec(open(inference_py).read())
# Then call: image = generate("your prompt here")
```
### Minimal Generation Loop
```python
import torch
import torch.nn.functional as F
def flux_shift(t, s=3.0):
"""Flux-style timestep shifting - biases toward data end."""
return s * t / (1 + (s - 1) * t)
def generate(model, t5_emb, clip_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
"""Euler sampling with classifier-free guidance."""
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# Start from pure noise (t=0)
x = torch.randn(1, 64*64, 16, device=device, dtype=dtype)
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device)
# Rectified flow: integrate from t=0 (noise) to t=1 (data)
timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1, device=device))
for i in range(num_steps):
t_curr = timesteps[i]
t_next = timesteps[i + 1]
dt = t_next - t_curr
t_batch = t_curr.expand(1)
# Conditional prediction
v_cond = model(
hidden_states=x,
encoder_hidden_states=t5_emb,
pooled_projections=clip_pooled,
timestep=t_batch,
img_ids=img_ids,
)
# Unconditional prediction (for CFG)
v_uncond = model(
hidden_states=x,
encoder_hidden_states=torch.zeros_like(t5_emb),
pooled_projections=torch.zeros_like(clip_pooled),
timestep=t_batch,
img_ids=img_ids,
)
# Classifier-free guidance
v = v_uncond + cfg_scale * (v_cond - v_uncond)
# Euler step
x = x + v * dt
return x # [1, 4096, 16] - decode with VAE
```
---
## Architecture
### Model Comparison
| Component | TinyFlux | TinyFlux-Deep v3 | TinyFlux-Deep v4.1 | Flux-Schnell |
|-----------|----------|------------------|--------------------| -------------|
| Hidden size | 256 | 512 | 512 | 3072 |
| Attention heads | 2 | 4 | 4 | 24 |
| Head dimension | 128 | 128 | 128 | 128 |
| Double-stream layers | 3 | 15 | 15 | 19 |
| Single-stream layers | 3 | 25 | 25 | 38 |
| MLP ratio | 4.0 | 4.0 | 4.0 | 4.0 |
| RoPE dims | (16,56,56) | (16,56,56) | (16,56,56) | (16,56,56) |
| Lune Expert | ✗ | ✓ | ✓ | ✗ |
| Sol Attention Prior | ✗ | ✗ | ✓ | ✗ |
| T5 Vec Enhancement | ✗ | ✗ | ✓ | ✗ |
| **Total Parameters** | ~10.7M | ~244.7M | ~246.4M | ~12B |
| **Memory (bf16)** | ~22MB | ~490MB | ~493MB | ~24GB |
### Block Structure
**Double-Stream Blocks (15 layers):**
- Separate text and image pathways
- Joint attention between modalities
- AdaLN-Zero conditioning from vec
- Sol spatial modulation on image Q/K only
**Single-Stream Blocks (25 layers):**
- Concatenated text + image sequence
- Full self-attention with RoPE
- Sol modulation skips text tokens
```
Input: img_latents [B, 4096, 16], t5_emb [B, 77, 768], clip_pooled [B, 768]
┌───────────────┴───────────────┐
▼ ▼
img_in (Linear) txt_in (Linear)
│ │
▼ ▼
[B, 4096, 512] [B, 77, 512]
│ │
└───────────┬───────────────────┘
vec = time_emb + clip_vec + t5_vec + lune_signal
┌───────────┴───────────┐
▼ ▼
Double Blocks (×15) Sol Prior → temperature, spatial_mod
│ │
▼ │
Single Blocks (×25) ◄─────────┘
final_norm → final_linear
Output: [B, 4096, 16]
```
---
## Dual Expert System
TinyFlux-Deep v4.1 uses two complementary expert pathways to inject knowledge from teacher models without the "twin-tail paradox" (mixing incompatible prediction targets).
### Lune Expert Predictor (Trajectory Guidance)
**Purpose:** Captures SD1.5's understanding of "how denoising should flow" - the trajectory through latent space.
**Architecture:**
```python
LuneExpertPredictor(
time_dim=512, # From timestep MLP
clip_dim=768, # CLIP pooled features
expert_dim=1280, # SD1.5 mid-block dimension (prediction target)
hidden_dim=512, # Internal MLP width
output_dim=512, # Output added to vec
dropout=0.1,
)
```
**How it works:**
1. Concatenates timestep embedding + CLIP pooled → hidden
2. Predicts what SD1.5's mid-block features would be
3. During training: uses real SD1.5 features when available
4. During inference: uses predicted features
5. Gates output with learnable sigmoid (init 0.5)
6. Adds `expert_signal` to global `vec` conditioning
**Training signal:** Cosine similarity loss against real SD1.5 UNet mid-block features (soft directional matching, not exact reconstruction).
### Sol Attention Prior (Structural Guidance)
**Purpose:** Captures geometric/structural knowledge about WHERE attention should focus, without injecting incompatible features.
**Key insight:** Sol (a V-prediction DDPM model) has valuable attention patterns, but its features are incompatible with TinyFlux's linear flow matching. We extract attention *statistics* instead:
- **Locality:** How local vs global is attention?
- **Entropy:** How focused vs diffuse?
- **Clustering:** How structured vs uniform?
- **Spatial importance:** Which regions matter most?
**Architecture:**
```python
SolAttentionPrior(
time_dim=512,
clip_dim=768,
hidden_dim=256,
num_heads=4, # Matches TinyFlux attention heads
spatial_size=8, # 8×8 importance map
geometric_weight=0.7, # David's 70/30 split
)
```
**How it works:**
1. **Geometric prior (70%):** Timestep-based heuristics
- Early denoising (high t): Higher temperature → softer, global attention
- Late denoising (low t): Lower temperature → sharper, local attention
- Spatial: Uniform early, center-biased late
2. **Learned prior (30%):** Content-based predictions
- Predicts attention statistics from (timestep, CLIP)
- Predicts spatial importance map
3. **Blending:** `blend * geometric + (1-blend) * learned` with learnable blend gate
4. **Output application:**
- `temperature [B, 4]`: Scales attention logits per head
- `spatial_mod [B, H, W]`: Modulates Q/K at each position via `exp(conv(spatial))`
**Identity initialization:** All Sol components initialize to zero-effect:
- `spatial_to_mod` Conv2d: zero weight, zero bias → `exp(0) = 1` (identity)
- Allows gradual learning without disrupting pretrained attention
### T5 Vec Enhancement
**Purpose:** Adds T5's semantic understanding to the global conditioning pathway (previously only CLIP pooled).
```python
# Attention-weighted pooling of T5 sequence
t5_attn = softmax(t5_emb.mean(dim=-1)) # [B, 77]
t5_pooled = (t5_emb * t5_attn.unsqueeze(-1)).sum(dim=1) # [B, 768]
t5_vec = t5_pool_mlp(t5_pooled) # [B, 512]
# Learnable balance between CLIP and T5
balance = sigmoid(text_balance) # Initialized to 0.5
text_vec = balance * clip_vec + (1 - balance) * t5_vec
```
---
## Configuration
### TinyFluxConfig
```python
from dataclasses import dataclass
from typing import Tuple
@dataclass
class TinyFluxConfig:
# Core architecture
hidden_size: int = 512
num_attention_heads: int = 4
attention_head_dim: int = 128 # hidden_size = heads × head_dim
in_channels: int = 16 # VAE latent channels
patch_size: int = 1
joint_attention_dim: int = 768 # T5 embedding dim
pooled_projection_dim: int = 768 # CLIP pooled dim
num_double_layers: int = 15
num_single_layers: int = 25
mlp_ratio: float = 4.0
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) # Must sum to head_dim
# Lune expert predictor
use_lune_expert: bool = True
lune_expert_dim: int = 1280 # SD1.5 mid-block dim
lune_hidden_dim: int = 512
lune_dropout: float = 0.1
# Sol attention prior
use_sol_prior: bool = True
sol_spatial_size: int = 8 # 8×8 spatial importance map
sol_hidden_dim: int = 256
sol_geometric_weight: float = 0.7 # 70% geometric, 30% learned
# T5 enhancement
use_t5_vec: bool = True
t5_pool_mode: str = "attention" # "attention", "mean", "cls"
# Loss configuration
lune_distill_mode: str = "cosine" # "hard", "soft", "cosine", "huber"
use_huber_loss: bool = True
huber_delta: float = 0.1
# Legacy compatibility
guidance_embeds: bool = False
```
### Loading from JSON
```python
# From file
config = TinyFluxConfig.from_json("lailah_401434_v4_config.json")
# From dict
config = TinyFluxConfig.from_dict({
"hidden_size": 512,
"num_attention_heads": 4,
...
})
# Save with metadata
config.save_json("config.json", metadata={"source_step": 401434})
```
### Validation
```python
# Config validates constraints on creation
config = TinyFluxConfig(hidden_size=512, num_attention_heads=4, attention_head_dim=128)
# ✓ OK: 512 = 4 × 128
config = TinyFluxConfig(hidden_size=512, num_attention_heads=4, attention_head_dim=64)
# ✗ ValueError: hidden_size (512) must equal num_attention_heads * attention_head_dim (256)
# Validate checkpoint compatibility
warnings = config.validate_checkpoint(state_dict)
if warnings:
print("Warnings:", warnings)
```
---
## Inference
### Full Pipeline
```python
import torch
from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# Load text encoders
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_model = T5EncoderModel.from_pretrained("google/flan-t5-base").to("cuda", torch.bfloat16)
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda", torch.bfloat16)
# Load VAE
vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
subfolder="vae",
torch_dtype=torch.bfloat16
).to("cuda")
# Load TinyFlux-Deep
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
exec(open(model_py).read())
config = TinyFluxConfig()
model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors"))
model.load_state_dict(weights, strict=False)
model.eval()
def encode_prompt(prompt):
"""Encode prompt with both T5 and CLIP."""
# T5
t5_tokens = t5_tokenizer(prompt, return_tensors="pt", padding="max_length",
max_length=77, truncation=True).to("cuda")
with torch.no_grad():
t5_emb = t5_model(**t5_tokens).last_hidden_state.to(torch.bfloat16)
# CLIP
clip_tokens = clip_tokenizer(prompt, return_tensors="pt", padding="max_length",
max_length=77, truncation=True).to("cuda")
with torch.no_grad():
clip_out = clip_model(**clip_tokens)
clip_pooled = clip_out.pooler_output.to(torch.bfloat16)
return t5_emb, clip_pooled
def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
"""
Euler sampling for rectified flow.
Flow: x_t = (1-t)*noise + t*data
Integrate from t=0 (noise) to t=1 (data)
"""
if seed is not None:
torch.manual_seed(seed)
t5_emb, clip_pooled = encode_prompt(prompt)
# Null embeddings for CFG
t5_null, clip_null = encode_prompt("")
# Start from pure noise (t=0)
x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda")
# Rectified flow: 0 → 1 with Flux shift
def flux_shift(t, s=3.0):
return s * t / (1 + (s - 1) * t)
timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1, device="cuda"))
with torch.no_grad():
for i in range(num_steps):
t = timesteps[i].expand(1)
dt = timesteps[i + 1] - timesteps[i] # Positive
# Conditional
v_cond = model(x, t5_emb, clip_pooled, t, img_ids)
# Unconditional
v_uncond = model(x, t5_null, clip_null, t, img_ids)
# CFG
v = v_uncond + cfg_scale * (v_cond - v_uncond)
# Euler step
x = x + v * dt
# Decode with VAE
x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) # [B, C, H, W]
x = x / vae.config.scaling_factor
with torch.no_grad():
image = vae.decode(x).sample
# Convert to PIL
image = (image / 2 + 0.5).clamp(0, 1)
image = image[0].permute(1, 2, 0).cpu().float().numpy()
image = (image * 255).astype("uint8")
from PIL import Image
return Image.fromarray(image)
# Generate!
image = generate_image("a photograph of a tiger in natural habitat", seed=42)
image.save("tiger.png")
```
### Batch Inference
```python
def generate_batch(prompts, **kwargs):
"""Generate multiple images."""
return [generate_image(p, **kwargs) for p in prompts]
images = generate_batch([
"a red bird with blue beak",
"a mountain landscape at sunset",
"an astronaut riding a horse",
], num_steps=25, cfg_scale=4.0)
```
---
## Training
### Loss Computation
```python
# Forward pass with expert info
output, expert_info = model(
hidden_states=noisy_latents,
encoder_hidden_states=t5_emb,
pooled_projections=clip_pooled,
timestep=timesteps,
img_ids=img_ids,
lune_features=sd15_midblock_features, # From SD1.5 teacher
sol_stats=sol_attention_stats, # From Sol teacher (optional)
sol_spatial=sol_spatial_importance, # From Sol teacher (optional)
return_expert_pred=True,
)
# Compute loss
losses = model.compute_loss(
output=output,
target=flow_target, # data - noise for flow matching
expert_info=expert_info,
lune_features=sd15_midblock_features,
sol_stats=sol_attention_stats,
sol_spatial=sol_spatial_importance,
# Loss weights
lune_weight=0.1, # Weight for Lune distillation
sol_weight=0.05, # Weight for Sol distillation
# Loss options
use_huber=True, # Huber loss for main objective (robust to outliers)
huber_delta=0.1, # Huber delta (smaller = tighter MSE region)
lune_distill_mode="cosine", # "hard", "soft", "cosine", "huber"
spatial_weighting=True, # Weight loss by Sol spatial importance
)
# losses dict contains:
# - main: flow matching loss
# - lune_distill: Lune prediction loss
# - sol_stat_distill: Sol statistics prediction loss
# - sol_spatial_distill: Sol spatial prediction loss
# - total: weighted sum
loss = losses['total']
loss.backward()
```
### Distillation Modes
| Mode | Description | Use Case |
|------|-------------|----------|
| `"hard"` | MSE against teacher features | Exact reconstruction |
| `"soft"` | Temperature-scaled MSE | Softer matching |
| `"cosine"` | Cosine similarity loss | Directional alignment (recommended) |
| `"huber"` | Huber loss on features | Robust to outliers |
### Training Loop Example
```python
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
optimizer = AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.99), weight_decay=0.01)
scaler = GradScaler()
# EMA
ema_decay = 0.9999
ema_model = copy.deepcopy(model)
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
with autocast(dtype=torch.bfloat16):
# Sample timesteps with logit-normal distribution
u = torch.randn(batch_size, device=device)
t = torch.sigmoid(u) # Logit-normal
t = flux_shift(t, s=3.0) # Flux shift
# Add noise
noise = torch.randn_like(batch['latents'])
noisy = t.view(-1,1,1) * batch['latents'] + (1-t.view(-1,1,1)) * noise
target = batch['latents'] - noise # Flow matching target
# Forward
output, expert_info = model(
hidden_states=noisy,
encoder_hidden_states=batch['t5_emb'],
pooled_projections=batch['clip_pooled'],
timestep=t,
img_ids=img_ids,
lune_features=batch.get('sd15_features'),
return_expert_pred=True,
)
# Loss
losses = model.compute_loss(output, target, expert_info,
lune_features=batch.get('sd15_features'))
scaler.scale(losses['total']).backward()
scaler.step(optimizer)
scaler.update()
# EMA update
with torch.no_grad():
for p, p_ema in zip(model.parameters(), ema_model.parameters()):
p_ema.data.lerp_(p.data, 1 - ema_decay)
```
### Hyperparameters
| Parameter | Value | Notes |
|-----------|-------|-------|
| Optimizer | AdamW | |
| Learning rate | 3e-4 | With cosine schedule |
| Betas | (0.9, 0.99) | |
| Weight decay | 0.01 | |
| Batch size | 32 | 16 × 2 gradient accumulation |
| EMA decay | 0.9999 | |
| Precision | bfloat16 | |
| Timestep shift | s=3.0 | Flux-style |
| Timestep sampling | Logit-normal | |
| Lune weight | 0.1 | |
| Sol weight | 0.05 | |
| Huber delta | 0.1 | |
---
## Checkpoint Conversion
### v3 → v4.1 Conversion
The converter preserves all pretrained weights and initializes new v4.1 components to identity/zero-effect:
**What gets converted:**
| v3 Key | v4.1 Key | Action |
|--------|----------|--------|
| `expert_predictor.*` | `lune_predictor.*` | Rename |
| `expert_gate` (0.5) | `expert_gate` (0.0) | Convert to logit space |
| - | `sol_prior.*` | Initialize (zero-effect) |
| - | `t5_pool.*` | Initialize (Xavier) |
| - | `text_balance` | Initialize (0.0 = 50/50) |
| - | `*.spatial_to_mod.*` | Initialize (zero = identity) |
**Parameter growth:**
- v3: ~244.7M parameters
- v4.1: ~246.4M parameters
- Added: ~1.7M parameters (0.7% increase)
### Python API
```python
from huggingface_hub import hf_hub_download
# Download converter
converter = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/convert_v3_to_v4.py")
exec(open(converter).read())
# Simple: download, convert, upload
from convert_v3_to_v4 import run
result = run(401434) # Step number
# With custom config
result = run(401434, config={
"hidden_size": 512,
"num_attention_heads": 4,
"sol_geometric_weight": 0.8, # More geometric, less learned
})
# From JSON config file
result = run(401434, config="my_config.json")
# Low-level API
from convert_v3_to_v4 import convert_state_dict, analyze_checkpoint, TinyFluxConfig
# Analyze checkpoint version
state_dict = load_file("checkpoint.safetensors")
info = analyze_checkpoint(state_dict)
print(f"Version: {info.version}") # "v3", "v4.0", "v4.1", etc.
print(f"Has Sol prior: {info.has_sol_prior}")
# Convert state dict
config = TinyFluxConfig()
v4_state, report = convert_state_dict(state_dict, config)
print(f"Renamed {len(report['renamed'])} keys")
print(f"Initialized {len(report['initialized'])} keys")
```
### CLI
```bash
# Basic conversion
python convert_v3_to_v4.py --step 401434
# Local file
python convert_v3_to_v4.py --input model_v3.safetensors
# Analyze only (don't convert)
python convert_v3_to_v4.py --step 401434 --analyze-only
# Custom output
python convert_v3_to_v4.py --step 401434 --output-dir my_converted --name mymodel
# With custom config
python convert_v3_to_v4.py --step 401434 --config my_config.json
```
### Output Structure
```
checkpoint_runs/v4_init/
├── lailah_401434_v4_init.safetensors # Converted model
├── lailah_401434_v4_init_ema.safetensors # Fresh EMA (copy of model)
├── lailah_401434_v4_init_ema_secondary.safetensors # Converted old EMA
└── lailah_401434_v4_config.json # Config with conversion metadata
```
### Config JSON Format
```json
{
"hidden_size": 512,
"num_attention_heads": 4,
"attention_head_dim": 128,
"num_double_layers": 15,
"num_single_layers": 25,
"use_lune_expert": true,
"use_sol_prior": true,
"use_t5_vec": true,
"sol_geometric_weight": 0.7,
"lune_distill_mode": "cosine",
"use_huber_loss": true,
"huber_delta": 0.1,
"_conversion_info": {
"source_step": 401434,
"source_repo": "AbstractPhil/tiny-flux-deep",
"source_version": "v3",
"target_version": "v4.1",
"source_params": 244690849,
"target_params": 246347234,
"params_added": 1656385,
"converter_version": "4.1.0"
}
}
```
---
## Repository Structure
```
AbstractPhil/tiny-flux-deep/
├── model.safetensors # Latest training weights
├── model_ema.safetensors # EMA weights (use for inference)
├── config.json # Model configuration
├── README.md
├── scripts/ # All Python code
│ ├── model_v4.py # v4.1 architecture (current)
│ ├── model_v3.py # v3 architecture (reference)
│ ├── model_v2.py # v2 architecture (legacy)
│ ├── inference_v3.py # Full inference pipeline
│ ├── convert_v3_to_v4.py # Checkpoint converter
│ ├── trainer_v3_expert_guidance.py # Training with distillation
│ ├── trainer_v2.py # Previous trainer
│ ├── trainer.py # Original trainer
│ ├── port_tiny_to_deep.py # TinyFlux → Deep port script
│ └── colab_inference_lailah_early.py # Simple Colab notebook
├── checkpoints/ # v3 checkpoints (legacy)
│ ├── step_401434.safetensors
│ └── step_401434_ema.safetensors
├── checkpoint_runs/ # Organized checkpoint runs
│ └── v4_init/ # v4.1 initialization from v3
│ ├── lailah_401434_v4_init.safetensors
│ ├── lailah_401434_v4_init_ema.safetensors
│ ├── lailah_401434_v4_init_ema_secondary.safetensors
│ └── lailah_401434_v4_config.json
├── samples/ # Generated samples per step
│ └── 20260127_074318_step_401434.png
└── logs/ # TensorBoard training logs
└── run_20260126_220714/
```
---
## API Reference
### TinyFluxDeep
```python
class TinyFluxDeep(nn.Module):
def __init__(self, config: Optional[TinyFluxConfig] = None):
"""Initialize model with config (uses defaults if None)."""
def forward(
self,
hidden_states: torch.Tensor, # [B, N, 16] image latents
encoder_hidden_states: torch.Tensor, # [B, L, 768] T5 embeddings
pooled_projections: torch.Tensor, # [B, 768] CLIP pooled
timestep: torch.Tensor, # [B] timestep in [0, 1]
img_ids: torch.Tensor, # [N, 3] position IDs
txt_ids: Optional[torch.Tensor] = None,
guidance: Optional[torch.Tensor] = None, # Legacy
lune_features: Optional[torch.Tensor] = None, # [B, 1280] SD1.5 features
sol_stats: Optional[torch.Tensor] = None, # [B, 3] attention stats
sol_spatial: Optional[torch.Tensor] = None, # [B, 8, 8] spatial importance
expert_features: Optional[torch.Tensor] = None, # Legacy API
return_expert_pred: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
"""
Forward pass.
Returns:
output: [B, N, 16] predicted velocity
expert_info: dict with predictions (if return_expert_pred=True)
"""
def compute_loss(
self,
output: torch.Tensor,
target: torch.Tensor,
expert_info: Optional[Dict] = None,
lune_features: Optional[torch.Tensor] = None,
sol_stats: Optional[torch.Tensor] = None,
sol_spatial: Optional[torch.Tensor] = None,
lune_weight: float = 0.1,
sol_weight: float = 0.05,
use_huber: bool = True,
huber_delta: float = 0.1,
lune_distill_mode: str = "cosine",
spatial_weighting: bool = True,
) -> Dict[str, torch.Tensor]:
"""Compute combined loss with distillation."""
@staticmethod
def create_img_ids(batch_size: int, height: int, width: int, device) -> torch.Tensor:
"""Create image position IDs for RoPE."""
@staticmethod
def create_txt_ids(text_len: int, device) -> torch.Tensor:
"""Create text position IDs."""
def count_parameters(self) -> Dict[str, int]:
"""Count parameters by component."""
```
### Converter Functions
```python
# High-level
def run(step, name="lailah", config=None, ...):
"""One-liner: download, convert, upload."""
def convert_checkpoint(step=None, input_path=None, config=None, ...) -> ConversionResult:
"""Convert checkpoint with full control."""
# Low-level
def analyze_checkpoint(state_dict) -> CheckpointInfo:
"""Analyze checkpoint version and contents."""
def convert_state_dict(state_dict, config=None) -> Tuple[Dict, Dict]:
"""Convert state dict, return (new_state, report)."""
def download_from_hf(step, repo_id, ...) -> Tuple[str, str]:
"""Download checkpoint from HuggingFace."""
# Config
class TinyFluxConfig:
def to_dict(self) -> Dict
def from_dict(cls, d) -> TinyFluxConfig
def from_json(cls, path) -> TinyFluxConfig
def save_json(self, path, metadata=None)
def validate_checkpoint(self, state_dict) -> List[str]
```
---
## Samples
### Step 401434 (v3 weights)
**"subject, animal, cat, photograph of a tiger, natural habitat"**
![tiger](https://cdn-uploads.huggingface.co/production/uploads/630cf55b15433862cfc9556f/uJ9Ffh780iLgEIJhmafod.png)
**"subject, bird, blue beak, red eyes, green claws"**
![bird1](https://cdn-uploads.huggingface.co/production/uploads/630cf55b15433862cfc9556f/GRS5tyaFFa0HV2xSJCsin.png)
**"subject, bird, red haired bird in a tree"**
![bird2](https://cdn-uploads.huggingface.co/production/uploads/630cf55b15433862cfc9556f/rGourHokJsPtYNnoFi3Eq.png)
---
## Limitations
| Limitation | Details |
|------------|---------|
| **Resolution** | 512×512 only (64×64 latent space) |
| **Text encoder** | flan-t5-base (768 dim) vs Flux's T5-XXL (4096 dim) |
| **Attention heads** | 4 heads vs Flux's 24 - limits capacity |
| **Training data** | Teacher latents, not real images |
| **v4.1 status** | New architecture, training just starting |
| **Artifacts** | Expect imperfections - research model |
---
## Name
**Lailah** (לילה) — In Jewish tradition, the angel of the night who guards souls and teaches wisdom to the unborn. Chosen for this model's role as a smaller guardian exploring the same latent space as larger models, learning from their knowledge while finding its own path.
---
## Citation
```bibtex
@misc{tinyfluxdeep2026,
title={TinyFlux-Deep: Compact Flow Matching with Dual Expert Distillation},
author={AbstractPhil},
year={2026},
howpublished={\url{https://huggingface.co/AbstractPhil/tiny-flux-deep}},
note={246M parameter text-to-image model with Lune trajectory guidance and Sol attention priors}
}
```
---
## Related Projects
| Project | Description |
|---------|-------------|
| [AbstractPhil/tiny-flux](https://huggingface.co/AbstractPhil/tiny-flux) | Original TinyFlux (10.7M params) |
| [AbstractPhil/flux-schnell-teacher-latents](https://huggingface.co/datasets/AbstractPhil/flux-schnell-teacher-latents) | Training dataset |
| [AbstractPhil/imagenet-synthetic](https://huggingface.co/datasets/AbstractPhil/imagenet-synthetic) | ImageNet-style synthetic data |
| [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | Teacher model |
---
## License
MIT License - free for research and commercial use.
---
**Status**: v4.1 architecture complete. Converting v3 checkpoints and resuming training with dual expert distillation.