LatentRecurrentFlow / README.md
krystv's picture
Add v2 training results with CIFAR-10 validation
cfaa4f6 verified
---
tags:
- image-generation
- latent-recurrent-flow
- lrf
- mobile-first
- flow-matching
- recursive-reasoning
- novel-architecture
- subquadratic-attention
- research
library_name: lrf
pipeline_tag: text-to-image
license: apache-2.0
---
# LatentRecurrentFlow (LRF) — A Novel Mobile-First Image Generation Architecture
> A genuinely new architecture for image generation designed from scratch to run on consumer devices with 3–4 GB RAM, trained on 16 GB budgets.
## 🔥 v2 Training Results (CIFAR-10)
**Trained end-to-end on CIFAR-10** (50K images, 10 classes) using:
- **Pre-trained TAESD** (2.4M frozen params) as the VAE — f=8 compression, 32×32 → 4×4×4 latents
- **1.47M parameter denoising core** with recursive refinement (4 shared blocks × 2 recursions = 8 effective layers)
- **Rectified flow** matching with SNR-weighted loss and 10% CFG dropout
- Training: 30 epochs, AdamW with cosine schedule, EMA decay 0.999
| Metric | Value |
|--------|-------|
| Final Loss | 0.931 |
| Training Time | ~70 min (CPU only!) |
| VAE Recon MSE | 0.068 |
| All 10 classes produce colorful images | ✅ |
### Sample Outputs
VAE Reconstruction (top: original, bottom: TAESD reconstruction):
![VAE Reconstruction](samples/vae_reconstruction.png)
Training progression (epoch 5 → 30):
![Epoch 5](samples/samples_epoch005.png)
![Epoch 30](samples/samples_epoch030.png)
Class-conditional generation (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck):
![Final Samples](samples/final_class_conditional.png)
Loss curve:
![Loss](samples/loss.png)
### Validation: No Grey Images
Every class produces images with proper variance:
```
airplane : std=0.383, range=1.908 ✅
automobile : std=0.448, range=2.000 ✅
bird : std=0.341, range=1.663 ✅
cat : std=0.521, range=2.000 ✅
deer : std=0.401, range=1.869 ✅
dog : std=0.477, range=1.994 ✅
frog : std=0.366, range=1.996 ✅
horse : std=0.499, range=1.972 ✅
ship : std=0.448, range=1.786 ✅
truck : std=0.510, range=1.944 ✅
```
---
## Architecture Overview
LRF combines five key innovations into a single coherent architecture:
| Innovation | Source Inspiration | What It Does |
|---|---|---|
| **Recursive Latent Refinement (RLR)** | HRM/TRM (2025) | Iterative fixed-point reasoning with O(1) memory backprop |
| **Efficient Spatial Mixer** | ViG/GLA + DyDiLA | Attention + DW-Conv locality (adapts to sequence length) |
| **Pre-trained TAESD VAE** | madebyollin/taesd | f=8 compression, 2.4M params, works out-of-box |
| **Rectified Flow** objective | SD3 / Liu et al. | Clean linear ODE for training and few-step sampling |
| **Additive Image Conditioning** | OmniGen | Same core supports text-to-image AND editing |
### v2 Architecture (Trained & Validated)
| Component | Parameters | Description |
|---|---|---|
| TAESD VAE (frozen) | 2.4M | Pre-trained image encoder/decoder |
| Denoising Core | 1.47M | 4 shared blocks × 2 inner recursions |
| Class Conditioner | 1.4K | Learned class embeddings for CIFAR-10 |
| **Trainable Total** | **1.47M** | |
### How It Works
```python
# 1. Encode image to latent (TAESD, frozen)
z_0 = vae.encode(image) # [B, 4, 4, 4]
# 2. Add noise (rectified flow)
z_t = (1-t) * z_0 + t * noise # Linear interpolation
# 3. Predict velocity (recursive denoising core)
v = core(z_t, t, class_label) # 4 blocks × 2 recursions
# 4. Training target
loss = MSE(v, noise - z_0) # Velocity matching
# 5. Sampling (Euler ODE solver, t=1→0)
for step in timesteps:
v = core(z, t, class_label)
z = z - dt * v
# 6. Decode to image (TAESD, frozen)
image = vae.decode(z)
```
---
## Quick Start
### Generate from trained model:
```python
import torch
from lrf.model_v2 import LRFv2, RectifiedFlowScheduler
from diffusers import AutoencoderTiny
# Load
vae = AutoencoderTiny.from_pretrained('madebyollin/taesd')
ckpt = torch.load('trained/cifar10_checkpoint.pt', map_location='cpu', weights_only=False)
model = LRFv2(ckpt['config'])
for name, p in model.named_parameters():
p.data.copy_(ckpt['ema_params'][name])
model.eval()
# Generate (class 3 = cat)
scheduler = RectifiedFlowScheduler()
labels = torch.full((4,), 3, dtype=torch.long)
z = scheduler.sample(model, (4,4,4,4), labels, num_steps=50, cfg_scale=3.0)
images = vae.decode(z).sample.clamp(-1, 1)
```
### Train from scratch:
```bash
python lrf/train_v2.py
```
---
## Files
| File | Description |
|---|---|
| `lrf/model_v2.py` | Core architecture (EfficientSpatialMixer, RecursiveLatentCore, LRFv2) |
| `lrf/train_v2.py` | CIFAR-10 training pipeline with TAESD VAE |
| `trained/cifar10_checkpoint.pt` | Trained weights (30 epochs, EMA) |
| `trained/config.json` | Model configuration |
| `samples/` | Generated sample images at various epochs |
| `lrf/model.py` | v1 architecture (research prototype) |
| `lrf/training.py` | v1 training pipeline |
| `lrf/pipeline.py` | HF-compatible inference pipeline |
| `notebook.ipynb` | Interactive walkthrough |
---
## Training Curriculum (Full Scale)
| Stage | Resolution | Data | Freeze | Train | LR | Steps |
|---|---|---|---|---|---|---|
| 1. VAE | 256² | ImageNet/COCO | - | VAE | 1e-4 | 50K |
| 2. Flow (low) | 64² | LAION-aesthetic | VAE | Core+Text | 1e-4 | 100K |
| 3. Flow (mid) | 256² | Filtered LAION | VAE | Core+Text | 5e-5 | 200K |
| 4. Flow (high) | 512² | Curated+JourneyDB | VAE | Core+Text | 2e-5 | 100K |
| 5. Distill | 512² | Same as 4 | VAE+Text | Core | 1e-5 | 50K |
| 6. Editing | 512² | InstructPix2Pix | VAE | Core+Text | 1e-5 | 50K |
**Shortcut (proven in this repo):** Skip Stage 1 entirely by using pre-trained TAESD. Start directly at Stage 2.
---
## Relevant Papers (Grouped by Problem)
### Subquadratic Spatial Mixing
- PDE-SSM-DiT (2603.13663): O(N log N) via Fourier PDE, 34× speedup
- DiMSUM (2411.04168): Mamba + wavelet, FID 2.11
- ViG/GLA (2405.18425): Gated Linear Attention, 90% memory savings
- DyDiLA (2601.13683): Dynamic differential linear attention
### Recursive Reasoning
- HRM (2506.21734): Fixed-point recurrence, O(1) memory via IFT
- TRM (2510.04871): 7M params → 45% ARC-AGI-1
### Compact Latent Spaces
- SANA DC-AE (2410.10629): f=32, PSNR 29.29
- SnapGen (2412.09619): 1.38M tiny decoder
- TAESD (madebyollin): 2.4M params, f=8, works immediately
### Few-Step Generation
- Consistency Models (2303.01469): One-step from diffusion
- LCM (2310.04378): 2-4 step via consistency distillation
### Editing Architectures
- OmniGen (2409.11340): Unified generation + editing
- InstructPix2Pix (2211.09800): Text-guided editing
---
## License
Apache 2.0