tiny-flux-deep / README.md
AbstractPhil's picture
Update README.md
82ba681 verified
|
raw
history blame
8.94 kB
---
license: mit
language:
- en
tags:
- diffusion
- flow-matching
- flux
- text-to-image
- image-generation
- deep
- experimental
library_name: pytorch
pipeline_tag: text-to-image
base_model:
- AbstractPhil/tiny-flux
- black-forest-labs/FLUX.1-schnell
datasets:
- AbstractPhil/flux-schnell-teacher-latents
---
# TinyFlux-Deep
An **expanded** TinyFlux architecture that increases depth and width while preserving learned representations. TinyFlux-Deep is ported from [TinyFlux](https://huggingface.co/AbstractPhil/tiny-flux) with strategic layer expansion and attention head doubling.
## Model Description
TinyFlux-Deep extends the base TinyFlux model by:
- **Doubling attention heads** (2 β†’ 4) with expanded hidden dimension (256 β†’ 512)
- **5Γ— more double-stream layers** (3 β†’ 15)
- **8Γ— more single-stream layers** (3 β†’ 25)
- **Preserving learned weights** from TinyFlux in frozen anchor positions
### Architecture Comparison
| Component | TinyFlux | TinyFlux-Deep | Flux |
|-----------|----------|---------------|------|
| Hidden size | 256 | **512** | 3072 |
| Attention heads | 2 | **4** | 24 |
| Head dimension | 128 | 128 | 128 |
| Double-stream layers | 3 | **15** | 19 |
| Single-stream layers | 3 | **25** | 38 |
| VAE channels | 16 | 16 | 16 |
| **Total params** | ~8M | **~85M** | ~12B |
### Layer Mapping (Ported from TinyFlux)
The original TinyFlux weights are strategically distributed and frozen:
**Single blocks (3 β†’ 25):**
| TinyFlux Layer | TinyFlux-Deep Position | Status |
|----------------|------------------------|--------|
| 0 | 0 | Frozen |
| 1 | 8, 12, 16 | Frozen (3 copies) |
| 2 | 24 | Frozen |
| β€” | 1-7, 9-11, 13-15, 17-23 | Trainable |
**Double blocks (3 β†’ 15):**
| TinyFlux Layer | TinyFlux-Deep Position | Status |
|----------------|------------------------|--------|
| 0 | 0 | Frozen |
| 1 | 4, 7, 10 | Frozen (3 copies) |
| 2 | 14 | Frozen |
| β€” | 1-3, 5-6, 8-9, 11-13 | Trainable |
**Trainable ratio:** ~70% of parameters
### Attention Head Expansion
Original 2 heads are copied to new positions, with 2 new heads randomly initialized:
- Old head 0 β†’ New head 0
- Old head 1 β†’ New head 1
- Heads 2-3 β†’ Xavier initialized (scaled 0.02Γ—)
### Text Encoders
Same as TinyFlux:
| Role | Model |
|------|-------|
| Sequence encoder | flan-t5-base (768 dim) |
| Pooled encoder | CLIP-L (768 dim) |
## Training
### Strategy
1. **Port** TinyFlux weights with dimension expansion
2. **Freeze** ported layers as "anchor" knowledge
3. **Train** new layers to interpolate between anchors
4. **Optional:** Unfreeze all and fine-tune at lower LR
### Dataset
Trained on [AbstractPhil/flux-schnell-teacher-latents](https://huggingface.co/datasets/AbstractPhil/flux-schnell-teacher-latents):
- 10,000 samples
- Pre-computed VAE latents (16, 64, 64) from 512Γ—512 images
- Diverse prompts covering people, objects, scenes, styles
### Training Details
- **Objective**: Flow matching (rectified flow)
- **Timestep sampling**: Logit-normal with Flux shift (s=3.0)
- **Loss weighting**: Min-SNR-Ξ³ (Ξ³=5.0)
- **Optimizer**: AdamW (lr=5e-5, Ξ²=(0.9, 0.99), wd=0.01)
- **Schedule**: Cosine with warmup
- **Precision**: bfloat16
- **Batch size**: 32 (16 Γ— 2 gradient accumulation)
## Usage
### Installation
```bash
pip install torch transformers diffusers safetensors huggingface_hub
```
### Inference
```python
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
# Load model (copy TinyFlux class definition first, use TinyFluxDeepConfig)
config = TinyFluxDeepConfig()
model = TinyFlux(config).to("cuda").to(torch.bfloat16)
weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors"))
model.load_state_dict(weights, strict=False) # strict=False for precomputed buffers
model.eval()
# Load encoders
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=torch.bfloat16).to("cuda")
clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16).to("cuda")
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
# Encode prompt
prompt = "a photo of a cat sitting on a windowsill"
t5_in = t5_tok(prompt, max_length=128, padding="max_length", truncation=True, return_tensors="pt").to("cuda")
t5_out = t5_enc(**t5_in).last_hidden_state
clip_in = clip_tok(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").to("cuda")
clip_out = clip_enc(**clip_in).pooler_output
# Euler sampling with Flux shift
def flux_shift(t, s=3.0):
return s * t / (1 + (s - 1) * t)
x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
img_ids = TinyFlux.create_img_ids(1, 64, 64, "cuda")
t_linear = torch.linspace(0, 1, 21, device="cuda")
timesteps = flux_shift(t_linear)
for i in range(20):
t = timesteps[i].unsqueeze(0)
dt = timesteps[i+1] - timesteps[i]
guidance = torch.tensor([3.5], device="cuda", dtype=torch.bfloat16)
v = model(
hidden_states=x,
encoder_hidden_states=t5_out,
pooled_projections=clip_out,
timestep=t,
img_ids=img_ids,
guidance=guidance,
)
x = x + v * dt
# Decode
latents = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2)
latents = latents / vae.config.scaling_factor
image = vae.decode(latents.float()).sample
image = (image / 2 + 0.5).clamp(0, 1)
```
### Configuration
```python
@dataclass
class TinyFluxDeepConfig:
hidden_size: int = 512
num_attention_heads: int = 4
attention_head_dim: int = 128
in_channels: int = 16
joint_attention_dim: int = 768
pooled_projection_dim: int = 768
num_double_layers: int = 15
num_single_layers: int = 25
mlp_ratio: float = 4.0
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
guidance_embeds: bool = True
```
## Files
```
AbstractPhil/tiny-flux-deep/
β”œβ”€β”€ model.safetensors # Model weights (~340MB)
β”œβ”€β”€ config.json # Model configuration
β”œβ”€β”€ frozen_params.json # List of frozen parameter names
β”œβ”€β”€ README.md # This file
β”œβ”€β”€ model.py # Model architecture (includes TinyFluxDeepConfig)
β”œβ”€β”€ inference_colab.py # Inference script
β”œβ”€β”€ train_deep_colab.py # Training script with layer freezing
β”œβ”€β”€ port_to_deep.py # Porting script from TinyFlux
β”œβ”€β”€ checkpoints/ # Training checkpoints
β”‚ └── step_*.safetensors
β”œβ”€β”€ logs/ # Tensorboard logs
└── samples/ # Generated samples during training
```
## Porting from TinyFlux
To create a new TinyFlux-Deep from scratch:
```python
# Run port_to_deep.py
# 1. Downloads AbstractPhil/tiny-flux weights
# 2. Creates TinyFlux-Deep model (512 hidden, 4 heads, 25 single, 15 double)
# 3. Expands attention heads (2β†’4) and hidden dimension (256β†’512)
# 4. Distributes layers to anchor positions
# 5. Saves to AbstractPhil/tiny-flux-deep
```
## Comparison with TinyFlux
| Aspect | TinyFlux | TinyFlux-Deep |
|--------|----------|---------------|
| Parameters | ~8M | ~85M |
| Memory (bf16) | ~16MB | ~170MB |
| Forward pass | ~15ms | ~60ms |
| Capacity | Limited | Moderate |
| Training | From scratch | Ported + fine-tuned |
## Limitations
- **Resolution**: Trained on 512Γ—512 only
- **Quality**: Better than TinyFlux, still below full Flux
- **Text understanding**: Limited by smaller T5 encoder (768 vs 4096 dim)
- **Early training**: Model is actively being trained
- **Experimental**: Intended for research, not production
## Intended Use
- Studying model scaling and expansion techniques
- Testing layer freezing and knowledge transfer
- Rapid prototyping with moderate capacity
- Educational purposes
- Baseline for architecture experiments
## Citation
```bibtex
@misc{tinyfluxdeep2026,
title={TinyFlux-Deep: Expanded Flux Architecture with Knowledge Preservation},
author={AbstractPhil},
year={2026},
url={https://huggingface.co/AbstractPhil/tiny-flux-deep}
}
```
## Related Models
- [AbstractPhil/tiny-flux](https://huggingface.co/AbstractPhil/tiny-flux) - Base model (8M params)
- [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) - Original Flux
## Acknowledgments
- [Black Forest Labs](https://blackforestlabs.ai/) for the original Flux architecture
- [Hugging Face](https://huggingface.co/) for diffusers and transformers libraries
## License
MIT License - See LICENSE file for details.
---
**Note**: This is an experimental research model under active development. Training is ongoing and weights may be updated frequently.