| """ |
| LatentRecurrentFlow (LRF) - Training Pipeline |
| |
| Implements: |
| 1. VAE training (stage 1) |
| 2. Flow matching denoiser training (stage 2) |
| 3. Consistency distillation for few-step generation (stage 3) |
| 4. Editing fine-tuning (stage 4) |
| |
| All stages designed for 16GB RAM training. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from typing import Optional, Dict, Tuple |
| import os |
| import json |
|
|
|
|
| |
| |
| |
|
|
| class RectifiedFlowScheduler: |
| """ |
| Rectified flow (linear interpolation) scheduler. |
| |
| Forward process: z_t = (1 - t) * z_0 + t * epsilon |
| Velocity target: v = epsilon - z_0 |
| |
| At inference: solve ODE from t=1 (noise) to t=0 (clean) |
| Using Euler: z_{t-dt} = z_t - dt * v_theta(z_t, t, c) |
| |
| For few-step generation: use consistency distillation to learn |
| the full ODE solution in 1-4 steps. |
| """ |
| |
| def __init__(self, num_train_timesteps: int = 1000, shift: float = 1.0): |
| self.num_train_timesteps = num_train_timesteps |
| self.shift = shift |
| |
| def add_noise(self, z_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| """Forward process: z_t = (1-t) * z_0 + t * noise""" |
| t = t.view(-1, 1, 1, 1) |
| return (1 - t) * z_0 + t * noise |
| |
| def get_velocity_target(self, z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: |
| """Target velocity: v = noise - z_0""" |
| return noise - z_0 |
| |
| def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor: |
| """Sample timesteps with optional shifting for better training distribution.""" |
| t = torch.rand(batch_size, device=device) |
| if self.shift != 1.0: |
| |
| t = torch.sigmoid(self.shift * torch.erfinv(2 * t - 1)) |
| return t.clamp(1e-5, 1 - 1e-5) |
| |
| @torch.no_grad() |
| def euler_step(self, z_t: torch.Tensor, v: torch.Tensor, t: float, dt: float) -> torch.Tensor: |
| """Single Euler step: z_{t-dt} = z_t - dt * v""" |
| return z_t - dt * v |
| |
| @torch.no_grad() |
| def sample( |
| self, |
| model, |
| shape: Tuple[int, ...], |
| text_emb: Optional[torch.Tensor] = None, |
| text_global: Optional[torch.Tensor] = None, |
| num_steps: int = 20, |
| cfg_scale: float = 7.5, |
| device: torch.device = torch.device('cpu'), |
| ) -> torch.Tensor: |
| """ |
| Generate samples using Euler ODE solver. |
| |
| Args: |
| model: The RecursiveLatentCore or LatentRecurrentFlow model |
| shape: [B, C, H, W] shape of the latent |
| text_emb: [B, T, D] text token embeddings |
| text_global: [B, D] global text embedding |
| num_steps: Number of Euler steps (20 for quality, 4-8 for speed) |
| cfg_scale: Classifier-free guidance scale |
| """ |
| |
| z = torch.randn(shape, device=device) |
| |
| |
| timesteps = torch.linspace(1, 0, num_steps + 1, device=device) |
| |
| for i in range(num_steps): |
| t = timesteps[i] |
| dt = timesteps[i] - timesteps[i + 1] |
| |
| t_batch = torch.full((shape[0],), t.item(), device=device) |
| |
| if cfg_scale > 1.0 and text_emb is not None: |
| |
| v_cond = model.predict_velocity(z, t_batch, text_emb, text_global) |
| v_uncond = model.predict_velocity(z, t_batch, None, None) |
| v = v_uncond + cfg_scale * (v_cond - v_uncond) |
| else: |
| v = model.predict_velocity(z, t_batch, text_emb, text_global) |
| |
| z = self.euler_step(z, v, t.item(), dt.item()) |
| |
| return z |
|
|
|
|
| |
| |
| |
|
|
| class VAELoss(nn.Module): |
| """ |
| VAE training loss: reconstruction + KL divergence. |
| |
| Uses MSE + perceptual (LPIPS approximated by multi-scale MSE) + KL. |
| No adversarial loss in the first stage for simplicity. |
| """ |
| def __init__(self, kl_weight: float = 1e-6, perceptual_weight: float = 1.0): |
| super().__init__() |
| self.kl_weight = kl_weight |
| self.perceptual_weight = perceptual_weight |
| |
| def forward(self, recon, target, mean, logvar): |
| |
| recon_loss = F.l1_loss(recon, target) |
| |
| |
| perceptual_loss = 0.0 |
| x_down = target |
| r_down = recon |
| for scale in range(3): |
| if scale > 0: |
| x_down = F.avg_pool2d(x_down, 2) |
| r_down = F.avg_pool2d(r_down, 2) |
| perceptual_loss += F.mse_loss(r_down, x_down) |
| perceptual_loss /= 3.0 |
| |
| |
| kl_loss = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp()) |
| |
| total = recon_loss + self.perceptual_weight * perceptual_loss + self.kl_weight * kl_loss |
| |
| return { |
| 'total': total, |
| 'recon': recon_loss, |
| 'perceptual': perceptual_loss, |
| 'kl': kl_loss, |
| } |
|
|
|
|
| class FlowMatchingLoss(nn.Module): |
| """ |
| Rectified flow matching loss. |
| |
| L = E_{t, z_0, eps} || v_theta(z_t, t, c) - (eps - z_0) ||^2 |
| |
| With optional: |
| - SNR weighting (upweight harder timesteps) |
| - Velocity prediction (v-prediction) or epsilon prediction |
| """ |
| def __init__(self, snr_weight: bool = True): |
| super().__init__() |
| self.snr_weight = snr_weight |
| |
| def forward(self, v_pred, v_target, t): |
| |
| loss = (v_pred - v_target).pow(2).mean(dim=[1, 2, 3]) |
| |
| if self.snr_weight: |
| |
| |
| w = 1.0 / (t * (1 - t) + 0.01) |
| w = w / w.mean() |
| loss = loss * w |
| |
| return loss.mean() |
|
|
|
|
| class ConsistencyDistillationLoss(nn.Module): |
| """ |
| Consistency distillation loss for few-step generation. |
| |
| The student learns to map any point on the ODE trajectory |
| directly to the clean sample z_0. |
| |
| L_cd = || f_theta(z_{t_n}, t_n) - f_teacher(z_{t_{n-1}}, t_{n-1}) ||^2 |
| |
| Where f_teacher uses the pre-trained flow model with one Euler step. |
| """ |
| def __init__(self, num_scales: int = 50): |
| super().__init__() |
| self.num_scales = num_scales |
| |
| def forward(self, student_pred, teacher_target): |
| return F.mse_loss(student_pred, teacher_target) |
|
|
|
|
| |
| |
| |
|
|
| class LRFTrainer: |
| """ |
| Staged training pipeline for LRF. |
| |
| Stage 1: VAE training (learn image compression) |
| Stage 2: Flow matching (learn denoising, VAE frozen) |
| Stage 3: Consistency distillation (learn few-step generation) |
| Stage 4: Editing fine-tuning (add conditioning channels) |
| |
| Each stage can run independently with checkpointing. |
| """ |
| |
| def __init__( |
| self, |
| model, |
| device: torch.device = torch.device('cpu'), |
| output_dir: str = './checkpoints', |
| ): |
| self.model = model |
| self.device = device |
| self.output_dir = output_dir |
| os.makedirs(output_dir, exist_ok=True) |
| |
| self.scheduler = RectifiedFlowScheduler(shift=1.0) |
| |
| def train_vae_step(self, images: torch.Tensor, optimizer: torch.optim.Optimizer) -> Dict: |
| """Single VAE training step.""" |
| self.model.vae.train() |
| optimizer.zero_grad() |
| |
| images = images.to(self.device) |
| recon, mean, logvar = self.model.vae(images) |
| |
| loss_fn = VAELoss(kl_weight=1e-6) |
| losses = loss_fn(recon, images, mean, logvar) |
| |
| losses['total'].backward() |
| torch.nn.utils.clip_grad_norm_(self.model.vae.parameters(), 1.0) |
| optimizer.step() |
| |
| return {k: v.item() for k, v in losses.items()} |
| |
| def train_flow_step( |
| self, |
| images: torch.Tensor, |
| token_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| optimizer: torch.optim.Optimizer = None, |
| cfg_dropout: float = 0.1, |
| ) -> Dict: |
| """ |
| Single flow matching training step. |
| VAE is frozen, only core + text encoder trained. |
| """ |
| self.model.core.train() |
| self.model.text_encoder.train() |
| self.model.vae.eval() |
| |
| optimizer.zero_grad() |
| |
| images = images.to(self.device) |
| token_ids = token_ids.to(self.device) |
| if attention_mask is not None: |
| attention_mask = attention_mask.to(self.device) |
| |
| B = images.shape[0] |
| |
| |
| with torch.no_grad(): |
| z_0, _, _ = self.model.encode_image(images) |
| |
| |
| text_emb, text_global = self.model.encode_text(token_ids, attention_mask) |
| |
| |
| if cfg_dropout > 0: |
| mask = torch.rand(B, device=self.device) > cfg_dropout |
| text_emb = text_emb * mask.view(B, 1, 1) |
| text_global = text_global * mask.view(B, 1) |
| |
| |
| t = self.scheduler.sample_timesteps(B, self.device) |
| noise = torch.randn_like(z_0) |
| |
| |
| z_t = self.scheduler.add_noise(z_0, noise, t) |
| |
| |
| v_pred = self.model.predict_velocity(z_t, t, text_emb, text_global) |
| |
| |
| v_target = self.scheduler.get_velocity_target(z_0, noise) |
| |
| |
| loss_fn = FlowMatchingLoss(snr_weight=True) |
| loss = loss_fn(v_pred, v_target, t) |
| |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_( |
| list(self.model.core.parameters()) + list(self.model.text_encoder.parameters()), |
| 1.0 |
| ) |
| optimizer.step() |
| |
| return {'flow_loss': loss.item()} |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| token_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| num_steps: int = 20, |
| cfg_scale: float = 7.5, |
| latent_h: int = 16, |
| latent_w: int = 16, |
| ) -> torch.Tensor: |
| """Generate images from text.""" |
| self.model.eval() |
| device = next(self.model.parameters()).device |
| |
| token_ids = token_ids.to(device) |
| if attention_mask is not None: |
| attention_mask = attention_mask.to(device) |
| |
| B = token_ids.shape[0] |
| |
| |
| text_emb, text_global = self.model.encode_text(token_ids, attention_mask) |
| |
| |
| shape = (B, self.model.config['latent_channels'], latent_h, latent_w) |
| z = self.scheduler.sample( |
| self.model, shape, text_emb, text_global, |
| num_steps=num_steps, cfg_scale=cfg_scale, device=device, |
| ) |
| |
| |
| images = self.model.decode_latent(z) |
| |
| return images.clamp(-1, 1) |
| |
| def save_checkpoint(self, path: str, stage: str, epoch: int, extra: dict = None): |
| """Save training checkpoint.""" |
| ckpt = { |
| 'model_state': self.model.state_dict(), |
| 'config': self.model.config, |
| 'stage': stage, |
| 'epoch': epoch, |
| } |
| if extra: |
| ckpt.update(extra) |
| torch.save(ckpt, path) |
| print(f"Saved checkpoint: {path}") |
| |
| def load_checkpoint(self, path: str): |
| """Load training checkpoint.""" |
| ckpt = torch.load(path, map_location=self.device, weights_only=False) |
| self.model.load_state_dict(ckpt['model_state']) |
| print(f"Loaded checkpoint: {path} (stage={ckpt.get('stage')}, epoch={ckpt.get('epoch')})") |
| return ckpt |
|
|
|
|
| |
| |
| |
|
|
| class SyntheticImageTextDataset(Dataset): |
| """ |
| Generates synthetic data for testing the pipeline. |
| Produces random images + random token sequences. |
| In production, replace with real image-text pairs. |
| """ |
| def __init__(self, num_samples: int = 1000, image_size: int = 64, max_text_length: int = 32): |
| self.num_samples = num_samples |
| self.image_size = image_size |
| self.max_text_length = max_text_length |
| |
| def __len__(self): |
| return self.num_samples |
| |
| def __getitem__(self, idx): |
| |
| image = torch.randn(3, self.image_size, self.image_size).clamp(-1, 1) |
| |
| |
| text_len = torch.randint(5, self.max_text_length, (1,)).item() |
| token_ids = torch.randint(1, 31999, (self.max_text_length,)) |
| attention_mask = torch.zeros(self.max_text_length) |
| attention_mask[:text_len] = 1.0 |
| |
| return { |
| 'image': image, |
| 'token_ids': token_ids, |
| 'attention_mask': attention_mask, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def run_prototype_training( |
| config: Optional[Dict] = None, |
| num_vae_steps: int = 100, |
| num_flow_steps: int = 100, |
| batch_size: int = 4, |
| image_size: int = 64, |
| lr: float = 1e-4, |
| device: str = 'cpu', |
| output_dir: str = './lrf_checkpoints', |
| ): |
| """ |
| Run a complete prototype training loop. |
| |
| This demonstrates the full pipeline: |
| 1. Create model |
| 2. Train VAE |
| 3. Train flow matching denoiser |
| 4. Generate samples |
| |
| On CPU, this is for testing only. |
| On GPU with 16GB, this can train a real prototype. |
| """ |
| from lrf.model import LatentRecurrentFlow |
| |
| device = torch.device(device) |
| config = config or LatentRecurrentFlow.tiny_config() |
| |
| print("=" * 60) |
| print("LatentRecurrentFlow (LRF) - Prototype Training") |
| print("=" * 60) |
| |
| |
| model = LatentRecurrentFlow(config).to(device) |
| param_counts = model.count_parameters() |
| print("\nModel parameters:") |
| for name, count in param_counts.items(): |
| print(f" {name}: {count:,}") |
| |
| |
| trainer = LRFTrainer(model, device, output_dir) |
| |
| |
| dataset = SyntheticImageTextDataset( |
| num_samples=max(num_vae_steps, num_flow_steps) * batch_size, |
| image_size=image_size, |
| max_text_length=32, |
| ) |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
| |
| |
| print("\n" + "=" * 60) |
| print("Stage 1: VAE Training") |
| print("=" * 60) |
| |
| vae_optimizer = torch.optim.AdamW(model.vae.parameters(), lr=lr, weight_decay=0.01) |
| |
| step = 0 |
| for batch in dataloader: |
| if step >= num_vae_steps: |
| break |
| |
| losses = trainer.train_vae_step(batch['image'], vae_optimizer) |
| |
| if step % 20 == 0: |
| print(f" Step {step}: loss={losses['total']:.4f}, " |
| f"recon={losses['recon']:.4f}, kl={losses['kl']:.4f}") |
| step += 1 |
| |
| trainer.save_checkpoint( |
| os.path.join(output_dir, 'vae_checkpoint.pt'), |
| stage='vae', epoch=0 |
| ) |
| |
| |
| print("\n" + "=" * 60) |
| print("Stage 2: Flow Matching Denoiser Training") |
| print("=" * 60) |
| |
| |
| for p in model.vae.parameters(): |
| p.requires_grad = False |
| |
| flow_params = list(model.core.parameters()) + list(model.text_encoder.parameters()) |
| flow_optimizer = torch.optim.AdamW(flow_params, lr=lr, weight_decay=0.01) |
| |
| step = 0 |
| for batch in dataloader: |
| if step >= num_flow_steps: |
| break |
| |
| losses = trainer.train_flow_step( |
| batch['image'], batch['token_ids'], batch['attention_mask'], |
| flow_optimizer, cfg_dropout=0.1, |
| ) |
| |
| if step % 20 == 0: |
| print(f" Step {step}: flow_loss={losses['flow_loss']:.4f}") |
| step += 1 |
| |
| trainer.save_checkpoint( |
| os.path.join(output_dir, 'flow_checkpoint.pt'), |
| stage='flow', epoch=0 |
| ) |
| |
| |
| print("\n" + "=" * 60) |
| print("Stage 3: Sample Generation") |
| print("=" * 60) |
| |
| |
| sample_tokens = torch.randint(1, 31999, (2, 32)) |
| sample_mask = torch.ones(2, 32) |
| |
| latent_h = image_size // 16 |
| latent_w = image_size // 16 |
| |
| generated = trainer.generate( |
| sample_tokens, sample_mask, |
| num_steps=10, cfg_scale=3.0, |
| latent_h=latent_h, latent_w=latent_w, |
| ) |
| |
| print(f" Generated {generated.shape[0]} images of shape {generated.shape[1:]}") |
| print(f" Value range: [{generated.min():.3f}, {generated.max():.3f}]") |
| |
| |
| config_path = os.path.join(output_dir, 'config.json') |
| with open(config_path, 'w') as f: |
| json.dump(config, f, indent=2) |
| print(f"\nConfig saved to {config_path}") |
| |
| print("\n" + "=" * 60) |
| print("Training complete!") |
| print("=" * 60) |
| |
| return model, trainer |
|
|
|
|
| if __name__ == '__main__': |
| run_prototype_training( |
| num_vae_steps=50, |
| num_flow_steps=50, |
| batch_size=2, |
| image_size=64, |
| device='cpu', |
| ) |
|
|