""" LatentRecurrentFlow (LRF) - Training Pipeline Implements: 1. VAE training (stage 1) 2. Flow matching denoiser training (stage 2) 3. Consistency distillation for few-step generation (stage 3) 4. Editing fine-tuning (stage 4) All stages designed for 16GB RAM training. """ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from typing import Optional, Dict, Tuple import os import json # ============================================================================ # Rectified Flow Scheduler # ============================================================================ class RectifiedFlowScheduler: """ Rectified flow (linear interpolation) scheduler. Forward process: z_t = (1 - t) * z_0 + t * epsilon Velocity target: v = epsilon - z_0 At inference: solve ODE from t=1 (noise) to t=0 (clean) Using Euler: z_{t-dt} = z_t - dt * v_theta(z_t, t, c) For few-step generation: use consistency distillation to learn the full ODE solution in 1-4 steps. """ def __init__(self, num_train_timesteps: int = 1000, shift: float = 1.0): self.num_train_timesteps = num_train_timesteps self.shift = shift # Timestep shifting (from SD3) - helps quality def add_noise(self, z_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """Forward process: z_t = (1-t) * z_0 + t * noise""" t = t.view(-1, 1, 1, 1) # Broadcast return (1 - t) * z_0 + t * noise def get_velocity_target(self, z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: """Target velocity: v = noise - z_0""" return noise - z_0 def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor: """Sample timesteps with optional shifting for better training distribution.""" t = torch.rand(batch_size, device=device) if self.shift != 1.0: # Logit-normal distribution (from SD3) - concentrates training on harder timesteps t = torch.sigmoid(self.shift * torch.erfinv(2 * t - 1)) return t.clamp(1e-5, 1 - 1e-5) @torch.no_grad() def euler_step(self, z_t: torch.Tensor, v: torch.Tensor, t: float, dt: float) -> torch.Tensor: """Single Euler step: z_{t-dt} = z_t - dt * v""" return z_t - dt * v @torch.no_grad() def sample( self, model, shape: Tuple[int, ...], text_emb: Optional[torch.Tensor] = None, text_global: Optional[torch.Tensor] = None, num_steps: int = 20, cfg_scale: float = 7.5, device: torch.device = torch.device('cpu'), ) -> torch.Tensor: """ Generate samples using Euler ODE solver. Args: model: The RecursiveLatentCore or LatentRecurrentFlow model shape: [B, C, H, W] shape of the latent text_emb: [B, T, D] text token embeddings text_global: [B, D] global text embedding num_steps: Number of Euler steps (20 for quality, 4-8 for speed) cfg_scale: Classifier-free guidance scale """ # Start from pure noise z = torch.randn(shape, device=device) # Time steps from t=1 (noise) to t=0 (clean) timesteps = torch.linspace(1, 0, num_steps + 1, device=device) for i in range(num_steps): t = timesteps[i] dt = timesteps[i] - timesteps[i + 1] t_batch = torch.full((shape[0],), t.item(), device=device) if cfg_scale > 1.0 and text_emb is not None: # Classifier-free guidance v_cond = model.predict_velocity(z, t_batch, text_emb, text_global) v_uncond = model.predict_velocity(z, t_batch, None, None) v = v_uncond + cfg_scale * (v_cond - v_uncond) else: v = model.predict_velocity(z, t_batch, text_emb, text_global) z = self.euler_step(z, v, t.item(), dt.item()) return z # ============================================================================ # Loss Functions # ============================================================================ class VAELoss(nn.Module): """ VAE training loss: reconstruction + KL divergence. Uses MSE + perceptual (LPIPS approximated by multi-scale MSE) + KL. No adversarial loss in the first stage for simplicity. """ def __init__(self, kl_weight: float = 1e-6, perceptual_weight: float = 1.0): super().__init__() self.kl_weight = kl_weight self.perceptual_weight = perceptual_weight def forward(self, recon, target, mean, logvar): # Reconstruction loss (L1 is more robust than L2 for images) recon_loss = F.l1_loss(recon, target) # Multi-scale perceptual approximation (no external model needed) perceptual_loss = 0.0 x_down = target r_down = recon for scale in range(3): if scale > 0: x_down = F.avg_pool2d(x_down, 2) r_down = F.avg_pool2d(r_down, 2) perceptual_loss += F.mse_loss(r_down, x_down) perceptual_loss /= 3.0 # KL divergence kl_loss = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp()) total = recon_loss + self.perceptual_weight * perceptual_loss + self.kl_weight * kl_loss return { 'total': total, 'recon': recon_loss, 'perceptual': perceptual_loss, 'kl': kl_loss, } class FlowMatchingLoss(nn.Module): """ Rectified flow matching loss. L = E_{t, z_0, eps} || v_theta(z_t, t, c) - (eps - z_0) ||^2 With optional: - SNR weighting (upweight harder timesteps) - Velocity prediction (v-prediction) or epsilon prediction """ def __init__(self, snr_weight: bool = True): super().__init__() self.snr_weight = snr_weight def forward(self, v_pred, v_target, t): # Per-sample MSE loss = (v_pred - v_target).pow(2).mean(dim=[1, 2, 3]) # [B] if self.snr_weight: # SNR weighting: upweight middle timesteps # w(t) = 1 / (t * (1-t) + 0.01) - emphasizes t~0 and t~1 less w = 1.0 / (t * (1 - t) + 0.01) w = w / w.mean() # Normalize loss = loss * w return loss.mean() class ConsistencyDistillationLoss(nn.Module): """ Consistency distillation loss for few-step generation. The student learns to map any point on the ODE trajectory directly to the clean sample z_0. L_cd = || f_theta(z_{t_n}, t_n) - f_teacher(z_{t_{n-1}}, t_{n-1}) ||^2 Where f_teacher uses the pre-trained flow model with one Euler step. """ def __init__(self, num_scales: int = 50): super().__init__() self.num_scales = num_scales def forward(self, student_pred, teacher_target): return F.mse_loss(student_pred, teacher_target) # ============================================================================ # Training Stages # ============================================================================ class LRFTrainer: """ Staged training pipeline for LRF. Stage 1: VAE training (learn image compression) Stage 2: Flow matching (learn denoising, VAE frozen) Stage 3: Consistency distillation (learn few-step generation) Stage 4: Editing fine-tuning (add conditioning channels) Each stage can run independently with checkpointing. """ def __init__( self, model, device: torch.device = torch.device('cpu'), output_dir: str = './checkpoints', ): self.model = model self.device = device self.output_dir = output_dir os.makedirs(output_dir, exist_ok=True) self.scheduler = RectifiedFlowScheduler(shift=1.0) def train_vae_step(self, images: torch.Tensor, optimizer: torch.optim.Optimizer) -> Dict: """Single VAE training step.""" self.model.vae.train() optimizer.zero_grad() images = images.to(self.device) recon, mean, logvar = self.model.vae(images) loss_fn = VAELoss(kl_weight=1e-6) losses = loss_fn(recon, images, mean, logvar) losses['total'].backward() torch.nn.utils.clip_grad_norm_(self.model.vae.parameters(), 1.0) optimizer.step() return {k: v.item() for k, v in losses.items()} def train_flow_step( self, images: torch.Tensor, token_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, optimizer: torch.optim.Optimizer = None, cfg_dropout: float = 0.1, ) -> Dict: """ Single flow matching training step. VAE is frozen, only core + text encoder trained. """ self.model.core.train() self.model.text_encoder.train() self.model.vae.eval() optimizer.zero_grad() images = images.to(self.device) token_ids = token_ids.to(self.device) if attention_mask is not None: attention_mask = attention_mask.to(self.device) B = images.shape[0] # Encode images to latent space (no grad through VAE) with torch.no_grad(): z_0, _, _ = self.model.encode_image(images) # Encode text text_emb, text_global = self.model.encode_text(token_ids, attention_mask) # Classifier-free guidance dropout if cfg_dropout > 0: mask = torch.rand(B, device=self.device) > cfg_dropout text_emb = text_emb * mask.view(B, 1, 1) text_global = text_global * mask.view(B, 1) # Sample timesteps and noise t = self.scheduler.sample_timesteps(B, self.device) noise = torch.randn_like(z_0) # Create noisy latent z_t = self.scheduler.add_noise(z_0, noise, t) # Predict velocity v_pred = self.model.predict_velocity(z_t, t, text_emb, text_global) # Compute target v_target = self.scheduler.get_velocity_target(z_0, noise) # Loss loss_fn = FlowMatchingLoss(snr_weight=True) loss = loss_fn(v_pred, v_target, t) loss.backward() torch.nn.utils.clip_grad_norm_( list(self.model.core.parameters()) + list(self.model.text_encoder.parameters()), 1.0 ) optimizer.step() return {'flow_loss': loss.item()} @torch.no_grad() def generate( self, token_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, num_steps: int = 20, cfg_scale: float = 7.5, latent_h: int = 16, latent_w: int = 16, ) -> torch.Tensor: """Generate images from text.""" self.model.eval() device = next(self.model.parameters()).device token_ids = token_ids.to(device) if attention_mask is not None: attention_mask = attention_mask.to(device) B = token_ids.shape[0] # Encode text text_emb, text_global = self.model.encode_text(token_ids, attention_mask) # Sample latents shape = (B, self.model.config['latent_channels'], latent_h, latent_w) z = self.scheduler.sample( self.model, shape, text_emb, text_global, num_steps=num_steps, cfg_scale=cfg_scale, device=device, ) # Decode images = self.model.decode_latent(z) return images.clamp(-1, 1) def save_checkpoint(self, path: str, stage: str, epoch: int, extra: dict = None): """Save training checkpoint.""" ckpt = { 'model_state': self.model.state_dict(), 'config': self.model.config, 'stage': stage, 'epoch': epoch, } if extra: ckpt.update(extra) torch.save(ckpt, path) print(f"Saved checkpoint: {path}") def load_checkpoint(self, path: str): """Load training checkpoint.""" ckpt = torch.load(path, map_location=self.device, weights_only=False) self.model.load_state_dict(ckpt['model_state']) print(f"Loaded checkpoint: {path} (stage={ckpt.get('stage')}, epoch={ckpt.get('epoch')})") return ckpt # ============================================================================ # Synthetic Data Generator (for prototype testing) # ============================================================================ class SyntheticImageTextDataset(Dataset): """ Generates synthetic data for testing the pipeline. Produces random images + random token sequences. In production, replace with real image-text pairs. """ def __init__(self, num_samples: int = 1000, image_size: int = 64, max_text_length: int = 32): self.num_samples = num_samples self.image_size = image_size self.max_text_length = max_text_length def __len__(self): return self.num_samples def __getitem__(self, idx): # Random image in [-1, 1] image = torch.randn(3, self.image_size, self.image_size).clamp(-1, 1) # Random text tokens text_len = torch.randint(5, self.max_text_length, (1,)).item() token_ids = torch.randint(1, 31999, (self.max_text_length,)) attention_mask = torch.zeros(self.max_text_length) attention_mask[:text_len] = 1.0 return { 'image': image, 'token_ids': token_ids, 'attention_mask': attention_mask, } # ============================================================================ # Complete Training Script (self-contained) # ============================================================================ def run_prototype_training( config: Optional[Dict] = None, num_vae_steps: int = 100, num_flow_steps: int = 100, batch_size: int = 4, image_size: int = 64, lr: float = 1e-4, device: str = 'cpu', output_dir: str = './lrf_checkpoints', ): """ Run a complete prototype training loop. This demonstrates the full pipeline: 1. Create model 2. Train VAE 3. Train flow matching denoiser 4. Generate samples On CPU, this is for testing only. On GPU with 16GB, this can train a real prototype. """ from lrf.model import LatentRecurrentFlow device = torch.device(device) config = config or LatentRecurrentFlow.tiny_config() print("=" * 60) print("LatentRecurrentFlow (LRF) - Prototype Training") print("=" * 60) # Create model model = LatentRecurrentFlow(config).to(device) param_counts = model.count_parameters() print("\nModel parameters:") for name, count in param_counts.items(): print(f" {name}: {count:,}") # Create trainer trainer = LRFTrainer(model, device, output_dir) # Create synthetic data dataset = SyntheticImageTextDataset( num_samples=max(num_vae_steps, num_flow_steps) * batch_size, image_size=image_size, max_text_length=32, ) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) # ===== STAGE 1: VAE Training ===== print("\n" + "=" * 60) print("Stage 1: VAE Training") print("=" * 60) vae_optimizer = torch.optim.AdamW(model.vae.parameters(), lr=lr, weight_decay=0.01) step = 0 for batch in dataloader: if step >= num_vae_steps: break losses = trainer.train_vae_step(batch['image'], vae_optimizer) if step % 20 == 0: print(f" Step {step}: loss={losses['total']:.4f}, " f"recon={losses['recon']:.4f}, kl={losses['kl']:.4f}") step += 1 trainer.save_checkpoint( os.path.join(output_dir, 'vae_checkpoint.pt'), stage='vae', epoch=0 ) # ===== STAGE 2: Flow Matching Training ===== print("\n" + "=" * 60) print("Stage 2: Flow Matching Denoiser Training") print("=" * 60) # Freeze VAE for p in model.vae.parameters(): p.requires_grad = False flow_params = list(model.core.parameters()) + list(model.text_encoder.parameters()) flow_optimizer = torch.optim.AdamW(flow_params, lr=lr, weight_decay=0.01) step = 0 for batch in dataloader: if step >= num_flow_steps: break losses = trainer.train_flow_step( batch['image'], batch['token_ids'], batch['attention_mask'], flow_optimizer, cfg_dropout=0.1, ) if step % 20 == 0: print(f" Step {step}: flow_loss={losses['flow_loss']:.4f}") step += 1 trainer.save_checkpoint( os.path.join(output_dir, 'flow_checkpoint.pt'), stage='flow', epoch=0 ) # ===== STAGE 3: Generation ===== print("\n" + "=" * 60) print("Stage 3: Sample Generation") print("=" * 60) # Generate with random text sample_tokens = torch.randint(1, 31999, (2, 32)) sample_mask = torch.ones(2, 32) latent_h = image_size // 16 latent_w = image_size // 16 generated = trainer.generate( sample_tokens, sample_mask, num_steps=10, cfg_scale=3.0, latent_h=latent_h, latent_w=latent_w, ) print(f" Generated {generated.shape[0]} images of shape {generated.shape[1:]}") print(f" Value range: [{generated.min():.3f}, {generated.max():.3f}]") # Save config config_path = os.path.join(output_dir, 'config.json') with open(config_path, 'w') as f: json.dump(config, f, indent=2) print(f"\nConfig saved to {config_path}") print("\n" + "=" * 60) print("Training complete!") print("=" * 60) return model, trainer if __name__ == '__main__': run_prototype_training( num_vae_steps=50, num_flow_steps=50, batch_size=2, image_size=64, device='cpu', )