""" LatentRecurrentFlow (LRF) - Core Architecture Modules Architecture Overview: ===================== The LRF architecture consists of 4 main components: 1. CompactEncoder/Decoder (VAE): f=32 spatial compression with tiny decoder 2. TextConditioner: Lightweight text encoding (TinyCLIP or small LM) 3. RecursiveLatentCore: The novel HRM-inspired denoising backbone 4. FlowScheduler: Rectified flow for training and sampling The RecursiveLatentCore is the key innovation: - It contains N_blocks GLD (Gated Linear Diffusion) blocks - These blocks are applied recursively T_outer * T_inner times - The same parameters are reused across recursions (weight sharing) - Training uses IFT (Implicit Function Theorem) for O(1) memory backprop - This gives effective depth of T_outer * T_inner * N_blocks layers from only N_blocks parameter sets Memory budget at inference (1024x1024, INT8): - Text encoder: ~150MB (TinyCLIP-ViT-B/16) - VAE encoder: ~100MB (f32 encoder, only needed for editing) - VAE decoder: ~6MB (SnapGen-style tiny decoder) - LRF core: ~200-400MB (depending on config) - Activations: ~500MB peak - Total: ~1-1.5GB model + ~500MB activations = 1.5-2GB """ import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from typing import Optional, Tuple, Dict, Any # ============================================================================ # Utility Modules # ============================================================================ class RMSNorm(nn.Module): """RMSNorm - more stable than LayerNorm for small models.""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() return (x.float() * norm).type_as(x) * self.weight class SwiGLU(nn.Module): """SwiGLU FFN - better than GELU for small models, mobile-friendly (SiLU not GELU).""" def __init__(self, dim: int, hidden_dim: Optional[int] = None, dropout: float = 0.0): super().__init__() hidden_dim = hidden_dim or int(dim * 8 / 3) # Round to nearest multiple of 8 for efficiency hidden_dim = ((hidden_dim + 7) // 8) * 8 self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) class DepthwiseSeparableConv2d(nn.Module): """Mobile-optimized convolution.""" def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3): super().__init__() padding = kernel_size // 2 self.dw = nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding, groups=in_channels, bias=False) self.pw = nn.Conv2d(in_channels, out_channels, 1, bias=False) def forward(self, x): return self.pw(self.dw(x)) # ============================================================================ # 2D Positional Encoding # ============================================================================ class RotaryPositionEncoding2D(nn.Module): """2D RoPE for spatial tokens - resolution-independent.""" def __init__(self, dim: int, max_res: int = 64): super().__init__() self.dim = dim half_dim = dim // 4 # Split into 4 parts: sin_h, cos_h, sin_w, cos_w freqs = torch.exp(torch.arange(half_dim) * -(math.log(10000.0) / half_dim)) self.register_buffer('freqs', freqs) def forward(self, h: int, w: int, device=None): device = device or self.freqs.device pos_h = torch.arange(h, device=device).float() pos_w = torch.arange(w, device=device).float() freqs_h = torch.outer(pos_h, self.freqs.to(device)) # [H, D/4] freqs_w = torch.outer(pos_w, self.freqs.to(device)) # [W, D/4] # Expand to [H, W, D/4] each freqs_h = freqs_h.unsqueeze(1).expand(-1, w, -1) freqs_w = freqs_w.unsqueeze(0).expand(h, -1, -1) # Concatenate: [H, W, D/2] for sin, [H, W, D/2] for cos freqs = torch.cat([freqs_h, freqs_w], dim=-1) # [H, W, D/2] sin_enc = freqs.sin() cos_enc = freqs.cos() return sin_enc.reshape(h * w, -1), cos_enc.reshape(h * w, -1) def apply_rope_2d(x, sin_enc, cos_enc): """Apply 2D RoPE to queries/keys.""" d = x.shape[-1] half_d = d // 2 x1, x2 = x[..., :half_d], x[..., half_d:] # Expand sin/cos to match batch dims while sin_enc.dim() < x1.dim(): sin_enc = sin_enc.unsqueeze(0) cos_enc = cos_enc.unsqueeze(0) return torch.cat([x1 * cos_enc - x2 * sin_enc, x2 * cos_enc + x1 * sin_enc], dim=-1) # ============================================================================ # Gated Linear Diffusion (GLD) Block - The Core Spatial Mixer # ============================================================================ class GatedLinearAttention(nn.Module): """ Gated Linear Attention for 2D spatial mixing. O(N) complexity instead of O(N²) softmax attention. Based on ViG/GLA research but adapted for diffusion: - Bidirectional scan (forward + backward) - 2D locality injection via depthwise conv gating - Token-differential operator to prevent oversmoothing (from DyDiLA) Math: Q, K, V = linear(x), linear(x), linear(x) Q = phi(Q), K = phi(K) where phi = 1 + elu (non-negative feature map) Forward scan: S_i = decay * S_{i-1} + K_i^T V_i; O_i = Q_i S_i Backward scan: same in reverse Output = gate * (O_fwd + O_bwd) * local_gate Complexity: O(N * d²) where d is head dimension, N is sequence length """ def __init__(self, dim: int, num_heads: int = 8, head_dim: int = 32, dropout: float = 0.0): super().__init__() self.num_heads = num_heads self.head_dim = head_dim inner_dim = num_heads * head_dim self.qkv = nn.Linear(dim, 3 * inner_dim, bias=False) self.out_proj = nn.Linear(inner_dim, dim, bias=False) # Learnable decay for recurrence (per-head) self.log_decay = nn.Parameter(torch.zeros(num_heads)) # Gate for output self.gate = nn.Linear(dim, inner_dim, bias=False) # 2D locality injection (depthwise conv) - critical for spatial structure self.local_conv = nn.Conv2d(inner_dim, inner_dim, 3, padding=1, groups=inner_dim, bias=False) self.local_gate = nn.Linear(dim, inner_dim, bias=False) # Token differential parameter (from DyDiLA - prevents oversmoothing) self.diff_lambda = nn.Parameter(torch.tensor(0.1)) self.dropout = nn.Dropout(dropout) self.norm = RMSNorm(inner_dim) def _feature_map(self, x): """Non-negative feature map: 1 + elu(x)""" return 1.0 + F.elu(x) def _scan(self, Q, K, V, reverse=False): """Linear recurrent scan - O(N * d²) per direction.""" B, H, N, D = Q.shape decay = torch.sigmoid(self.log_decay).view(1, H, 1, 1) # [1, H, 1, 1] if reverse: Q = Q.flip(2) K = K.flip(2) V = V.flip(2) # Chunk-wise computation for memory efficiency chunk_size = min(64, N) outputs = [] S = torch.zeros(B, H, D, D, device=Q.device, dtype=Q.dtype) for i in range(0, N, chunk_size): q_chunk = Q[:, :, i:i+chunk_size] # [B, H, C, D] k_chunk = K[:, :, i:i+chunk_size] v_chunk = V[:, :, i:i+chunk_size] chunk_len = q_chunk.shape[2] # Update state: S = decay * S + K^T V kv = torch.einsum('bhcd,bhce->bhde', k_chunk, v_chunk) S = decay * S + kv # Query state: O = Q S o_chunk = torch.einsum('bhcd,bhde->bhce', q_chunk, S) outputs.append(o_chunk) output = torch.cat(outputs, dim=2) if reverse: output = output.flip(2) return output def forward(self, x, h: int, w: int): """ Args: x: [B, N, D] where N = H*W h, w: spatial dimensions Returns: [B, N, D] """ B, N, D = x.shape # Project to Q, K, V qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) # Reshape to heads q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads) k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads) v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads) # Token differential (prevents oversmoothing) # Q_diff = Q_i - lambda * Q_{i-1}, K_diff = K_i - lambda * K_{i-1} lam = torch.sigmoid(self.diff_lambda) q_shifted = F.pad(q[:, :, :-1], (0, 0, 1, 0)) k_shifted = F.pad(k[:, :, :-1], (0, 0, 1, 0)) q = q - lam * q_shifted k = k - lam * k_shifted # Apply feature map (non-negative) q = self._feature_map(q) k = self._feature_map(k) # Bidirectional scan o_fwd = self._scan(q, k, v, reverse=False) o_bwd = self._scan(q, k, v, reverse=True) output = o_fwd + o_bwd # Normalize output = rearrange(output, 'b h n d -> b n (h d)') output = self.norm(output) # 2D locality injection (GaLI from ViG) x_2d = rearrange(x, 'b (h w) d -> b d h w', h=h, w=w) gate_input = rearrange(x, 'b n d -> b n d') local_feat = self.local_conv(rearrange(self.local_gate(gate_input), 'b (h w) d -> b d h w', h=h, w=w)) local_feat = rearrange(local_feat, 'b d h w -> b (h w) d') # Gated output g = torch.sigmoid(self.gate(x)) output = g * output * torch.sigmoid(local_feat) return self.dropout(self.out_proj(output)) class GLDBlock(nn.Module): """ Gated Linear Diffusion Block. Components: 1. GatedLinearAttention for spatial mixing (O(N) complexity) 2. SwiGLU FFN for channel mixing 3. Timestep + condition modulation (adaptive layer norm) 4. 2D RoPE for position encoding This replaces the standard transformer block in diffusion models. """ def __init__( self, dim: int, num_heads: int = 8, head_dim: int = 32, ffn_mult: float = 2.67, dropout: float = 0.0, cond_dim: int = 256, ): super().__init__() self.norm1 = RMSNorm(dim) self.norm2 = RMSNorm(dim) self.attn = GatedLinearAttention(dim, num_heads, head_dim, dropout) self.ffn = SwiGLU(dim, int(dim * ffn_mult), dropout) # Adaptive modulation (scale, shift, gate for each sub-layer) # Conditioned on timestep + text embedding self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(cond_dim, 6 * dim, bias=False), ) # Cross-attention to text (lightweight - only when text is available) self.cross_norm = RMSNorm(dim) self.cross_q = nn.Linear(dim, dim, bias=False) self.cross_kv = nn.Linear(cond_dim, 2 * dim, bias=False) self.cross_out = nn.Linear(dim, dim, bias=False) self.cross_gate = nn.Parameter(torch.zeros(1)) # Zero-init for residual def forward( self, x: torch.Tensor, # [B, N, D] cond: torch.Tensor, # [B, cond_dim] - timestep + global condition text_ctx: Optional[torch.Tensor] = None, # [B, T, cond_dim] - text tokens h: int = 32, w: int = 32, ) -> torch.Tensor: B, N, D = x.shape # Compute modulation parameters mod = self.adaLN_modulation(cond) # [B, 6*D] shift1, scale1, gate1, shift2, scale2, gate2 = mod.chunk(6, dim=-1) # Pre-norm + modulate + GLA x_norm = self.norm1(x) x_norm = x_norm * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1) x = x + gate1.unsqueeze(1) * self.attn(x_norm, h, w) # Cross-attention to text (if available) if text_ctx is not None: x_cross = self.cross_norm(x) q = self.cross_q(x_cross) kv = self.cross_kv(text_ctx) k, v = kv.chunk(2, dim=-1) # Simple dot-product attention (text sequence is short, so O(N*T) is fine) scale = q.shape[-1] ** -0.5 attn_weights = torch.bmm(q, k.transpose(-2, -1)) * scale attn_weights = F.softmax(attn_weights, dim=-1) cross_out = torch.bmm(attn_weights, v) x = x + torch.tanh(self.cross_gate) * self.cross_out(cross_out) # Pre-norm + modulate + FFN x_norm = self.norm2(x) x_norm = x_norm * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1) x = x + gate2.unsqueeze(1) * self.ffn(x_norm) return x # ============================================================================ # Recursive Latent Refinement (RLR) Core - THE KEY INNOVATION # ============================================================================ class RecursiveLatentCore(nn.Module): """ The Recursive Latent Refinement (RLR) Core. This is the key architectural innovation of LRF. Instead of stacking many unique transformer layers (like DiT with 28 layers), we use a small set of GLD blocks applied RECURSIVELY through an HRM-inspired iterative refinement loop. Architecture: - N_blocks GLD blocks (typically 4-6, shared across recursions) - T_inner recursive applications per outer step (typically 4-6) - T_outer outer steps with slow abstract state update (typically 2-3) Effective depth: T_outer * T_inner * N_blocks = 2*4*4 = 32 effective layers Actual parameters: only N_blocks sets = 4 unique block parameter sets Training uses IFT (Implicit Function Theorem): - Forward: run full recursion with torch.no_grad() for warmup - Backward: only backprop through the LAST recursion step - This gives O(1) memory cost regardless of recursion depth! Mathematical formulation: Let z be the noisy latent, c be the condition embedding. Outer loop (j = 1..T_outer): z_abstract = f_slow(z, c) # Abstract planning update Inner loop (i = 1..T_inner): z = f_blocks(z, z_abstract, c) # Apply N shared GLD blocks Where f_blocks applies the same N GLD blocks in sequence. The model learns a FIXED POINT: z* = f(z*, c) At convergence, the output is the denoised prediction v(z_t, t, c). """ def __init__( self, dim: int = 384, cond_dim: int = 256, num_blocks: int = 4, num_heads: int = 6, head_dim: int = 64, T_inner: int = 4, T_outer: int = 2, ffn_mult: float = 2.67, dropout: float = 0.0, use_ift_training: bool = True, ): super().__init__() self.dim = dim self.cond_dim = cond_dim self.num_blocks = num_blocks self.T_inner = T_inner self.T_outer = T_outer self.use_ift_training = use_ift_training # The shared GLD blocks (applied recursively) self.blocks = nn.ModuleList([ GLDBlock( dim=dim, num_heads=num_heads, head_dim=head_dim, ffn_mult=ffn_mult, dropout=dropout, cond_dim=cond_dim, ) for _ in range(num_blocks) ]) # Abstract state updater (the "slow" H-module from HRM) # This updates a global abstract representation every T_inner steps self.abstract_norm = RMSNorm(dim) self.abstract_update = nn.Sequential( nn.Linear(dim * 2, dim, bias=False), nn.SiLU(), nn.Linear(dim, dim, bias=False), ) self.abstract_gate = nn.Parameter(torch.zeros(1)) # Zero-init # Input projection self.input_proj = nn.Linear(dim, dim, bias=False) # Timestep embedding self.time_embed = nn.Sequential( nn.Linear(256, cond_dim), nn.SiLU(), nn.Linear(cond_dim, cond_dim), ) # Output projection (predicts velocity v for rectified flow) self.out_norm = RMSNorm(dim) self.out_proj = nn.Sequential( nn.Linear(dim, dim, bias=False), nn.SiLU(), nn.Linear(dim, dim, bias=False), ) # Recursion depth embedding (tells the model which recursion step it's on) self.recursion_embed = nn.Embedding(T_outer * T_inner + 1, cond_dim) # 2D positional encoding self.rope = RotaryPositionEncoding2D(head_dim) def _sinusoidal_embedding(self, t: torch.Tensor, dim: int = 256) -> torch.Tensor: """Sinusoidal timestep embedding.""" half_dim = dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) emb = t.unsqueeze(-1) * emb.unsqueeze(0) return torch.cat([emb.sin(), emb.cos()], dim=-1) def _apply_blocks(self, z, cond, text_ctx, h, w): """Apply all GLD blocks once.""" for block in self.blocks: z = block(z, cond, text_ctx, h, w) return z def _recursive_refinement(self, z, cond_base, text_ctx, h, w): """ Full recursive refinement loop. Returns the refined latent z after T_outer * T_inner applications. """ z_abstract = z.mean(dim=1, keepdim=True).expand_as(z) # Initial abstract state step_idx = 0 for j in range(self.T_outer): # Abstract state update (slow H-module) z_pooled = z.mean(dim=1, keepdim=True).expand_as(z) abstract_input = torch.cat([self.abstract_norm(z), z_pooled], dim=-1) z_abstract = z_abstract + torch.tanh(self.abstract_gate) * self.abstract_update(abstract_input) for i in range(self.T_inner): # Add recursion depth information to conditioning rec_emb = self.recursion_embed( torch.tensor([step_idx], device=z.device) ).expand(z.shape[0], -1) cond = cond_base + rec_emb # Apply shared blocks with abstract state modulation z_input = z + z_abstract # Combine detail + abstract z = z + (self._apply_blocks(z_input, cond, text_ctx, h, w) - z) * 0.5 # Damped update step_idx += 1 return z def forward( self, z_t: torch.Tensor, # [B, C, H, W] - noisy latent t: torch.Tensor, # [B] - timestep (0 to 1) text_emb: Optional[torch.Tensor] = None, # [B, T, cond_dim] - text tokens text_global: Optional[torch.Tensor] = None, # [B, cond_dim] - global text embedding image_cond: Optional[torch.Tensor] = None, # [B, C, H, W] - for editing tasks ) -> torch.Tensor: """ Forward pass predicting velocity v_theta(z_t, t, c). For rectified flow: z_t = (1-t) * z_0 + t * epsilon Target: v = epsilon - z_0 """ B, C, H, W = z_t.shape # Flatten spatial dims z = rearrange(z_t, 'b c h w -> b (h w) c') # If editing: concatenate condition image (channel-wise before projection) if image_cond is not None: img_cond_flat = rearrange(image_cond, 'b c h w -> b (h w) c') z = z + img_cond_flat # Additive conditioning preserves spatial correspondence # Project z = self.input_proj(z) # Build conditioning t_emb = self._sinusoidal_embedding(t) t_emb = self.time_embed(t_emb) # [B, cond_dim] if text_global is not None: cond = t_emb + text_global else: cond = t_emb # Apply recursive refinement if self.training and self.use_ift_training: # IFT training: no_grad warmup + 1-step grad with torch.no_grad(): for _ in range(self.T_outer - 1): z = self._recursive_refinement(z, cond, text_emb, H, W) # Last step with gradients z = self._recursive_refinement(z, cond, text_emb, H, W) else: # Full recursion (inference or non-IFT training) z = self._recursive_refinement(z, cond, text_emb, H, W) # Output projection z = self.out_norm(z) v = self.out_proj(z) # Reshape back to spatial v = rearrange(v, 'b (h w) c -> b c h w', h=H, w=W) return v # ============================================================================ # Compact VAE (Tiny Decoder inspired by SnapGen) # ============================================================================ class TinyResBlock(nn.Module): """Ultra-compact residual block for tiny decoder.""" def __init__(self, in_channels: int, out_channels: int = None): super().__init__() out_channels = out_channels or in_channels self.norm1 = nn.GroupNorm(min(8, in_channels), in_channels) self.conv1 = DepthwiseSeparableConv2d(in_channels, out_channels, 3) self.norm2 = nn.GroupNorm(min(8, out_channels), out_channels) self.conv2 = DepthwiseSeparableConv2d(out_channels, out_channels, 3) self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) if in_channels != out_channels else nn.Identity() def forward(self, x): h = self.conv1(F.silu(self.norm1(x))) h = self.conv2(F.silu(self.norm2(h))) return self.skip(x) + h class CompactEncoder(nn.Module): """ Compact image encoder: image -> latent space. f=16 spatial compression, C_latent channels. Uses strided depthwise-separable convolutions for efficiency. 4 downsampling stages: 256->128->64->32->16 (for 256x256 input) """ def __init__( self, in_channels: int = 3, latent_channels: int = 32, base_channels: int = 64, num_res_blocks: int = 2, ): super().__init__() channels = [base_channels, base_channels * 2, base_channels * 4, base_channels * 4] self.stem = nn.Conv2d(in_channels, channels[0], 3, padding=1, bias=False) self.downs = nn.ModuleList() ch_in = channels[0] for ch_out in channels: blocks = nn.ModuleList() # First block handles channel transition blocks.append(TinyResBlock(ch_in, ch_out)) for _ in range(num_res_blocks - 1): blocks.append(TinyResBlock(ch_out, ch_out)) # Downsample with strided conv down = nn.Conv2d(ch_out, ch_out, 4, stride=2, padding=1, bias=False) self.downs.append(nn.ModuleDict({ 'blocks': blocks, 'down': down, })) ch_in = ch_out # To latent self.to_latent = nn.Sequential( nn.GroupNorm(8, ch_in), nn.SiLU(), nn.Conv2d(ch_in, latent_channels * 2, 1, bias=False), # *2 for mean+logvar ) def forward(self, x): h = self.stem(x) for down_module in self.downs: for block in down_module['blocks']: h = block(h) h = down_module['down'](h) params = self.to_latent(h) mean, logvar = params.chunk(2, dim=1) logvar = torch.clamp(logvar, -30.0, 20.0) return mean, logvar class TinyDecoder(nn.Module): """ SnapGen-inspired tiny decoder: latent -> image. ~1-2M parameters. No attention layers. Uses depthwise-separable convolutions + minimal GroupNorm. 4 upsampling stages matching the encoder. """ def __init__( self, latent_channels: int = 32, out_channels: int = 3, base_channels: int = 128, num_res_blocks: int = 2, ): super().__init__() channels = [base_channels * 2, base_channels * 2, base_channels, base_channels // 2] self.from_latent = nn.Conv2d(latent_channels, channels[0], 1, bias=False) self.ups = nn.ModuleList() ch_in = channels[0] for ch_out in channels: blocks = nn.ModuleList() for _ in range(num_res_blocks): blocks.append(TinyResBlock(ch_in, ch_in)) # Upsample with channel transition up = nn.Sequential( nn.Upsample(scale_factor=2, mode='nearest'), DepthwiseSeparableConv2d(ch_in, ch_out, 3), ) self.ups.append(nn.ModuleDict({ 'blocks': blocks, 'up': up, })) ch_in = ch_out self.to_image = nn.Sequential( nn.GroupNorm(min(8, ch_in), ch_in), nn.SiLU(), nn.Conv2d(ch_in, out_channels, 3, padding=1), nn.Tanh(), # Output in [-1, 1] ) def forward(self, z): h = self.from_latent(z) for up_module in self.ups: for block in up_module['blocks']: h = block(h) h = up_module['up'](h) return self.to_image(h) class CompactVAE(nn.Module): """ Complete VAE with compact encoder + tiny decoder. f=16 compression, configurable latent channels. """ def __init__( self, in_channels: int = 3, latent_channels: int = 32, encoder_base_ch: int = 64, decoder_base_ch: int = 128, ): super().__init__() self.encoder = CompactEncoder(in_channels, latent_channels, encoder_base_ch) self.decoder = TinyDecoder(latent_channels, in_channels, decoder_base_ch) self.latent_channels = latent_channels def encode(self, x): mean, logvar = self.encoder(x) if self.training: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = mean + eps * std else: z = mean return z, mean, logvar def decode(self, z): return self.decoder(z) def forward(self, x): z, mean, logvar = self.encode(x) recon = self.decode(z) return recon, mean, logvar # ============================================================================ # Text Conditioner (Lightweight) # ============================================================================ class SimpleTextEncoder(nn.Module): """ Lightweight text encoder for the standalone prototype. In production, this would be replaced by TinyCLIP or a small LM. For the prototype: simple learned embeddings + small transformer. This lets us test the full pipeline without a heavy text encoder. """ def __init__( self, vocab_size: int = 32000, max_length: int = 77, dim: int = 256, num_layers: int = 4, num_heads: int = 4, ): super().__init__() self.dim = dim self.token_embed = nn.Embedding(vocab_size, dim) self.pos_embed = nn.Embedding(max_length, dim) encoder_layer = nn.TransformerEncoderLayer( d_model=dim, nhead=num_heads, dim_feedforward=dim*4, dropout=0.1, activation='gelu', batch_first=True, norm_first=True ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.norm = RMSNorm(dim) # Global pooling projection self.global_proj = nn.Sequential( nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim), ) def forward(self, token_ids, attention_mask=None): B, T = token_ids.shape pos_ids = torch.arange(T, device=token_ids.device).unsqueeze(0).expand(B, -1) x = self.token_embed(token_ids) + self.pos_embed(pos_ids) if attention_mask is not None: # Convert to transformer mask (True = ignore) src_key_padding_mask = ~attention_mask.bool() else: src_key_padding_mask = None x = self.transformer(x, src_key_padding_mask=src_key_padding_mask) x = self.norm(x) # Global embedding (mean pool over non-padded tokens) if attention_mask is not None: mask = attention_mask.unsqueeze(-1).float() global_emb = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) else: global_emb = x.mean(dim=1) global_emb = self.global_proj(global_emb) return x, global_emb # [B, T, D], [B, D] # ============================================================================ # Full LRF Model # ============================================================================ class LatentRecurrentFlow(nn.Module): """ LatentRecurrentFlow (LRF) - Complete model. Combines: 1. CompactVAE for image encoding/decoding 2. SimpleTextEncoder for text conditioning 3. RecursiveLatentCore for denoising Training modes: - 'vae': Train only the VAE - 'denoise': Train only the denoising core (freeze VAE) - 'e2e': End-to-end fine-tuning - 'distill': Consistency distillation from teacher """ def __init__(self, config: Optional[Dict[str, Any]] = None): super().__init__() config = config or self.default_config() self.config = config # VAE self.vae = CompactVAE( in_channels=3, latent_channels=config['latent_channels'], encoder_base_ch=config.get('encoder_base_ch', 64), decoder_base_ch=config.get('decoder_base_ch', 128), ) # Text encoder self.text_encoder = SimpleTextEncoder( vocab_size=config.get('vocab_size', 32000), max_length=config.get('max_text_length', 77), dim=config['cond_dim'], num_layers=config.get('text_layers', 4), num_heads=config.get('text_heads', 4), ) # Denoising core self.core = RecursiveLatentCore( dim=config['latent_channels'], cond_dim=config['cond_dim'], num_blocks=config['num_blocks'], num_heads=config.get('num_heads', 6), head_dim=config.get('head_dim', 64), T_inner=config.get('T_inner', 4), T_outer=config.get('T_outer', 2), ffn_mult=config.get('ffn_mult', 2.67), dropout=config.get('dropout', 0.0), use_ift_training=config.get('use_ift', True), ) # Latent scaling (learnable, stabilizes training) self.latent_scale = nn.Parameter(torch.tensor(1.0)) @staticmethod def default_config(): """Default config targeting ~50M params, trainable on 16GB.""" return { 'latent_channels': 32, 'cond_dim': 256, 'num_blocks': 4, 'num_heads': 4, 'head_dim': 64, 'T_inner': 4, 'T_outer': 2, 'ffn_mult': 2.67, 'dropout': 0.0, 'use_ift': True, 'encoder_base_ch': 64, 'decoder_base_ch': 128, 'vocab_size': 32000, 'max_text_length': 77, 'text_layers': 4, 'text_heads': 4, } @staticmethod def tiny_config(): """Tiny config for quick testing.""" return { 'latent_channels': 16, 'cond_dim': 128, 'num_blocks': 2, 'num_heads': 2, 'head_dim': 32, 'T_inner': 2, 'T_outer': 1, 'ffn_mult': 2.0, 'dropout': 0.0, 'use_ift': False, 'encoder_base_ch': 32, 'decoder_base_ch': 64, 'vocab_size': 32000, 'max_text_length': 77, 'text_layers': 2, 'text_heads': 2, } def encode_image(self, x): """Encode image to latent space.""" z, mean, logvar = self.vae.encode(x) return z * self.latent_scale, mean, logvar def decode_latent(self, z): """Decode latent to image.""" return self.vae.decode(z / self.latent_scale) def encode_text(self, token_ids, attention_mask=None): """Encode text to conditioning vectors.""" return self.text_encoder(token_ids, attention_mask) def predict_velocity(self, z_t, t, text_emb=None, text_global=None, image_cond=None): """Predict velocity for rectified flow.""" return self.core(z_t, t, text_emb, text_global, image_cond) def get_param_groups(self): """Return parameter groups for staged training.""" return { 'vae_encoder': list(self.vae.encoder.parameters()), 'vae_decoder': list(self.vae.decoder.parameters()), 'text_encoder': list(self.text_encoder.parameters()), 'core': list(self.core.parameters()), 'latent_scale': [self.latent_scale], } def count_parameters(self): """Count parameters per module.""" counts = {} for name, module in [ ('vae_encoder', self.vae.encoder), ('vae_decoder', self.vae.decoder), ('text_encoder', self.text_encoder), ('core', self.core), ]: counts[name] = sum(p.numel() for p in module.parameters()) counts['latent_scale'] = 1 counts['total'] = sum(counts.values()) return counts def forward(self, x=None, token_ids=None, attention_mask=None, **kwargs): """Full forward pass for training. See training script for usage.""" raise NotImplementedError( "Use the training pipeline functions instead of calling forward() directly. " "See LRFTrainer for VAE training, denoiser training, and distillation." )