LatentRecurrentFlow / lrf /training.py
krystv's picture
Upload lrf/training.py with huggingface_hub
eba54b1 verified
"""
LatentRecurrentFlow (LRF) - Training Pipeline
Implements:
1. VAE training (stage 1)
2. Flow matching denoiser training (stage 2)
3. Consistency distillation for few-step generation (stage 3)
4. Editing fine-tuning (stage 4)
All stages designed for 16GB RAM training.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from typing import Optional, Dict, Tuple
import os
import json
# ============================================================================
# Rectified Flow Scheduler
# ============================================================================
class RectifiedFlowScheduler:
"""
Rectified flow (linear interpolation) scheduler.
Forward process: z_t = (1 - t) * z_0 + t * epsilon
Velocity target: v = epsilon - z_0
At inference: solve ODE from t=1 (noise) to t=0 (clean)
Using Euler: z_{t-dt} = z_t - dt * v_theta(z_t, t, c)
For few-step generation: use consistency distillation to learn
the full ODE solution in 1-4 steps.
"""
def __init__(self, num_train_timesteps: int = 1000, shift: float = 1.0):
self.num_train_timesteps = num_train_timesteps
self.shift = shift # Timestep shifting (from SD3) - helps quality
def add_noise(self, z_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Forward process: z_t = (1-t) * z_0 + t * noise"""
t = t.view(-1, 1, 1, 1) # Broadcast
return (1 - t) * z_0 + t * noise
def get_velocity_target(self, z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""Target velocity: v = noise - z_0"""
return noise - z_0
def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
"""Sample timesteps with optional shifting for better training distribution."""
t = torch.rand(batch_size, device=device)
if self.shift != 1.0:
# Logit-normal distribution (from SD3) - concentrates training on harder timesteps
t = torch.sigmoid(self.shift * torch.erfinv(2 * t - 1))
return t.clamp(1e-5, 1 - 1e-5)
@torch.no_grad()
def euler_step(self, z_t: torch.Tensor, v: torch.Tensor, t: float, dt: float) -> torch.Tensor:
"""Single Euler step: z_{t-dt} = z_t - dt * v"""
return z_t - dt * v
@torch.no_grad()
def sample(
self,
model,
shape: Tuple[int, ...],
text_emb: Optional[torch.Tensor] = None,
text_global: Optional[torch.Tensor] = None,
num_steps: int = 20,
cfg_scale: float = 7.5,
device: torch.device = torch.device('cpu'),
) -> torch.Tensor:
"""
Generate samples using Euler ODE solver.
Args:
model: The RecursiveLatentCore or LatentRecurrentFlow model
shape: [B, C, H, W] shape of the latent
text_emb: [B, T, D] text token embeddings
text_global: [B, D] global text embedding
num_steps: Number of Euler steps (20 for quality, 4-8 for speed)
cfg_scale: Classifier-free guidance scale
"""
# Start from pure noise
z = torch.randn(shape, device=device)
# Time steps from t=1 (noise) to t=0 (clean)
timesteps = torch.linspace(1, 0, num_steps + 1, device=device)
for i in range(num_steps):
t = timesteps[i]
dt = timesteps[i] - timesteps[i + 1]
t_batch = torch.full((shape[0],), t.item(), device=device)
if cfg_scale > 1.0 and text_emb is not None:
# Classifier-free guidance
v_cond = model.predict_velocity(z, t_batch, text_emb, text_global)
v_uncond = model.predict_velocity(z, t_batch, None, None)
v = v_uncond + cfg_scale * (v_cond - v_uncond)
else:
v = model.predict_velocity(z, t_batch, text_emb, text_global)
z = self.euler_step(z, v, t.item(), dt.item())
return z
# ============================================================================
# Loss Functions
# ============================================================================
class VAELoss(nn.Module):
"""
VAE training loss: reconstruction + KL divergence.
Uses MSE + perceptual (LPIPS approximated by multi-scale MSE) + KL.
No adversarial loss in the first stage for simplicity.
"""
def __init__(self, kl_weight: float = 1e-6, perceptual_weight: float = 1.0):
super().__init__()
self.kl_weight = kl_weight
self.perceptual_weight = perceptual_weight
def forward(self, recon, target, mean, logvar):
# Reconstruction loss (L1 is more robust than L2 for images)
recon_loss = F.l1_loss(recon, target)
# Multi-scale perceptual approximation (no external model needed)
perceptual_loss = 0.0
x_down = target
r_down = recon
for scale in range(3):
if scale > 0:
x_down = F.avg_pool2d(x_down, 2)
r_down = F.avg_pool2d(r_down, 2)
perceptual_loss += F.mse_loss(r_down, x_down)
perceptual_loss /= 3.0
# KL divergence
kl_loss = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
total = recon_loss + self.perceptual_weight * perceptual_loss + self.kl_weight * kl_loss
return {
'total': total,
'recon': recon_loss,
'perceptual': perceptual_loss,
'kl': kl_loss,
}
class FlowMatchingLoss(nn.Module):
"""
Rectified flow matching loss.
L = E_{t, z_0, eps} || v_theta(z_t, t, c) - (eps - z_0) ||^2
With optional:
- SNR weighting (upweight harder timesteps)
- Velocity prediction (v-prediction) or epsilon prediction
"""
def __init__(self, snr_weight: bool = True):
super().__init__()
self.snr_weight = snr_weight
def forward(self, v_pred, v_target, t):
# Per-sample MSE
loss = (v_pred - v_target).pow(2).mean(dim=[1, 2, 3]) # [B]
if self.snr_weight:
# SNR weighting: upweight middle timesteps
# w(t) = 1 / (t * (1-t) + 0.01) - emphasizes t~0 and t~1 less
w = 1.0 / (t * (1 - t) + 0.01)
w = w / w.mean() # Normalize
loss = loss * w
return loss.mean()
class ConsistencyDistillationLoss(nn.Module):
"""
Consistency distillation loss for few-step generation.
The student learns to map any point on the ODE trajectory
directly to the clean sample z_0.
L_cd = || f_theta(z_{t_n}, t_n) - f_teacher(z_{t_{n-1}}, t_{n-1}) ||^2
Where f_teacher uses the pre-trained flow model with one Euler step.
"""
def __init__(self, num_scales: int = 50):
super().__init__()
self.num_scales = num_scales
def forward(self, student_pred, teacher_target):
return F.mse_loss(student_pred, teacher_target)
# ============================================================================
# Training Stages
# ============================================================================
class LRFTrainer:
"""
Staged training pipeline for LRF.
Stage 1: VAE training (learn image compression)
Stage 2: Flow matching (learn denoising, VAE frozen)
Stage 3: Consistency distillation (learn few-step generation)
Stage 4: Editing fine-tuning (add conditioning channels)
Each stage can run independently with checkpointing.
"""
def __init__(
self,
model,
device: torch.device = torch.device('cpu'),
output_dir: str = './checkpoints',
):
self.model = model
self.device = device
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
self.scheduler = RectifiedFlowScheduler(shift=1.0)
def train_vae_step(self, images: torch.Tensor, optimizer: torch.optim.Optimizer) -> Dict:
"""Single VAE training step."""
self.model.vae.train()
optimizer.zero_grad()
images = images.to(self.device)
recon, mean, logvar = self.model.vae(images)
loss_fn = VAELoss(kl_weight=1e-6)
losses = loss_fn(recon, images, mean, logvar)
losses['total'].backward()
torch.nn.utils.clip_grad_norm_(self.model.vae.parameters(), 1.0)
optimizer.step()
return {k: v.item() for k, v in losses.items()}
def train_flow_step(
self,
images: torch.Tensor,
token_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
optimizer: torch.optim.Optimizer = None,
cfg_dropout: float = 0.1,
) -> Dict:
"""
Single flow matching training step.
VAE is frozen, only core + text encoder trained.
"""
self.model.core.train()
self.model.text_encoder.train()
self.model.vae.eval()
optimizer.zero_grad()
images = images.to(self.device)
token_ids = token_ids.to(self.device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.device)
B = images.shape[0]
# Encode images to latent space (no grad through VAE)
with torch.no_grad():
z_0, _, _ = self.model.encode_image(images)
# Encode text
text_emb, text_global = self.model.encode_text(token_ids, attention_mask)
# Classifier-free guidance dropout
if cfg_dropout > 0:
mask = torch.rand(B, device=self.device) > cfg_dropout
text_emb = text_emb * mask.view(B, 1, 1)
text_global = text_global * mask.view(B, 1)
# Sample timesteps and noise
t = self.scheduler.sample_timesteps(B, self.device)
noise = torch.randn_like(z_0)
# Create noisy latent
z_t = self.scheduler.add_noise(z_0, noise, t)
# Predict velocity
v_pred = self.model.predict_velocity(z_t, t, text_emb, text_global)
# Compute target
v_target = self.scheduler.get_velocity_target(z_0, noise)
# Loss
loss_fn = FlowMatchingLoss(snr_weight=True)
loss = loss_fn(v_pred, v_target, t)
loss.backward()
torch.nn.utils.clip_grad_norm_(
list(self.model.core.parameters()) + list(self.model.text_encoder.parameters()),
1.0
)
optimizer.step()
return {'flow_loss': loss.item()}
@torch.no_grad()
def generate(
self,
token_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
num_steps: int = 20,
cfg_scale: float = 7.5,
latent_h: int = 16,
latent_w: int = 16,
) -> torch.Tensor:
"""Generate images from text."""
self.model.eval()
device = next(self.model.parameters()).device
token_ids = token_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
B = token_ids.shape[0]
# Encode text
text_emb, text_global = self.model.encode_text(token_ids, attention_mask)
# Sample latents
shape = (B, self.model.config['latent_channels'], latent_h, latent_w)
z = self.scheduler.sample(
self.model, shape, text_emb, text_global,
num_steps=num_steps, cfg_scale=cfg_scale, device=device,
)
# Decode
images = self.model.decode_latent(z)
return images.clamp(-1, 1)
def save_checkpoint(self, path: str, stage: str, epoch: int, extra: dict = None):
"""Save training checkpoint."""
ckpt = {
'model_state': self.model.state_dict(),
'config': self.model.config,
'stage': stage,
'epoch': epoch,
}
if extra:
ckpt.update(extra)
torch.save(ckpt, path)
print(f"Saved checkpoint: {path}")
def load_checkpoint(self, path: str):
"""Load training checkpoint."""
ckpt = torch.load(path, map_location=self.device, weights_only=False)
self.model.load_state_dict(ckpt['model_state'])
print(f"Loaded checkpoint: {path} (stage={ckpt.get('stage')}, epoch={ckpt.get('epoch')})")
return ckpt
# ============================================================================
# Synthetic Data Generator (for prototype testing)
# ============================================================================
class SyntheticImageTextDataset(Dataset):
"""
Generates synthetic data for testing the pipeline.
Produces random images + random token sequences.
In production, replace with real image-text pairs.
"""
def __init__(self, num_samples: int = 1000, image_size: int = 64, max_text_length: int = 32):
self.num_samples = num_samples
self.image_size = image_size
self.max_text_length = max_text_length
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Random image in [-1, 1]
image = torch.randn(3, self.image_size, self.image_size).clamp(-1, 1)
# Random text tokens
text_len = torch.randint(5, self.max_text_length, (1,)).item()
token_ids = torch.randint(1, 31999, (self.max_text_length,))
attention_mask = torch.zeros(self.max_text_length)
attention_mask[:text_len] = 1.0
return {
'image': image,
'token_ids': token_ids,
'attention_mask': attention_mask,
}
# ============================================================================
# Complete Training Script (self-contained)
# ============================================================================
def run_prototype_training(
config: Optional[Dict] = None,
num_vae_steps: int = 100,
num_flow_steps: int = 100,
batch_size: int = 4,
image_size: int = 64,
lr: float = 1e-4,
device: str = 'cpu',
output_dir: str = './lrf_checkpoints',
):
"""
Run a complete prototype training loop.
This demonstrates the full pipeline:
1. Create model
2. Train VAE
3. Train flow matching denoiser
4. Generate samples
On CPU, this is for testing only.
On GPU with 16GB, this can train a real prototype.
"""
from lrf.model import LatentRecurrentFlow
device = torch.device(device)
config = config or LatentRecurrentFlow.tiny_config()
print("=" * 60)
print("LatentRecurrentFlow (LRF) - Prototype Training")
print("=" * 60)
# Create model
model = LatentRecurrentFlow(config).to(device)
param_counts = model.count_parameters()
print("\nModel parameters:")
for name, count in param_counts.items():
print(f" {name}: {count:,}")
# Create trainer
trainer = LRFTrainer(model, device, output_dir)
# Create synthetic data
dataset = SyntheticImageTextDataset(
num_samples=max(num_vae_steps, num_flow_steps) * batch_size,
image_size=image_size,
max_text_length=32,
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
# ===== STAGE 1: VAE Training =====
print("\n" + "=" * 60)
print("Stage 1: VAE Training")
print("=" * 60)
vae_optimizer = torch.optim.AdamW(model.vae.parameters(), lr=lr, weight_decay=0.01)
step = 0
for batch in dataloader:
if step >= num_vae_steps:
break
losses = trainer.train_vae_step(batch['image'], vae_optimizer)
if step % 20 == 0:
print(f" Step {step}: loss={losses['total']:.4f}, "
f"recon={losses['recon']:.4f}, kl={losses['kl']:.4f}")
step += 1
trainer.save_checkpoint(
os.path.join(output_dir, 'vae_checkpoint.pt'),
stage='vae', epoch=0
)
# ===== STAGE 2: Flow Matching Training =====
print("\n" + "=" * 60)
print("Stage 2: Flow Matching Denoiser Training")
print("=" * 60)
# Freeze VAE
for p in model.vae.parameters():
p.requires_grad = False
flow_params = list(model.core.parameters()) + list(model.text_encoder.parameters())
flow_optimizer = torch.optim.AdamW(flow_params, lr=lr, weight_decay=0.01)
step = 0
for batch in dataloader:
if step >= num_flow_steps:
break
losses = trainer.train_flow_step(
batch['image'], batch['token_ids'], batch['attention_mask'],
flow_optimizer, cfg_dropout=0.1,
)
if step % 20 == 0:
print(f" Step {step}: flow_loss={losses['flow_loss']:.4f}")
step += 1
trainer.save_checkpoint(
os.path.join(output_dir, 'flow_checkpoint.pt'),
stage='flow', epoch=0
)
# ===== STAGE 3: Generation =====
print("\n" + "=" * 60)
print("Stage 3: Sample Generation")
print("=" * 60)
# Generate with random text
sample_tokens = torch.randint(1, 31999, (2, 32))
sample_mask = torch.ones(2, 32)
latent_h = image_size // 16
latent_w = image_size // 16
generated = trainer.generate(
sample_tokens, sample_mask,
num_steps=10, cfg_scale=3.0,
latent_h=latent_h, latent_w=latent_w,
)
print(f" Generated {generated.shape[0]} images of shape {generated.shape[1:]}")
print(f" Value range: [{generated.min():.3f}, {generated.max():.3f}]")
# Save config
config_path = os.path.join(output_dir, 'config.json')
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
print(f"\nConfig saved to {config_path}")
print("\n" + "=" * 60)
print("Training complete!")
print("=" * 60)
return model, trainer
if __name__ == '__main__':
run_prototype_training(
num_vae_steps=50,
num_flow_steps=50,
batch_size=2,
image_size=64,
device='cpu',
)