Update README.md
Browse files
README.md
CHANGED
|
@@ -21,234 +21,934 @@ datasets:
|
|
| 21 |
- AbstractPhil/imagenet-synthetic
|
| 22 |
---
|
| 23 |
|
| 24 |
-
# TinyFlux-Deep (Lailah)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
|
| 30 |
|
| 31 |
-
|
| 32 |
|
| 33 |
-
|
| 34 |
-
2. Copy the contents of [`inference_v3.py`](./inference_v3.py) and [`model_v3.py`](./model_v3.py)
|
| 35 |
-
3. Run the cells
|
| 36 |
|
| 37 |
```python
|
| 38 |
-
|
| 39 |
-
!wget https://huggingface.co/AbstractPhil/tiny-flux-deep/raw/main/inference_v3.py
|
| 40 |
-
%run inference_v3.py
|
| 41 |
-
```
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
|
|
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
|
| 62 |
## Architecture
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
| 67 |
-
|
| 68 |
-
|
|
| 69 |
-
|
|
| 70 |
-
|
|
| 71 |
-
|
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|------|-------|-----------|
|
| 78 |
-
| Sequence encoder | flan-t5-base | 768 |
|
| 79 |
-
| Pooled encoder | CLIP-L | 768 |
|
| 80 |
|
| 81 |
-
##
|
| 82 |
|
| 83 |
-
|
| 84 |
|
| 85 |
-
|
| 86 |
|
| 87 |
-
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
-
|
| 97 |
-
- **Timestep sampling**: Logit-normal with Flux shift (s=3.0)
|
| 98 |
-
- **Loss weighting**: Min-SNR-Ξ³ (Ξ³=5.0)
|
| 99 |
-
- **Optimizer**: AdamW (lr=3e-4, Ξ²=(0.9, 0.99), wd=0.01)
|
| 100 |
-
- **Schedule**: Cosine with warmup
|
| 101 |
-
- **Precision**: bfloat16
|
| 102 |
-
- **Batch size**: 32 (16 Γ 2 gradient accumulation)
|
| 103 |
-
- **EMA decay**: 0.9999
|
| 104 |
|
| 105 |
-
###
|
| 106 |
|
| 107 |
-
|
| 108 |
-
- `checkpoints/step_XXXXX.safetensors` - Training weights
|
| 109 |
-
- `checkpoints/step_XXXXX_ema.safetensors` - EMA weights (currently very broken and retraining, use standard step to inference)
|
| 110 |
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
```
|
| 118 |
|
| 119 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
```python
|
| 122 |
import torch
|
|
|
|
|
|
|
| 123 |
from huggingface_hub import hf_hub_download
|
| 124 |
from safetensors.torch import load_file
|
| 125 |
|
| 126 |
-
# Load
|
| 127 |
-
|
| 128 |
-
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
model.load_state_dict(weights, strict=False)
|
| 136 |
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
```
|
| 138 |
|
| 139 |
-
###
|
| 140 |
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
```python
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
|
| 149 |
-
timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1))
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
```
|
| 158 |
|
| 159 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
```python
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
```
|
| 176 |
|
| 177 |
-
|
|
|
|
|
|
|
| 178 |
|
| 179 |
```
|
| 180 |
AbstractPhil/tiny-flux-deep/
|
| 181 |
-
|
| 182 |
-
βββ
|
| 183 |
-
βββ
|
| 184 |
-
βββ
|
| 185 |
-
βββ
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
β
|
| 189 |
-
βββ
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
```
|
| 192 |
|
| 193 |
-
|
| 194 |
|
| 195 |
-
|
| 196 |
|
| 197 |
-
|
| 198 |
-
2. **Hidden dimension expansion** (256 β 512): Weights tiled and scaled
|
| 199 |
-
3. **Layer distribution**: Original 3 layers distributed across 15/25 positions as initialization anchors
|
| 200 |
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
-
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
-
-
|
| 215 |
-
- **Early training**: Quality improving but not production-ready
|
| 216 |
-
- **Text capacity**: Limited by flan-t5-base (768 dim vs Flux's 4096)
|
| 217 |
-
- **Experimental**: Research model, expect artifacts
|
| 218 |
|
| 219 |
-
|
| 220 |
|
| 221 |
-
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
## Name
|
| 228 |
|
| 229 |
-
**Lailah** (ΧΧΧΧ)
|
|
|
|
|
|
|
| 230 |
|
| 231 |
## Citation
|
| 232 |
|
| 233 |
```bibtex
|
| 234 |
-
@misc{
|
| 235 |
-
title={TinyFlux-
|
| 236 |
author={AbstractPhil},
|
| 237 |
year={2026},
|
| 238 |
-
|
|
|
|
| 239 |
}
|
| 240 |
```
|
| 241 |
|
| 242 |
-
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
## License
|
| 249 |
|
| 250 |
-
MIT License
|
| 251 |
|
| 252 |
---
|
| 253 |
|
| 254 |
-
**Status**:
|
|
|
|
| 21 |
- AbstractPhil/imagenet-synthetic
|
| 22 |
---
|
| 23 |
|
| 24 |
+
# TinyFlux-Deep v4.1 (Lailah)
|
| 25 |
+
|
| 26 |
+
A compact **246M parameter** flow-matching diffusion model that distills knowledge from multiple teacher models into an efficient architecture. TinyFlux-Deep uses a dual expert system to capture both trajectory dynamics (from SD1.5) and structural attention patterns (from a geometric prior), enabling high-quality image generation at a fraction of the compute cost of full-scale models.
|
| 27 |
+
|
| 28 |
+
## Table of Contents
|
| 29 |
+
|
| 30 |
+
- [Key Features](#key-features)
|
| 31 |
+
- [Quick Start](#quick-start)
|
| 32 |
+
- [Architecture](#architecture)
|
| 33 |
+
- [Dual Expert System](#dual-expert-system)
|
| 34 |
+
- [Configuration](#configuration)
|
| 35 |
+
- [Inference](#inference)
|
| 36 |
+
- [Training](#training)
|
| 37 |
+
- [Checkpoint Conversion](#checkpoint-conversion)
|
| 38 |
+
- [Repository Structure](#repository-structure)
|
| 39 |
+
- [API Reference](#api-reference)
|
| 40 |
+
- [Samples](#samples)
|
| 41 |
+
- [Limitations](#limitations)
|
| 42 |
+
- [Citation](#citation)
|
| 43 |
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## Key Features
|
| 47 |
|
| 48 |
+
| Feature | Description |
|
| 49 |
+
|---------|-------------|
|
| 50 |
+
| **Compact Size** | 246M params (~500MB bf16) vs Flux's 12B (~24GB) |
|
| 51 |
+
| **Dual Expert Distillation** | Learns from both SD1.5 trajectory features and geometric attention priors |
|
| 52 |
+
| **Flow Matching** | Rectified flow objective with Flux-style timestep shifting |
|
| 53 |
+
| **T5 + CLIP Conditioning** | Dual text encoder pathway with learnable balance |
|
| 54 |
+
| **Huber Loss** | Robust training that handles outliers gracefully |
|
| 55 |
+
| **Identity-Init Conversion** | v3βv4 conversion preserves pretrained weights exactly |
|
| 56 |
|
| 57 |
+
---
|
| 58 |
|
| 59 |
+
## Quick Start
|
| 60 |
|
| 61 |
+
### Colab Inference
|
|
|
|
|
|
|
| 62 |
|
| 63 |
```python
|
| 64 |
+
!pip install torch transformers safetensors huggingface_hub accelerate
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
import torch
|
| 67 |
+
from huggingface_hub import hf_hub_download
|
| 68 |
+
from safetensors.torch import load_file
|
| 69 |
|
| 70 |
+
# Download model code and weights
|
| 71 |
+
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
|
| 72 |
+
weights = hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors")
|
| 73 |
|
| 74 |
+
# Load model
|
| 75 |
+
exec(open(model_py).read())
|
| 76 |
+
config = TinyFluxConfig()
|
| 77 |
+
model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
|
| 78 |
+
model.load_state_dict(load_file(weights), strict=False)
|
| 79 |
+
model.eval()
|
| 80 |
+
|
| 81 |
+
# For full inference pipeline with text encoders and sampling:
|
| 82 |
+
inference_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/inference_v3.py")
|
| 83 |
+
exec(open(inference_py).read())
|
| 84 |
+
# Then call: image = generate("your prompt here")
|
| 85 |
+
```
|
| 86 |
|
| 87 |
+
### Minimal Generation Loop
|
| 88 |
|
| 89 |
+
```python
|
| 90 |
+
import torch
|
| 91 |
+
import torch.nn.functional as F
|
| 92 |
|
| 93 |
+
def flux_shift(t, s=3.0):
|
| 94 |
+
"""Flux-style timestep shifting - biases toward data end."""
|
| 95 |
+
return s * t / (1 + (s - 1) * t)
|
| 96 |
|
| 97 |
+
def generate(model, t5_emb, clip_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
|
| 98 |
+
"""Euler sampling with classifier-free guidance."""
|
| 99 |
+
device = next(model.parameters()).device
|
| 100 |
+
dtype = next(model.parameters()).dtype
|
| 101 |
+
|
| 102 |
+
# Start from noise
|
| 103 |
+
x = torch.randn(1, 64*64, 16, device=device, dtype=dtype)
|
| 104 |
+
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device)
|
| 105 |
+
|
| 106 |
+
# Timesteps with Flux shift
|
| 107 |
+
timesteps = flux_shift(torch.linspace(1, 0, num_steps + 1, device=device))
|
| 108 |
+
|
| 109 |
+
for i in range(num_steps):
|
| 110 |
+
t_curr = timesteps[i]
|
| 111 |
+
t_next = timesteps[i + 1]
|
| 112 |
+
dt = t_next - t_curr
|
| 113 |
+
|
| 114 |
+
t_batch = t_curr.expand(1)
|
| 115 |
+
|
| 116 |
+
# Conditional prediction
|
| 117 |
+
v_cond = model(
|
| 118 |
+
hidden_states=x,
|
| 119 |
+
encoder_hidden_states=t5_emb,
|
| 120 |
+
pooled_projections=clip_pooled,
|
| 121 |
+
timestep=t_batch,
|
| 122 |
+
img_ids=img_ids,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Unconditional prediction (for CFG)
|
| 126 |
+
v_uncond = model(
|
| 127 |
+
hidden_states=x,
|
| 128 |
+
encoder_hidden_states=torch.zeros_like(t5_emb),
|
| 129 |
+
pooled_projections=torch.zeros_like(clip_pooled),
|
| 130 |
+
timestep=t_batch,
|
| 131 |
+
img_ids=img_ids,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Classifier-free guidance
|
| 135 |
+
v = v_uncond + cfg_scale * (v_cond - v_uncond)
|
| 136 |
+
|
| 137 |
+
# Euler step
|
| 138 |
+
x = x + v * dt
|
| 139 |
+
|
| 140 |
+
return x # [1, 4096, 16] - decode with VAE
|
| 141 |
+
```
|
| 142 |
|
| 143 |
+
---
|
| 144 |
|
| 145 |
## Architecture
|
| 146 |
|
| 147 |
+
### Model Comparison
|
| 148 |
+
|
| 149 |
+
| Component | TinyFlux | TinyFlux-Deep v3 | TinyFlux-Deep v4.1 | Flux-Schnell |
|
| 150 |
+
|-----------|----------|------------------|--------------------| -------------|
|
| 151 |
+
| Hidden size | 256 | 512 | 512 | 3072 |
|
| 152 |
+
| Attention heads | 2 | 4 | 4 | 24 |
|
| 153 |
+
| Head dimension | 128 | 128 | 128 | 128 |
|
| 154 |
+
| Double-stream layers | 3 | 15 | 15 | 19 |
|
| 155 |
+
| Single-stream layers | 3 | 25 | 25 | 38 |
|
| 156 |
+
| MLP ratio | 4.0 | 4.0 | 4.0 | 4.0 |
|
| 157 |
+
| RoPE dims | (16,56,56) | (16,56,56) | (16,56,56) | (16,56,56) |
|
| 158 |
+
| Lune Expert | β | β | β | β |
|
| 159 |
+
| Sol Attention Prior | β | β | β | β |
|
| 160 |
+
| T5 Vec Enhancement | β | β | β | β |
|
| 161 |
+
| **Total Parameters** | ~10.7M | ~244.7M | ~246.4M | ~12B |
|
| 162 |
+
| **Memory (bf16)** | ~22MB | ~490MB | ~493MB | ~24GB |
|
| 163 |
+
|
| 164 |
+
### Block Structure
|
| 165 |
+
|
| 166 |
+
**Double-Stream Blocks (15 layers):**
|
| 167 |
+
- Separate text and image pathways
|
| 168 |
+
- Joint attention between modalities
|
| 169 |
+
- AdaLN-Zero conditioning from vec
|
| 170 |
+
- Sol spatial modulation on image Q/K only
|
| 171 |
+
|
| 172 |
+
**Single-Stream Blocks (25 layers):**
|
| 173 |
+
- Concatenated text + image sequence
|
| 174 |
+
- Full self-attention with RoPE
|
| 175 |
+
- Sol modulation skips text tokens
|
| 176 |
|
| 177 |
+
```
|
| 178 |
+
Input: img_latents [B, 4096, 16], t5_emb [B, 77, 768], clip_pooled [B, 768]
|
| 179 |
+
β
|
| 180 |
+
βββββββββββββββββ΄ββββββββββββββββ
|
| 181 |
+
βΌ βΌ
|
| 182 |
+
img_in (Linear) txt_in (Linear)
|
| 183 |
+
β β
|
| 184 |
+
βΌ βΌ
|
| 185 |
+
[B, 4096, 512] [B, 77, 512]
|
| 186 |
+
β β
|
| 187 |
+
βββββββββββββ¬ββββββββββββββββββββ
|
| 188 |
+
β
|
| 189 |
+
vec = time_emb + clip_vec + t5_vec + lune_signal
|
| 190 |
+
β
|
| 191 |
+
βββββββββββββ΄ββββββββββββ
|
| 192 |
+
βΌ βΌ
|
| 193 |
+
Double Blocks (Γ15) Sol Prior β temperature, spatial_mod
|
| 194 |
+
β β
|
| 195 |
+
βΌ β
|
| 196 |
+
Single Blocks (Γ25) βββββββββββ
|
| 197 |
+
β
|
| 198 |
+
βΌ
|
| 199 |
+
final_norm β final_linear
|
| 200 |
+
β
|
| 201 |
+
βΌ
|
| 202 |
+
Output: [B, 4096, 16]
|
| 203 |
+
```
|
| 204 |
|
| 205 |
+
---
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
+
## Dual Expert System
|
| 208 |
|
| 209 |
+
TinyFlux-Deep v4.1 uses two complementary expert pathways to inject knowledge from teacher models without the "twin-tail paradox" (mixing incompatible prediction targets).
|
| 210 |
|
| 211 |
+
### Lune Expert Predictor (Trajectory Guidance)
|
| 212 |
|
| 213 |
+
**Purpose:** Captures SD1.5's understanding of "how denoising should flow" - the trajectory through latent space.
|
| 214 |
|
| 215 |
+
**Architecture:**
|
| 216 |
+
```python
|
| 217 |
+
LuneExpertPredictor(
|
| 218 |
+
time_dim=512, # From timestep MLP
|
| 219 |
+
clip_dim=768, # CLIP pooled features
|
| 220 |
+
expert_dim=1280, # SD1.5 mid-block dimension (prediction target)
|
| 221 |
+
hidden_dim=512, # Internal MLP width
|
| 222 |
+
output_dim=512, # Output added to vec
|
| 223 |
+
dropout=0.1,
|
| 224 |
+
)
|
| 225 |
+
```
|
| 226 |
|
| 227 |
+
**How it works:**
|
| 228 |
+
1. Concatenates timestep embedding + CLIP pooled β hidden
|
| 229 |
+
2. Predicts what SD1.5's mid-block features would be
|
| 230 |
+
3. During training: uses real SD1.5 features when available
|
| 231 |
+
4. During inference: uses predicted features
|
| 232 |
+
5. Gates output with learnable sigmoid (init 0.5)
|
| 233 |
+
6. Adds `expert_signal` to global `vec` conditioning
|
| 234 |
|
| 235 |
+
**Training signal:** Cosine similarity loss against real SD1.5 UNet mid-block features (soft directional matching, not exact reconstruction).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
+
### Sol Attention Prior (Structural Guidance)
|
| 238 |
|
| 239 |
+
**Purpose:** Captures geometric/structural knowledge about WHERE attention should focus, without injecting incompatible features.
|
|
|
|
|
|
|
| 240 |
|
| 241 |
+
**Key insight:** Sol (a V-prediction DDPM model) has valuable attention patterns, but its features are incompatible with TinyFlux's linear flow matching. We extract attention *statistics* instead:
|
| 242 |
+
- **Locality:** How local vs global is attention?
|
| 243 |
+
- **Entropy:** How focused vs diffuse?
|
| 244 |
+
- **Clustering:** How structured vs uniform?
|
| 245 |
+
- **Spatial importance:** Which regions matter most?
|
| 246 |
|
| 247 |
+
**Architecture:**
|
| 248 |
+
```python
|
| 249 |
+
SolAttentionPrior(
|
| 250 |
+
time_dim=512,
|
| 251 |
+
clip_dim=768,
|
| 252 |
+
hidden_dim=256,
|
| 253 |
+
num_heads=4, # Matches TinyFlux attention heads
|
| 254 |
+
spatial_size=8, # 8Γ8 importance map
|
| 255 |
+
geometric_weight=0.7, # David's 70/30 split
|
| 256 |
+
)
|
| 257 |
+
```
|
| 258 |
|
| 259 |
+
**How it works:**
|
| 260 |
+
1. **Geometric prior (70%):** Timestep-based heuristics
|
| 261 |
+
- Early denoising (high t): Higher temperature β softer, global attention
|
| 262 |
+
- Late denoising (low t): Lower temperature β sharper, local attention
|
| 263 |
+
- Spatial: Uniform early, center-biased late
|
| 264 |
+
|
| 265 |
+
2. **Learned prior (30%):** Content-based predictions
|
| 266 |
+
- Predicts attention statistics from (timestep, CLIP)
|
| 267 |
+
- Predicts spatial importance map
|
| 268 |
+
|
| 269 |
+
3. **Blending:** `blend * geometric + (1-blend) * learned` with learnable blend gate
|
| 270 |
+
|
| 271 |
+
4. **Output application:**
|
| 272 |
+
- `temperature [B, 4]`: Scales attention logits per head
|
| 273 |
+
- `spatial_mod [B, H, W]`: Modulates Q/K at each position via `exp(conv(spatial))`
|
| 274 |
+
|
| 275 |
+
**Identity initialization:** All Sol components initialize to zero-effect:
|
| 276 |
+
- `spatial_to_mod` Conv2d: zero weight, zero bias β `exp(0) = 1` (identity)
|
| 277 |
+
- Allows gradual learning without disrupting pretrained attention
|
| 278 |
+
|
| 279 |
+
### T5 Vec Enhancement
|
| 280 |
+
|
| 281 |
+
**Purpose:** Adds T5's semantic understanding to the global conditioning pathway (previously only CLIP pooled).
|
| 282 |
+
|
| 283 |
+
```python
|
| 284 |
+
# Attention-weighted pooling of T5 sequence
|
| 285 |
+
t5_attn = softmax(t5_emb.mean(dim=-1)) # [B, 77]
|
| 286 |
+
t5_pooled = (t5_emb * t5_attn.unsqueeze(-1)).sum(dim=1) # [B, 768]
|
| 287 |
+
t5_vec = t5_pool_mlp(t5_pooled) # [B, 512]
|
| 288 |
+
|
| 289 |
+
# Learnable balance between CLIP and T5
|
| 290 |
+
balance = sigmoid(text_balance) # Initialized to 0.5
|
| 291 |
+
text_vec = balance * clip_vec + (1 - balance) * t5_vec
|
| 292 |
+
```
|
| 293 |
+
|
| 294 |
+
---
|
| 295 |
+
|
| 296 |
+
## Configuration
|
| 297 |
+
|
| 298 |
+
### TinyFluxConfig
|
| 299 |
+
|
| 300 |
+
```python
|
| 301 |
+
from dataclasses import dataclass
|
| 302 |
+
from typing import Tuple
|
| 303 |
+
|
| 304 |
+
@dataclass
|
| 305 |
+
class TinyFluxConfig:
|
| 306 |
+
# Core architecture
|
| 307 |
+
hidden_size: int = 512
|
| 308 |
+
num_attention_heads: int = 4
|
| 309 |
+
attention_head_dim: int = 128 # hidden_size = heads Γ head_dim
|
| 310 |
+
in_channels: int = 16 # VAE latent channels
|
| 311 |
+
patch_size: int = 1
|
| 312 |
+
joint_attention_dim: int = 768 # T5 embedding dim
|
| 313 |
+
pooled_projection_dim: int = 768 # CLIP pooled dim
|
| 314 |
+
num_double_layers: int = 15
|
| 315 |
+
num_single_layers: int = 25
|
| 316 |
+
mlp_ratio: float = 4.0
|
| 317 |
+
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) # Must sum to head_dim
|
| 318 |
+
|
| 319 |
+
# Lune expert predictor
|
| 320 |
+
use_lune_expert: bool = True
|
| 321 |
+
lune_expert_dim: int = 1280 # SD1.5 mid-block dim
|
| 322 |
+
lune_hidden_dim: int = 512
|
| 323 |
+
lune_dropout: float = 0.1
|
| 324 |
+
|
| 325 |
+
# Sol attention prior
|
| 326 |
+
use_sol_prior: bool = True
|
| 327 |
+
sol_spatial_size: int = 8 # 8Γ8 spatial importance map
|
| 328 |
+
sol_hidden_dim: int = 256
|
| 329 |
+
sol_geometric_weight: float = 0.7 # 70% geometric, 30% learned
|
| 330 |
+
|
| 331 |
+
# T5 enhancement
|
| 332 |
+
use_t5_vec: bool = True
|
| 333 |
+
t5_pool_mode: str = "attention" # "attention", "mean", "cls"
|
| 334 |
+
|
| 335 |
+
# Loss configuration
|
| 336 |
+
lune_distill_mode: str = "cosine" # "hard", "soft", "cosine", "huber"
|
| 337 |
+
use_huber_loss: bool = True
|
| 338 |
+
huber_delta: float = 0.1
|
| 339 |
+
|
| 340 |
+
# Legacy compatibility
|
| 341 |
+
guidance_embeds: bool = False
|
| 342 |
```
|
| 343 |
|
| 344 |
+
### Loading from JSON
|
| 345 |
+
|
| 346 |
+
```python
|
| 347 |
+
# From file
|
| 348 |
+
config = TinyFluxConfig.from_json("lailah_401434_v4_config.json")
|
| 349 |
+
|
| 350 |
+
# From dict
|
| 351 |
+
config = TinyFluxConfig.from_dict({
|
| 352 |
+
"hidden_size": 512,
|
| 353 |
+
"num_attention_heads": 4,
|
| 354 |
+
...
|
| 355 |
+
})
|
| 356 |
+
|
| 357 |
+
# Save with metadata
|
| 358 |
+
config.save_json("config.json", metadata={"source_step": 401434})
|
| 359 |
+
```
|
| 360 |
+
|
| 361 |
+
### Validation
|
| 362 |
+
|
| 363 |
+
```python
|
| 364 |
+
# Config validates constraints on creation
|
| 365 |
+
config = TinyFluxConfig(hidden_size=512, num_attention_heads=4, attention_head_dim=128)
|
| 366 |
+
# β OK: 512 = 4 Γ 128
|
| 367 |
+
|
| 368 |
+
config = TinyFluxConfig(hidden_size=512, num_attention_heads=4, attention_head_dim=64)
|
| 369 |
+
# β ValueError: hidden_size (512) must equal num_attention_heads * attention_head_dim (256)
|
| 370 |
+
|
| 371 |
+
# Validate checkpoint compatibility
|
| 372 |
+
warnings = config.validate_checkpoint(state_dict)
|
| 373 |
+
if warnings:
|
| 374 |
+
print("Warnings:", warnings)
|
| 375 |
+
```
|
| 376 |
+
|
| 377 |
+
---
|
| 378 |
+
|
| 379 |
+
## Inference
|
| 380 |
+
|
| 381 |
+
### Full Pipeline
|
| 382 |
|
| 383 |
```python
|
| 384 |
import torch
|
| 385 |
+
from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
|
| 386 |
+
from diffusers import AutoencoderKL
|
| 387 |
from huggingface_hub import hf_hub_download
|
| 388 |
from safetensors.torch import load_file
|
| 389 |
|
| 390 |
+
# Load text encoders
|
| 391 |
+
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
| 392 |
+
t5_model = T5EncoderModel.from_pretrained("google/flan-t5-base").to("cuda", torch.bfloat16)
|
| 393 |
|
| 394 |
+
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 395 |
+
clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda", torch.bfloat16)
|
| 396 |
+
|
| 397 |
+
# Load VAE
|
| 398 |
+
vae = AutoencoderKL.from_pretrained(
|
| 399 |
+
"black-forest-labs/FLUX.1-schnell",
|
| 400 |
+
subfolder="vae",
|
| 401 |
+
torch_dtype=torch.bfloat16
|
| 402 |
+
).to("cuda")
|
| 403 |
+
|
| 404 |
+
# Load TinyFlux-Deep
|
| 405 |
+
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
|
| 406 |
+
exec(open(model_py).read())
|
| 407 |
+
|
| 408 |
+
config = TinyFluxConfig()
|
| 409 |
+
model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
|
| 410 |
+
weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors"))
|
| 411 |
model.load_state_dict(weights, strict=False)
|
| 412 |
model.eval()
|
| 413 |
+
|
| 414 |
+
def encode_prompt(prompt):
|
| 415 |
+
"""Encode prompt with both T5 and CLIP."""
|
| 416 |
+
# T5
|
| 417 |
+
t5_tokens = t5_tokenizer(prompt, return_tensors="pt", padding="max_length",
|
| 418 |
+
max_length=77, truncation=True).to("cuda")
|
| 419 |
+
with torch.no_grad():
|
| 420 |
+
t5_emb = t5_model(**t5_tokens).last_hidden_state.to(torch.bfloat16)
|
| 421 |
+
|
| 422 |
+
# CLIP
|
| 423 |
+
clip_tokens = clip_tokenizer(prompt, return_tensors="pt", padding="max_length",
|
| 424 |
+
max_length=77, truncation=True).to("cuda")
|
| 425 |
+
with torch.no_grad():
|
| 426 |
+
clip_out = clip_model(**clip_tokens)
|
| 427 |
+
clip_pooled = clip_out.pooler_output.to(torch.bfloat16)
|
| 428 |
+
|
| 429 |
+
return t5_emb, clip_pooled
|
| 430 |
+
|
| 431 |
+
def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
|
| 432 |
+
"""Generate image from text prompt."""
|
| 433 |
+
if seed is not None:
|
| 434 |
+
torch.manual_seed(seed)
|
| 435 |
+
|
| 436 |
+
t5_emb, clip_pooled = encode_prompt(prompt)
|
| 437 |
+
|
| 438 |
+
# Null embeddings for CFG
|
| 439 |
+
t5_null, clip_null = encode_prompt("")
|
| 440 |
+
|
| 441 |
+
# Start from noise
|
| 442 |
+
x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
|
| 443 |
+
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda")
|
| 444 |
+
|
| 445 |
+
# Flux-shifted timesteps
|
| 446 |
+
def flux_shift(t, s=3.0):
|
| 447 |
+
return s * t / (1 + (s - 1) * t)
|
| 448 |
+
|
| 449 |
+
timesteps = flux_shift(torch.linspace(1, 0, num_steps + 1, device="cuda"))
|
| 450 |
+
|
| 451 |
+
with torch.no_grad():
|
| 452 |
+
for i in range(num_steps):
|
| 453 |
+
t = timesteps[i].expand(1)
|
| 454 |
+
dt = timesteps[i + 1] - timesteps[i]
|
| 455 |
+
|
| 456 |
+
# Conditional
|
| 457 |
+
v_cond = model(x, t5_emb, clip_pooled, t, img_ids)
|
| 458 |
+
|
| 459 |
+
# Unconditional
|
| 460 |
+
v_uncond = model(x, t5_null, clip_null, t, img_ids)
|
| 461 |
+
|
| 462 |
+
# CFG
|
| 463 |
+
v = v_uncond + cfg_scale * (v_cond - v_uncond)
|
| 464 |
+
|
| 465 |
+
# Euler step
|
| 466 |
+
x = x + v * dt
|
| 467 |
+
|
| 468 |
+
# Decode with VAE
|
| 469 |
+
x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) # [B, C, H, W]
|
| 470 |
+
x = x / vae.config.scaling_factor
|
| 471 |
+
with torch.no_grad():
|
| 472 |
+
image = vae.decode(x).sample
|
| 473 |
+
|
| 474 |
+
# Convert to PIL
|
| 475 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 476 |
+
image = image[0].permute(1, 2, 0).cpu().float().numpy()
|
| 477 |
+
image = (image * 255).astype("uint8")
|
| 478 |
+
|
| 479 |
+
from PIL import Image
|
| 480 |
+
return Image.fromarray(image)
|
| 481 |
+
|
| 482 |
+
# Generate!
|
| 483 |
+
image = generate_image("a photograph of a tiger in natural habitat", seed=42)
|
| 484 |
+
image.save("tiger.png")
|
| 485 |
```
|
| 486 |
|
| 487 |
+
### Batch Inference
|
| 488 |
|
| 489 |
+
```python
|
| 490 |
+
def generate_batch(prompts, **kwargs):
|
| 491 |
+
"""Generate multiple images."""
|
| 492 |
+
return [generate_image(p, **kwargs) for p in prompts]
|
| 493 |
+
|
| 494 |
+
images = generate_batch([
|
| 495 |
+
"a red bird with blue beak",
|
| 496 |
+
"a mountain landscape at sunset",
|
| 497 |
+
"an astronaut riding a horse",
|
| 498 |
+
], num_steps=25, cfg_scale=4.0)
|
| 499 |
+
```
|
| 500 |
+
|
| 501 |
+
---
|
| 502 |
+
|
| 503 |
+
## Training
|
| 504 |
+
|
| 505 |
+
### Loss Computation
|
| 506 |
|
| 507 |
```python
|
| 508 |
+
# Forward pass with expert info
|
| 509 |
+
output, expert_info = model(
|
| 510 |
+
hidden_states=noisy_latents,
|
| 511 |
+
encoder_hidden_states=t5_emb,
|
| 512 |
+
pooled_projections=clip_pooled,
|
| 513 |
+
timestep=timesteps,
|
| 514 |
+
img_ids=img_ids,
|
| 515 |
+
lune_features=sd15_midblock_features, # From SD1.5 teacher
|
| 516 |
+
sol_stats=sol_attention_stats, # From Sol teacher (optional)
|
| 517 |
+
sol_spatial=sol_spatial_importance, # From Sol teacher (optional)
|
| 518 |
+
return_expert_pred=True,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# Compute loss
|
| 522 |
+
losses = model.compute_loss(
|
| 523 |
+
output=output,
|
| 524 |
+
target=flow_target, # data - noise for flow matching
|
| 525 |
+
expert_info=expert_info,
|
| 526 |
+
lune_features=sd15_midblock_features,
|
| 527 |
+
sol_stats=sol_attention_stats,
|
| 528 |
+
sol_spatial=sol_spatial_importance,
|
| 529 |
+
|
| 530 |
+
# Loss weights
|
| 531 |
+
lune_weight=0.1, # Weight for Lune distillation
|
| 532 |
+
sol_weight=0.05, # Weight for Sol distillation
|
| 533 |
+
|
| 534 |
+
# Loss options
|
| 535 |
+
use_huber=True, # Huber loss for main objective (robust to outliers)
|
| 536 |
+
huber_delta=0.1, # Huber delta (smaller = tighter MSE region)
|
| 537 |
+
lune_distill_mode="cosine", # "hard", "soft", "cosine", "huber"
|
| 538 |
+
spatial_weighting=True, # Weight loss by Sol spatial importance
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
# losses dict contains:
|
| 542 |
+
# - main: flow matching loss
|
| 543 |
+
# - lune_distill: Lune prediction loss
|
| 544 |
+
# - sol_stat_distill: Sol statistics prediction loss
|
| 545 |
+
# - sol_spatial_distill: Sol spatial prediction loss
|
| 546 |
+
# - total: weighted sum
|
| 547 |
+
|
| 548 |
+
loss = losses['total']
|
| 549 |
+
loss.backward()
|
| 550 |
+
```
|
| 551 |
+
|
| 552 |
+
### Distillation Modes
|
| 553 |
+
|
| 554 |
+
| Mode | Description | Use Case |
|
| 555 |
+
|------|-------------|----------|
|
| 556 |
+
| `"hard"` | MSE against teacher features | Exact reconstruction |
|
| 557 |
+
| `"soft"` | Temperature-scaled MSE | Softer matching |
|
| 558 |
+
| `"cosine"` | Cosine similarity loss | Directional alignment (recommended) |
|
| 559 |
+
| `"huber"` | Huber loss on features | Robust to outliers |
|
| 560 |
|
| 561 |
+
### Training Loop Example
|
|
|
|
| 562 |
|
| 563 |
+
```python
|
| 564 |
+
from torch.optim import AdamW
|
| 565 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 566 |
+
|
| 567 |
+
optimizer = AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.99), weight_decay=0.01)
|
| 568 |
+
scaler = GradScaler()
|
| 569 |
+
|
| 570 |
+
# EMA
|
| 571 |
+
ema_decay = 0.9999
|
| 572 |
+
ema_model = copy.deepcopy(model)
|
| 573 |
+
|
| 574 |
+
for step, batch in enumerate(dataloader):
|
| 575 |
+
optimizer.zero_grad()
|
| 576 |
+
|
| 577 |
+
with autocast(dtype=torch.bfloat16):
|
| 578 |
+
# Sample timesteps with logit-normal distribution
|
| 579 |
+
u = torch.randn(batch_size, device=device)
|
| 580 |
+
t = torch.sigmoid(u) # Logit-normal
|
| 581 |
+
t = flux_shift(t, s=3.0) # Flux shift
|
| 582 |
+
|
| 583 |
+
# Add noise
|
| 584 |
+
noise = torch.randn_like(batch['latents'])
|
| 585 |
+
noisy = t.view(-1,1,1) * batch['latents'] + (1-t.view(-1,1,1)) * noise
|
| 586 |
+
target = batch['latents'] - noise # Flow matching target
|
| 587 |
+
|
| 588 |
+
# Forward
|
| 589 |
+
output, expert_info = model(
|
| 590 |
+
hidden_states=noisy,
|
| 591 |
+
encoder_hidden_states=batch['t5_emb'],
|
| 592 |
+
pooled_projections=batch['clip_pooled'],
|
| 593 |
+
timestep=t,
|
| 594 |
+
img_ids=img_ids,
|
| 595 |
+
lune_features=batch.get('sd15_features'),
|
| 596 |
+
return_expert_pred=True,
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
# Loss
|
| 600 |
+
losses = model.compute_loss(output, target, expert_info,
|
| 601 |
+
lune_features=batch.get('sd15_features'))
|
| 602 |
|
| 603 |
+
scaler.scale(losses['total']).backward()
|
| 604 |
+
scaler.step(optimizer)
|
| 605 |
+
scaler.update()
|
| 606 |
+
|
| 607 |
+
# EMA update
|
| 608 |
+
with torch.no_grad():
|
| 609 |
+
for p, p_ema in zip(model.parameters(), ema_model.parameters()):
|
| 610 |
+
p_ema.data.lerp_(p.data, 1 - ema_decay)
|
| 611 |
```
|
| 612 |
|
| 613 |
+
### Hyperparameters
|
| 614 |
+
|
| 615 |
+
| Parameter | Value | Notes |
|
| 616 |
+
|-----------|-------|-------|
|
| 617 |
+
| Optimizer | AdamW | |
|
| 618 |
+
| Learning rate | 3e-4 | With cosine schedule |
|
| 619 |
+
| Betas | (0.9, 0.99) | |
|
| 620 |
+
| Weight decay | 0.01 | |
|
| 621 |
+
| Batch size | 32 | 16 Γ 2 gradient accumulation |
|
| 622 |
+
| EMA decay | 0.9999 | |
|
| 623 |
+
| Precision | bfloat16 | |
|
| 624 |
+
| Timestep shift | s=3.0 | Flux-style |
|
| 625 |
+
| Timestep sampling | Logit-normal | |
|
| 626 |
+
| Lune weight | 0.1 | |
|
| 627 |
+
| Sol weight | 0.05 | |
|
| 628 |
+
| Huber delta | 0.1 | |
|
| 629 |
+
|
| 630 |
+
---
|
| 631 |
+
|
| 632 |
+
## Checkpoint Conversion
|
| 633 |
+
|
| 634 |
+
### v3 β v4.1 Conversion
|
| 635 |
+
|
| 636 |
+
The converter preserves all pretrained weights and initializes new v4.1 components to identity/zero-effect:
|
| 637 |
+
|
| 638 |
+
**What gets converted:**
|
| 639 |
+
| v3 Key | v4.1 Key | Action |
|
| 640 |
+
|--------|----------|--------|
|
| 641 |
+
| `expert_predictor.*` | `lune_predictor.*` | Rename |
|
| 642 |
+
| `expert_gate` (0.5) | `expert_gate` (0.0) | Convert to logit space |
|
| 643 |
+
| - | `sol_prior.*` | Initialize (zero-effect) |
|
| 644 |
+
| - | `t5_pool.*` | Initialize (Xavier) |
|
| 645 |
+
| - | `text_balance` | Initialize (0.0 = 50/50) |
|
| 646 |
+
| - | `*.spatial_to_mod.*` | Initialize (zero = identity) |
|
| 647 |
+
|
| 648 |
+
**Parameter growth:**
|
| 649 |
+
- v3: ~244.7M parameters
|
| 650 |
+
- v4.1: ~246.4M parameters
|
| 651 |
+
- Added: ~1.7M parameters (0.7% increase)
|
| 652 |
+
|
| 653 |
+
### Python API
|
| 654 |
|
| 655 |
```python
|
| 656 |
+
from huggingface_hub import hf_hub_download
|
| 657 |
+
|
| 658 |
+
# Download converter
|
| 659 |
+
converter = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/convert_v3_to_v4.py")
|
| 660 |
+
exec(open(converter).read())
|
| 661 |
+
|
| 662 |
+
# Simple: download, convert, upload
|
| 663 |
+
from convert_v3_to_v4 import run
|
| 664 |
+
result = run(401434) # Step number
|
| 665 |
+
|
| 666 |
+
# With custom config
|
| 667 |
+
result = run(401434, config={
|
| 668 |
+
"hidden_size": 512,
|
| 669 |
+
"num_attention_heads": 4,
|
| 670 |
+
"sol_geometric_weight": 0.8, # More geometric, less learned
|
| 671 |
+
})
|
| 672 |
+
|
| 673 |
+
# From JSON config file
|
| 674 |
+
result = run(401434, config="my_config.json")
|
| 675 |
+
|
| 676 |
+
# Low-level API
|
| 677 |
+
from convert_v3_to_v4 import convert_state_dict, analyze_checkpoint, TinyFluxConfig
|
| 678 |
+
|
| 679 |
+
# Analyze checkpoint version
|
| 680 |
+
state_dict = load_file("checkpoint.safetensors")
|
| 681 |
+
info = analyze_checkpoint(state_dict)
|
| 682 |
+
print(f"Version: {info.version}") # "v3", "v4.0", "v4.1", etc.
|
| 683 |
+
print(f"Has Sol prior: {info.has_sol_prior}")
|
| 684 |
+
|
| 685 |
+
# Convert state dict
|
| 686 |
+
config = TinyFluxConfig()
|
| 687 |
+
v4_state, report = convert_state_dict(state_dict, config)
|
| 688 |
+
print(f"Renamed {len(report['renamed'])} keys")
|
| 689 |
+
print(f"Initialized {len(report['initialized'])} keys")
|
| 690 |
+
```
|
| 691 |
+
|
| 692 |
+
### CLI
|
| 693 |
+
|
| 694 |
+
```bash
|
| 695 |
+
# Basic conversion
|
| 696 |
+
python convert_v3_to_v4.py --step 401434
|
| 697 |
+
|
| 698 |
+
# Local file
|
| 699 |
+
python convert_v3_to_v4.py --input model_v3.safetensors
|
| 700 |
+
|
| 701 |
+
# Analyze only (don't convert)
|
| 702 |
+
python convert_v3_to_v4.py --step 401434 --analyze-only
|
| 703 |
+
|
| 704 |
+
# Custom output
|
| 705 |
+
python convert_v3_to_v4.py --step 401434 --output-dir my_converted --name mymodel
|
| 706 |
+
|
| 707 |
+
# With custom config
|
| 708 |
+
python convert_v3_to_v4.py --step 401434 --config my_config.json
|
| 709 |
+
```
|
| 710 |
+
|
| 711 |
+
### Output Structure
|
| 712 |
+
|
| 713 |
+
```
|
| 714 |
+
checkpoint_runs/v4_init/
|
| 715 |
+
βββ lailah_401434_v4_init.safetensors # Converted model
|
| 716 |
+
βββ lailah_401434_v4_init_ema.safetensors # Fresh EMA (copy of model)
|
| 717 |
+
βββ lailah_401434_v4_init_ema_secondary.safetensors # Converted old EMA
|
| 718 |
+
βββ lailah_401434_v4_config.json # Config with conversion metadata
|
| 719 |
+
```
|
| 720 |
+
|
| 721 |
+
### Config JSON Format
|
| 722 |
+
|
| 723 |
+
```json
|
| 724 |
+
{
|
| 725 |
+
"hidden_size": 512,
|
| 726 |
+
"num_attention_heads": 4,
|
| 727 |
+
"attention_head_dim": 128,
|
| 728 |
+
"num_double_layers": 15,
|
| 729 |
+
"num_single_layers": 25,
|
| 730 |
+
"use_lune_expert": true,
|
| 731 |
+
"use_sol_prior": true,
|
| 732 |
+
"use_t5_vec": true,
|
| 733 |
+
"sol_geometric_weight": 0.7,
|
| 734 |
+
"lune_distill_mode": "cosine",
|
| 735 |
+
"use_huber_loss": true,
|
| 736 |
+
"huber_delta": 0.1,
|
| 737 |
+
"_conversion_info": {
|
| 738 |
+
"source_step": 401434,
|
| 739 |
+
"source_repo": "AbstractPhil/tiny-flux-deep",
|
| 740 |
+
"source_version": "v3",
|
| 741 |
+
"target_version": "v4.1",
|
| 742 |
+
"source_params": 244690849,
|
| 743 |
+
"target_params": 246347234,
|
| 744 |
+
"params_added": 1656385,
|
| 745 |
+
"converter_version": "4.1.0"
|
| 746 |
+
}
|
| 747 |
+
}
|
| 748 |
```
|
| 749 |
|
| 750 |
+
---
|
| 751 |
+
|
| 752 |
+
## Repository Structure
|
| 753 |
|
| 754 |
```
|
| 755 |
AbstractPhil/tiny-flux-deep/
|
| 756 |
+
β
|
| 757 |
+
βββ model.safetensors # Latest training weights
|
| 758 |
+
βββ model_ema.safetensors # EMA weights (use for inference)
|
| 759 |
+
βββ config.json # Model configuration
|
| 760 |
+
βββ README.md
|
| 761 |
+
β
|
| 762 |
+
βββ scripts/ # All Python code
|
| 763 |
+
β βββ model_v4.py # v4.1 architecture (current)
|
| 764 |
+
β βββ model_v3.py # v3 architecture (reference)
|
| 765 |
+
β βββ model_v2.py # v2 architecture (legacy)
|
| 766 |
+
β βββ inference_v3.py # Full inference pipeline
|
| 767 |
+
β βββ convert_v3_to_v4.py # Checkpoint converter
|
| 768 |
+
β βββ trainer_v3_expert_guidance.py # Training with distillation
|
| 769 |
+
β βββ trainer_v2.py # Previous trainer
|
| 770 |
+
β βββ trainer.py # Original trainer
|
| 771 |
+
β βββ port_tiny_to_deep.py # TinyFlux β Deep port script
|
| 772 |
+
β βββ colab_inference_lailah_early.py # Simple Colab notebook
|
| 773 |
+
β
|
| 774 |
+
βββ checkpoints/ # v3 checkpoints (legacy)
|
| 775 |
+
β βββ step_401434.safetensors
|
| 776 |
+
β βββ step_401434_ema.safetensors
|
| 777 |
+
β
|
| 778 |
+
βββ checkpoint_runs/ # Organized checkpoint runs
|
| 779 |
+
β βββ v4_init/ # v4.1 initialization from v3
|
| 780 |
+
β βββ lailah_401434_v4_init.safetensors
|
| 781 |
+
β βββ lailah_401434_v4_init_ema.safetensors
|
| 782 |
+
β βββ lailah_401434_v4_init_ema_secondary.safetensors
|
| 783 |
+
β βββ lailah_401434_v4_config.json
|
| 784 |
+
β
|
| 785 |
+
βββ samples/ # Generated samples per step
|
| 786 |
+
β βββ 20260127_074318_step_401434.png
|
| 787 |
+
β
|
| 788 |
+
βββ logs/ # TensorBoard training logs
|
| 789 |
+
βββ run_20260126_220714/
|
| 790 |
```
|
| 791 |
|
| 792 |
+
---
|
| 793 |
|
| 794 |
+
## API Reference
|
| 795 |
|
| 796 |
+
### TinyFluxDeep
|
|
|
|
|
|
|
| 797 |
|
| 798 |
+
```python
|
| 799 |
+
class TinyFluxDeep(nn.Module):
|
| 800 |
+
def __init__(self, config: Optional[TinyFluxConfig] = None):
|
| 801 |
+
"""Initialize model with config (uses defaults if None)."""
|
| 802 |
+
|
| 803 |
+
def forward(
|
| 804 |
+
self,
|
| 805 |
+
hidden_states: torch.Tensor, # [B, N, 16] image latents
|
| 806 |
+
encoder_hidden_states: torch.Tensor, # [B, L, 768] T5 embeddings
|
| 807 |
+
pooled_projections: torch.Tensor, # [B, 768] CLIP pooled
|
| 808 |
+
timestep: torch.Tensor, # [B] timestep in [0, 1]
|
| 809 |
+
img_ids: torch.Tensor, # [N, 3] position IDs
|
| 810 |
+
txt_ids: Optional[torch.Tensor] = None,
|
| 811 |
+
guidance: Optional[torch.Tensor] = None, # Legacy
|
| 812 |
+
lune_features: Optional[torch.Tensor] = None, # [B, 1280] SD1.5 features
|
| 813 |
+
sol_stats: Optional[torch.Tensor] = None, # [B, 3] attention stats
|
| 814 |
+
sol_spatial: Optional[torch.Tensor] = None, # [B, 8, 8] spatial importance
|
| 815 |
+
expert_features: Optional[torch.Tensor] = None, # Legacy API
|
| 816 |
+
return_expert_pred: bool = False,
|
| 817 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
|
| 818 |
+
"""
|
| 819 |
+
Forward pass.
|
| 820 |
+
|
| 821 |
+
Returns:
|
| 822 |
+
output: [B, N, 16] predicted velocity
|
| 823 |
+
expert_info: dict with predictions (if return_expert_pred=True)
|
| 824 |
+
"""
|
| 825 |
+
|
| 826 |
+
def compute_loss(
|
| 827 |
+
self,
|
| 828 |
+
output: torch.Tensor,
|
| 829 |
+
target: torch.Tensor,
|
| 830 |
+
expert_info: Optional[Dict] = None,
|
| 831 |
+
lune_features: Optional[torch.Tensor] = None,
|
| 832 |
+
sol_stats: Optional[torch.Tensor] = None,
|
| 833 |
+
sol_spatial: Optional[torch.Tensor] = None,
|
| 834 |
+
lune_weight: float = 0.1,
|
| 835 |
+
sol_weight: float = 0.05,
|
| 836 |
+
use_huber: bool = True,
|
| 837 |
+
huber_delta: float = 0.1,
|
| 838 |
+
lune_distill_mode: str = "cosine",
|
| 839 |
+
spatial_weighting: bool = True,
|
| 840 |
+
) -> Dict[str, torch.Tensor]:
|
| 841 |
+
"""Compute combined loss with distillation."""
|
| 842 |
+
|
| 843 |
+
@staticmethod
|
| 844 |
+
def create_img_ids(batch_size: int, height: int, width: int, device) -> torch.Tensor:
|
| 845 |
+
"""Create image position IDs for RoPE."""
|
| 846 |
+
|
| 847 |
+
@staticmethod
|
| 848 |
+
def create_txt_ids(text_len: int, device) -> torch.Tensor:
|
| 849 |
+
"""Create text position IDs."""
|
| 850 |
+
|
| 851 |
+
def count_parameters(self) -> Dict[str, int]:
|
| 852 |
+
"""Count parameters by component."""
|
| 853 |
+
```
|
| 854 |
|
| 855 |
+
### Converter Functions
|
| 856 |
|
| 857 |
+
```python
|
| 858 |
+
# High-level
|
| 859 |
+
def run(step, name="lailah", config=None, ...):
|
| 860 |
+
"""One-liner: download, convert, upload."""
|
| 861 |
+
|
| 862 |
+
def convert_checkpoint(step=None, input_path=None, config=None, ...) -> ConversionResult:
|
| 863 |
+
"""Convert checkpoint with full control."""
|
| 864 |
+
|
| 865 |
+
# Low-level
|
| 866 |
+
def analyze_checkpoint(state_dict) -> CheckpointInfo:
|
| 867 |
+
"""Analyze checkpoint version and contents."""
|
| 868 |
+
|
| 869 |
+
def convert_state_dict(state_dict, config=None) -> Tuple[Dict, Dict]:
|
| 870 |
+
"""Convert state dict, return (new_state, report)."""
|
| 871 |
+
|
| 872 |
+
def download_from_hf(step, repo_id, ...) -> Tuple[str, str]:
|
| 873 |
+
"""Download checkpoint from HuggingFace."""
|
| 874 |
+
|
| 875 |
+
# Config
|
| 876 |
+
class TinyFluxConfig:
|
| 877 |
+
def to_dict(self) -> Dict
|
| 878 |
+
def from_dict(cls, d) -> TinyFluxConfig
|
| 879 |
+
def from_json(cls, path) -> TinyFluxConfig
|
| 880 |
+
def save_json(self, path, metadata=None)
|
| 881 |
+
def validate_checkpoint(self, state_dict) -> List[str]
|
| 882 |
+
```
|
| 883 |
|
| 884 |
+
---
|
| 885 |
+
|
| 886 |
+
## Samples
|
| 887 |
+
|
| 888 |
+
### Step 401434 (v3 weights)
|
| 889 |
+
|
| 890 |
+
**"subject, animal, cat, photograph of a tiger, natural habitat"**
|
| 891 |
|
| 892 |
+

|
|
|
|
|
|
|
|
|
|
| 893 |
|
| 894 |
+
**"subject, bird, blue beak, red eyes, green claws"**
|
| 895 |
|
| 896 |
+

|
| 897 |
+
|
| 898 |
+
**"subject, bird, red haired bird in a tree"**
|
| 899 |
+
|
| 900 |
+

|
| 901 |
+
|
| 902 |
+
---
|
| 903 |
+
|
| 904 |
+
## Limitations
|
| 905 |
+
|
| 906 |
+
| Limitation | Details |
|
| 907 |
+
|------------|---------|
|
| 908 |
+
| **Resolution** | 512Γ512 only (64Γ64 latent space) |
|
| 909 |
+
| **Text encoder** | flan-t5-base (768 dim) vs Flux's T5-XXL (4096 dim) |
|
| 910 |
+
| **Attention heads** | 4 heads vs Flux's 24 - limits capacity |
|
| 911 |
+
| **Training data** | Teacher latents, not real images |
|
| 912 |
+
| **v4.1 status** | New architecture, training just starting |
|
| 913 |
+
| **Artifacts** | Expect imperfections - research model |
|
| 914 |
+
|
| 915 |
+
---
|
| 916 |
|
| 917 |
## Name
|
| 918 |
|
| 919 |
+
**Lailah** (ΧΧΧΧ) β In Jewish tradition, the angel of the night who guards souls and teaches wisdom to the unborn. Chosen for this model's role as a smaller guardian exploring the same latent space as larger models, learning from their knowledge while finding its own path.
|
| 920 |
+
|
| 921 |
+
---
|
| 922 |
|
| 923 |
## Citation
|
| 924 |
|
| 925 |
```bibtex
|
| 926 |
+
@misc{tinyfluxdeep2026,
|
| 927 |
+
title={TinyFlux-Deep: Compact Flow Matching with Dual Expert Distillation},
|
| 928 |
author={AbstractPhil},
|
| 929 |
year={2026},
|
| 930 |
+
howpublished={\url{https://huggingface.co/AbstractPhil/tiny-flux-deep}},
|
| 931 |
+
note={246M parameter text-to-image model with Lune trajectory guidance and Sol attention priors}
|
| 932 |
}
|
| 933 |
```
|
| 934 |
|
| 935 |
+
---
|
| 936 |
+
|
| 937 |
+
## Related Projects
|
| 938 |
|
| 939 |
+
| Project | Description |
|
| 940 |
+
|---------|-------------|
|
| 941 |
+
| [AbstractPhil/tiny-flux](https://huggingface.co/AbstractPhil/tiny-flux) | Original TinyFlux (10.7M params) |
|
| 942 |
+
| [AbstractPhil/flux-schnell-teacher-latents](https://huggingface.co/datasets/AbstractPhil/flux-schnell-teacher-latents) | Training dataset |
|
| 943 |
+
| [AbstractPhil/imagenet-synthetic](https://huggingface.co/datasets/AbstractPhil/imagenet-synthetic) | ImageNet-style synthetic data |
|
| 944 |
+
| [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | Teacher model |
|
| 945 |
+
|
| 946 |
+
---
|
| 947 |
|
| 948 |
## License
|
| 949 |
|
| 950 |
+
MIT License - free for research and commercial use.
|
| 951 |
|
| 952 |
---
|
| 953 |
|
| 954 |
+
**Status**: v4.1 architecture complete. Converting v3 checkpoints and resuming training with dual expert distillation.
|