| """ |
| LatentRecurrentFlow (LRF) - Core Architecture Modules |
| |
| Architecture Overview: |
| ===================== |
| The LRF architecture consists of 4 main components: |
| |
| 1. CompactEncoder/Decoder (VAE): f=32 spatial compression with tiny decoder |
| 2. TextConditioner: Lightweight text encoding (TinyCLIP or small LM) |
| 3. RecursiveLatentCore: The novel HRM-inspired denoising backbone |
| 4. FlowScheduler: Rectified flow for training and sampling |
| |
| The RecursiveLatentCore is the key innovation: |
| - It contains N_blocks GLD (Gated Linear Diffusion) blocks |
| - These blocks are applied recursively T_outer * T_inner times |
| - The same parameters are reused across recursions (weight sharing) |
| - Training uses IFT (Implicit Function Theorem) for O(1) memory backprop |
| - This gives effective depth of T_outer * T_inner * N_blocks layers |
| from only N_blocks parameter sets |
| |
| Memory budget at inference (1024x1024, INT8): |
| - Text encoder: ~150MB (TinyCLIP-ViT-B/16) |
| - VAE encoder: ~100MB (f32 encoder, only needed for editing) |
| - VAE decoder: ~6MB (SnapGen-style tiny decoder) |
| - LRF core: ~200-400MB (depending on config) |
| - Activations: ~500MB peak |
| - Total: ~1-1.5GB model + ~500MB activations = 1.5-2GB |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange, repeat |
| from typing import Optional, Tuple, Dict, Any |
|
|
|
|
| |
| |
| |
|
|
| class RMSNorm(nn.Module): |
| """RMSNorm - more stable than LayerNorm for small models.""" |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
| |
| def forward(self, x): |
| norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() |
| return (x.float() * norm).type_as(x) * self.weight |
|
|
|
|
| class SwiGLU(nn.Module): |
| """SwiGLU FFN - better than GELU for small models, mobile-friendly (SiLU not GELU).""" |
| def __init__(self, dim: int, hidden_dim: Optional[int] = None, dropout: float = 0.0): |
| super().__init__() |
| hidden_dim = hidden_dim or int(dim * 8 / 3) |
| |
| hidden_dim = ((hidden_dim + 7) // 8) * 8 |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x): |
| return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) |
|
|
|
|
| class DepthwiseSeparableConv2d(nn.Module): |
| """Mobile-optimized convolution.""" |
| def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3): |
| super().__init__() |
| padding = kernel_size // 2 |
| self.dw = nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding, groups=in_channels, bias=False) |
| self.pw = nn.Conv2d(in_channels, out_channels, 1, bias=False) |
| |
| def forward(self, x): |
| return self.pw(self.dw(x)) |
|
|
|
|
| |
| |
| |
|
|
| class RotaryPositionEncoding2D(nn.Module): |
| """2D RoPE for spatial tokens - resolution-independent.""" |
| def __init__(self, dim: int, max_res: int = 64): |
| super().__init__() |
| self.dim = dim |
| half_dim = dim // 4 |
| freqs = torch.exp(torch.arange(half_dim) * -(math.log(10000.0) / half_dim)) |
| self.register_buffer('freqs', freqs) |
| |
| def forward(self, h: int, w: int, device=None): |
| device = device or self.freqs.device |
| pos_h = torch.arange(h, device=device).float() |
| pos_w = torch.arange(w, device=device).float() |
| |
| freqs_h = torch.outer(pos_h, self.freqs.to(device)) |
| freqs_w = torch.outer(pos_w, self.freqs.to(device)) |
| |
| |
| freqs_h = freqs_h.unsqueeze(1).expand(-1, w, -1) |
| freqs_w = freqs_w.unsqueeze(0).expand(h, -1, -1) |
| |
| |
| freqs = torch.cat([freqs_h, freqs_w], dim=-1) |
| |
| sin_enc = freqs.sin() |
| cos_enc = freqs.cos() |
| |
| return sin_enc.reshape(h * w, -1), cos_enc.reshape(h * w, -1) |
|
|
|
|
| def apply_rope_2d(x, sin_enc, cos_enc): |
| """Apply 2D RoPE to queries/keys.""" |
| d = x.shape[-1] |
| half_d = d // 2 |
| x1, x2 = x[..., :half_d], x[..., half_d:] |
| |
| while sin_enc.dim() < x1.dim(): |
| sin_enc = sin_enc.unsqueeze(0) |
| cos_enc = cos_enc.unsqueeze(0) |
| return torch.cat([x1 * cos_enc - x2 * sin_enc, x2 * cos_enc + x1 * sin_enc], dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class GatedLinearAttention(nn.Module): |
| """ |
| Gated Linear Attention for 2D spatial mixing. |
| O(N) complexity instead of O(N²) softmax attention. |
| |
| Based on ViG/GLA research but adapted for diffusion: |
| - Bidirectional scan (forward + backward) |
| - 2D locality injection via depthwise conv gating |
| - Token-differential operator to prevent oversmoothing (from DyDiLA) |
| |
| Math: |
| Q, K, V = linear(x), linear(x), linear(x) |
| Q = phi(Q), K = phi(K) where phi = 1 + elu (non-negative feature map) |
| |
| Forward scan: S_i = decay * S_{i-1} + K_i^T V_i; O_i = Q_i S_i |
| Backward scan: same in reverse |
| |
| Output = gate * (O_fwd + O_bwd) * local_gate |
| |
| Complexity: O(N * d²) where d is head dimension, N is sequence length |
| """ |
| def __init__(self, dim: int, num_heads: int = 8, head_dim: int = 32, dropout: float = 0.0): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = head_dim |
| inner_dim = num_heads * head_dim |
| |
| self.qkv = nn.Linear(dim, 3 * inner_dim, bias=False) |
| self.out_proj = nn.Linear(inner_dim, dim, bias=False) |
| |
| |
| self.log_decay = nn.Parameter(torch.zeros(num_heads)) |
| |
| |
| self.gate = nn.Linear(dim, inner_dim, bias=False) |
| |
| |
| self.local_conv = nn.Conv2d(inner_dim, inner_dim, 3, padding=1, groups=inner_dim, bias=False) |
| self.local_gate = nn.Linear(dim, inner_dim, bias=False) |
| |
| |
| self.diff_lambda = nn.Parameter(torch.tensor(0.1)) |
| |
| self.dropout = nn.Dropout(dropout) |
| self.norm = RMSNorm(inner_dim) |
| |
| def _feature_map(self, x): |
| """Non-negative feature map: 1 + elu(x)""" |
| return 1.0 + F.elu(x) |
| |
| def _scan(self, Q, K, V, reverse=False): |
| """Linear recurrent scan - O(N * d²) per direction.""" |
| B, H, N, D = Q.shape |
| |
| decay = torch.sigmoid(self.log_decay).view(1, H, 1, 1) |
| |
| if reverse: |
| Q = Q.flip(2) |
| K = K.flip(2) |
| V = V.flip(2) |
| |
| |
| chunk_size = min(64, N) |
| outputs = [] |
| S = torch.zeros(B, H, D, D, device=Q.device, dtype=Q.dtype) |
| |
| for i in range(0, N, chunk_size): |
| q_chunk = Q[:, :, i:i+chunk_size] |
| k_chunk = K[:, :, i:i+chunk_size] |
| v_chunk = V[:, :, i:i+chunk_size] |
| |
| chunk_len = q_chunk.shape[2] |
| |
| |
| kv = torch.einsum('bhcd,bhce->bhde', k_chunk, v_chunk) |
| S = decay * S + kv |
| |
| |
| o_chunk = torch.einsum('bhcd,bhde->bhce', q_chunk, S) |
| outputs.append(o_chunk) |
| |
| output = torch.cat(outputs, dim=2) |
| |
| if reverse: |
| output = output.flip(2) |
| |
| return output |
| |
| def forward(self, x, h: int, w: int): |
| """ |
| Args: |
| x: [B, N, D] where N = H*W |
| h, w: spatial dimensions |
| Returns: |
| [B, N, D] |
| """ |
| B, N, D = x.shape |
| |
| |
| qkv = self.qkv(x) |
| q, k, v = qkv.chunk(3, dim=-1) |
| |
| |
| q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads) |
| k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads) |
| v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads) |
| |
| |
| |
| lam = torch.sigmoid(self.diff_lambda) |
| q_shifted = F.pad(q[:, :, :-1], (0, 0, 1, 0)) |
| k_shifted = F.pad(k[:, :, :-1], (0, 0, 1, 0)) |
| q = q - lam * q_shifted |
| k = k - lam * k_shifted |
| |
| |
| q = self._feature_map(q) |
| k = self._feature_map(k) |
| |
| |
| o_fwd = self._scan(q, k, v, reverse=False) |
| o_bwd = self._scan(q, k, v, reverse=True) |
| output = o_fwd + o_bwd |
| |
| |
| output = rearrange(output, 'b h n d -> b n (h d)') |
| output = self.norm(output) |
| |
| |
| x_2d = rearrange(x, 'b (h w) d -> b d h w', h=h, w=w) |
| gate_input = rearrange(x, 'b n d -> b n d') |
| local_feat = self.local_conv(rearrange(self.local_gate(gate_input), 'b (h w) d -> b d h w', h=h, w=w)) |
| local_feat = rearrange(local_feat, 'b d h w -> b (h w) d') |
| |
| |
| g = torch.sigmoid(self.gate(x)) |
| output = g * output * torch.sigmoid(local_feat) |
| |
| return self.dropout(self.out_proj(output)) |
|
|
|
|
| class GLDBlock(nn.Module): |
| """ |
| Gated Linear Diffusion Block. |
| |
| Components: |
| 1. GatedLinearAttention for spatial mixing (O(N) complexity) |
| 2. SwiGLU FFN for channel mixing |
| 3. Timestep + condition modulation (adaptive layer norm) |
| 4. 2D RoPE for position encoding |
| |
| This replaces the standard transformer block in diffusion models. |
| """ |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int = 8, |
| head_dim: int = 32, |
| ffn_mult: float = 2.67, |
| dropout: float = 0.0, |
| cond_dim: int = 256, |
| ): |
| super().__init__() |
| self.norm1 = RMSNorm(dim) |
| self.norm2 = RMSNorm(dim) |
| |
| self.attn = GatedLinearAttention(dim, num_heads, head_dim, dropout) |
| self.ffn = SwiGLU(dim, int(dim * ffn_mult), dropout) |
| |
| |
| |
| self.adaLN_modulation = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(cond_dim, 6 * dim, bias=False), |
| ) |
| |
| |
| self.cross_norm = RMSNorm(dim) |
| self.cross_q = nn.Linear(dim, dim, bias=False) |
| self.cross_kv = nn.Linear(cond_dim, 2 * dim, bias=False) |
| self.cross_out = nn.Linear(dim, dim, bias=False) |
| self.cross_gate = nn.Parameter(torch.zeros(1)) |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| cond: torch.Tensor, |
| text_ctx: Optional[torch.Tensor] = None, |
| h: int = 32, |
| w: int = 32, |
| ) -> torch.Tensor: |
| B, N, D = x.shape |
| |
| |
| mod = self.adaLN_modulation(cond) |
| shift1, scale1, gate1, shift2, scale2, gate2 = mod.chunk(6, dim=-1) |
| |
| |
| x_norm = self.norm1(x) |
| x_norm = x_norm * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1) |
| x = x + gate1.unsqueeze(1) * self.attn(x_norm, h, w) |
| |
| |
| if text_ctx is not None: |
| x_cross = self.cross_norm(x) |
| q = self.cross_q(x_cross) |
| kv = self.cross_kv(text_ctx) |
| k, v = kv.chunk(2, dim=-1) |
| |
| |
| scale = q.shape[-1] ** -0.5 |
| attn_weights = torch.bmm(q, k.transpose(-2, -1)) * scale |
| attn_weights = F.softmax(attn_weights, dim=-1) |
| cross_out = torch.bmm(attn_weights, v) |
| x = x + torch.tanh(self.cross_gate) * self.cross_out(cross_out) |
| |
| |
| x_norm = self.norm2(x) |
| x_norm = x_norm * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1) |
| x = x + gate2.unsqueeze(1) * self.ffn(x_norm) |
| |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class RecursiveLatentCore(nn.Module): |
| """ |
| The Recursive Latent Refinement (RLR) Core. |
| |
| This is the key architectural innovation of LRF. Instead of stacking |
| many unique transformer layers (like DiT with 28 layers), we use a |
| small set of GLD blocks applied RECURSIVELY through an HRM-inspired |
| iterative refinement loop. |
| |
| Architecture: |
| - N_blocks GLD blocks (typically 4-6, shared across recursions) |
| - T_inner recursive applications per outer step (typically 4-6) |
| - T_outer outer steps with slow abstract state update (typically 2-3) |
| |
| Effective depth: T_outer * T_inner * N_blocks = 2*4*4 = 32 effective layers |
| Actual parameters: only N_blocks sets = 4 unique block parameter sets |
| |
| Training uses IFT (Implicit Function Theorem): |
| - Forward: run full recursion with torch.no_grad() for warmup |
| - Backward: only backprop through the LAST recursion step |
| - This gives O(1) memory cost regardless of recursion depth! |
| |
| Mathematical formulation: |
| |
| Let z be the noisy latent, c be the condition embedding. |
| |
| Outer loop (j = 1..T_outer): |
| z_abstract = f_slow(z, c) # Abstract planning update |
| Inner loop (i = 1..T_inner): |
| z = f_blocks(z, z_abstract, c) # Apply N shared GLD blocks |
| |
| Where f_blocks applies the same N GLD blocks in sequence. |
| |
| The model learns a FIXED POINT: z* = f(z*, c) |
| At convergence, the output is the denoised prediction v(z_t, t, c). |
| """ |
| |
| def __init__( |
| self, |
| dim: int = 384, |
| cond_dim: int = 256, |
| num_blocks: int = 4, |
| num_heads: int = 6, |
| head_dim: int = 64, |
| T_inner: int = 4, |
| T_outer: int = 2, |
| ffn_mult: float = 2.67, |
| dropout: float = 0.0, |
| use_ift_training: bool = True, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.cond_dim = cond_dim |
| self.num_blocks = num_blocks |
| self.T_inner = T_inner |
| self.T_outer = T_outer |
| self.use_ift_training = use_ift_training |
| |
| |
| self.blocks = nn.ModuleList([ |
| GLDBlock( |
| dim=dim, |
| num_heads=num_heads, |
| head_dim=head_dim, |
| ffn_mult=ffn_mult, |
| dropout=dropout, |
| cond_dim=cond_dim, |
| ) |
| for _ in range(num_blocks) |
| ]) |
| |
| |
| |
| self.abstract_norm = RMSNorm(dim) |
| self.abstract_update = nn.Sequential( |
| nn.Linear(dim * 2, dim, bias=False), |
| nn.SiLU(), |
| nn.Linear(dim, dim, bias=False), |
| ) |
| self.abstract_gate = nn.Parameter(torch.zeros(1)) |
| |
| |
| self.input_proj = nn.Linear(dim, dim, bias=False) |
| |
| |
| self.time_embed = nn.Sequential( |
| nn.Linear(256, cond_dim), |
| nn.SiLU(), |
| nn.Linear(cond_dim, cond_dim), |
| ) |
| |
| |
| self.out_norm = RMSNorm(dim) |
| self.out_proj = nn.Sequential( |
| nn.Linear(dim, dim, bias=False), |
| nn.SiLU(), |
| nn.Linear(dim, dim, bias=False), |
| ) |
| |
| |
| self.recursion_embed = nn.Embedding(T_outer * T_inner + 1, cond_dim) |
| |
| |
| self.rope = RotaryPositionEncoding2D(head_dim) |
| |
| def _sinusoidal_embedding(self, t: torch.Tensor, dim: int = 256) -> torch.Tensor: |
| """Sinusoidal timestep embedding.""" |
| half_dim = dim // 2 |
| emb = math.log(10000) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) |
| emb = t.unsqueeze(-1) * emb.unsqueeze(0) |
| return torch.cat([emb.sin(), emb.cos()], dim=-1) |
| |
| def _apply_blocks(self, z, cond, text_ctx, h, w): |
| """Apply all GLD blocks once.""" |
| for block in self.blocks: |
| z = block(z, cond, text_ctx, h, w) |
| return z |
| |
| def _recursive_refinement(self, z, cond_base, text_ctx, h, w): |
| """ |
| Full recursive refinement loop. |
| |
| Returns the refined latent z after T_outer * T_inner applications. |
| """ |
| z_abstract = z.mean(dim=1, keepdim=True).expand_as(z) |
| |
| step_idx = 0 |
| for j in range(self.T_outer): |
| |
| z_pooled = z.mean(dim=1, keepdim=True).expand_as(z) |
| abstract_input = torch.cat([self.abstract_norm(z), z_pooled], dim=-1) |
| z_abstract = z_abstract + torch.tanh(self.abstract_gate) * self.abstract_update(abstract_input) |
| |
| for i in range(self.T_inner): |
| |
| rec_emb = self.recursion_embed( |
| torch.tensor([step_idx], device=z.device) |
| ).expand(z.shape[0], -1) |
| cond = cond_base + rec_emb |
| |
| |
| z_input = z + z_abstract |
| z = z + (self._apply_blocks(z_input, cond, text_ctx, h, w) - z) * 0.5 |
| |
| step_idx += 1 |
| |
| return z |
| |
| def forward( |
| self, |
| z_t: torch.Tensor, |
| t: torch.Tensor, |
| text_emb: Optional[torch.Tensor] = None, |
| text_global: Optional[torch.Tensor] = None, |
| image_cond: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Forward pass predicting velocity v_theta(z_t, t, c). |
| |
| For rectified flow: z_t = (1-t) * z_0 + t * epsilon |
| Target: v = epsilon - z_0 |
| """ |
| B, C, H, W = z_t.shape |
| |
| |
| z = rearrange(z_t, 'b c h w -> b (h w) c') |
| |
| |
| if image_cond is not None: |
| img_cond_flat = rearrange(image_cond, 'b c h w -> b (h w) c') |
| z = z + img_cond_flat |
| |
| |
| z = self.input_proj(z) |
| |
| |
| t_emb = self._sinusoidal_embedding(t) |
| t_emb = self.time_embed(t_emb) |
| |
| if text_global is not None: |
| cond = t_emb + text_global |
| else: |
| cond = t_emb |
| |
| |
| if self.training and self.use_ift_training: |
| |
| with torch.no_grad(): |
| for _ in range(self.T_outer - 1): |
| z = self._recursive_refinement(z, cond, text_emb, H, W) |
| |
| z = self._recursive_refinement(z, cond, text_emb, H, W) |
| else: |
| |
| z = self._recursive_refinement(z, cond, text_emb, H, W) |
| |
| |
| z = self.out_norm(z) |
| v = self.out_proj(z) |
| |
| |
| v = rearrange(v, 'b (h w) c -> b c h w', h=H, w=W) |
| |
| return v |
|
|
|
|
| |
| |
| |
|
|
| class TinyResBlock(nn.Module): |
| """Ultra-compact residual block for tiny decoder.""" |
| def __init__(self, in_channels: int, out_channels: int = None): |
| super().__init__() |
| out_channels = out_channels or in_channels |
| self.norm1 = nn.GroupNorm(min(8, in_channels), in_channels) |
| self.conv1 = DepthwiseSeparableConv2d(in_channels, out_channels, 3) |
| self.norm2 = nn.GroupNorm(min(8, out_channels), out_channels) |
| self.conv2 = DepthwiseSeparableConv2d(out_channels, out_channels, 3) |
| self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) if in_channels != out_channels else nn.Identity() |
| |
| def forward(self, x): |
| h = self.conv1(F.silu(self.norm1(x))) |
| h = self.conv2(F.silu(self.norm2(h))) |
| return self.skip(x) + h |
|
|
|
|
| class CompactEncoder(nn.Module): |
| """ |
| Compact image encoder: image -> latent space. |
| f=16 spatial compression, C_latent channels. |
| |
| Uses strided depthwise-separable convolutions for efficiency. |
| 4 downsampling stages: 256->128->64->32->16 (for 256x256 input) |
| """ |
| def __init__( |
| self, |
| in_channels: int = 3, |
| latent_channels: int = 32, |
| base_channels: int = 64, |
| num_res_blocks: int = 2, |
| ): |
| super().__init__() |
| channels = [base_channels, base_channels * 2, base_channels * 4, base_channels * 4] |
| |
| self.stem = nn.Conv2d(in_channels, channels[0], 3, padding=1, bias=False) |
| |
| self.downs = nn.ModuleList() |
| ch_in = channels[0] |
| for ch_out in channels: |
| blocks = nn.ModuleList() |
| |
| blocks.append(TinyResBlock(ch_in, ch_out)) |
| for _ in range(num_res_blocks - 1): |
| blocks.append(TinyResBlock(ch_out, ch_out)) |
| |
| down = nn.Conv2d(ch_out, ch_out, 4, stride=2, padding=1, bias=False) |
| self.downs.append(nn.ModuleDict({ |
| 'blocks': blocks, |
| 'down': down, |
| })) |
| ch_in = ch_out |
| |
| |
| self.to_latent = nn.Sequential( |
| nn.GroupNorm(8, ch_in), |
| nn.SiLU(), |
| nn.Conv2d(ch_in, latent_channels * 2, 1, bias=False), |
| ) |
| |
| def forward(self, x): |
| h = self.stem(x) |
| for down_module in self.downs: |
| for block in down_module['blocks']: |
| h = block(h) |
| h = down_module['down'](h) |
| |
| params = self.to_latent(h) |
| mean, logvar = params.chunk(2, dim=1) |
| logvar = torch.clamp(logvar, -30.0, 20.0) |
| |
| return mean, logvar |
|
|
|
|
| class TinyDecoder(nn.Module): |
| """ |
| SnapGen-inspired tiny decoder: latent -> image. |
| ~1-2M parameters. No attention layers. |
| Uses depthwise-separable convolutions + minimal GroupNorm. |
| |
| 4 upsampling stages matching the encoder. |
| """ |
| def __init__( |
| self, |
| latent_channels: int = 32, |
| out_channels: int = 3, |
| base_channels: int = 128, |
| num_res_blocks: int = 2, |
| ): |
| super().__init__() |
| channels = [base_channels * 2, base_channels * 2, base_channels, base_channels // 2] |
| |
| self.from_latent = nn.Conv2d(latent_channels, channels[0], 1, bias=False) |
| |
| self.ups = nn.ModuleList() |
| ch_in = channels[0] |
| for ch_out in channels: |
| blocks = nn.ModuleList() |
| for _ in range(num_res_blocks): |
| blocks.append(TinyResBlock(ch_in, ch_in)) |
| |
| up = nn.Sequential( |
| nn.Upsample(scale_factor=2, mode='nearest'), |
| DepthwiseSeparableConv2d(ch_in, ch_out, 3), |
| ) |
| self.ups.append(nn.ModuleDict({ |
| 'blocks': blocks, |
| 'up': up, |
| })) |
| ch_in = ch_out |
| |
| self.to_image = nn.Sequential( |
| nn.GroupNorm(min(8, ch_in), ch_in), |
| nn.SiLU(), |
| nn.Conv2d(ch_in, out_channels, 3, padding=1), |
| nn.Tanh(), |
| ) |
| |
| def forward(self, z): |
| h = self.from_latent(z) |
| for up_module in self.ups: |
| for block in up_module['blocks']: |
| h = block(h) |
| h = up_module['up'](h) |
| return self.to_image(h) |
|
|
|
|
| class CompactVAE(nn.Module): |
| """ |
| Complete VAE with compact encoder + tiny decoder. |
| f=16 compression, configurable latent channels. |
| """ |
| def __init__( |
| self, |
| in_channels: int = 3, |
| latent_channels: int = 32, |
| encoder_base_ch: int = 64, |
| decoder_base_ch: int = 128, |
| ): |
| super().__init__() |
| self.encoder = CompactEncoder(in_channels, latent_channels, encoder_base_ch) |
| self.decoder = TinyDecoder(latent_channels, in_channels, decoder_base_ch) |
| self.latent_channels = latent_channels |
| |
| def encode(self, x): |
| mean, logvar = self.encoder(x) |
| if self.training: |
| std = torch.exp(0.5 * logvar) |
| eps = torch.randn_like(std) |
| z = mean + eps * std |
| else: |
| z = mean |
| return z, mean, logvar |
| |
| def decode(self, z): |
| return self.decoder(z) |
| |
| def forward(self, x): |
| z, mean, logvar = self.encode(x) |
| recon = self.decode(z) |
| return recon, mean, logvar |
|
|
|
|
| |
| |
| |
|
|
| class SimpleTextEncoder(nn.Module): |
| """ |
| Lightweight text encoder for the standalone prototype. |
| In production, this would be replaced by TinyCLIP or a small LM. |
| |
| For the prototype: simple learned embeddings + small transformer. |
| This lets us test the full pipeline without a heavy text encoder. |
| """ |
| def __init__( |
| self, |
| vocab_size: int = 32000, |
| max_length: int = 77, |
| dim: int = 256, |
| num_layers: int = 4, |
| num_heads: int = 4, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.token_embed = nn.Embedding(vocab_size, dim) |
| self.pos_embed = nn.Embedding(max_length, dim) |
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=dim, nhead=num_heads, dim_feedforward=dim*4, |
| dropout=0.1, activation='gelu', batch_first=True, norm_first=True |
| ) |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
| self.norm = RMSNorm(dim) |
| |
| |
| self.global_proj = nn.Sequential( |
| nn.Linear(dim, dim), |
| nn.SiLU(), |
| nn.Linear(dim, dim), |
| ) |
| |
| def forward(self, token_ids, attention_mask=None): |
| B, T = token_ids.shape |
| pos_ids = torch.arange(T, device=token_ids.device).unsqueeze(0).expand(B, -1) |
| |
| x = self.token_embed(token_ids) + self.pos_embed(pos_ids) |
| |
| if attention_mask is not None: |
| |
| src_key_padding_mask = ~attention_mask.bool() |
| else: |
| src_key_padding_mask = None |
| |
| x = self.transformer(x, src_key_padding_mask=src_key_padding_mask) |
| x = self.norm(x) |
| |
| |
| if attention_mask is not None: |
| mask = attention_mask.unsqueeze(-1).float() |
| global_emb = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) |
| else: |
| global_emb = x.mean(dim=1) |
| |
| global_emb = self.global_proj(global_emb) |
| |
| return x, global_emb |
|
|
|
|
| |
| |
| |
|
|
| class LatentRecurrentFlow(nn.Module): |
| """ |
| LatentRecurrentFlow (LRF) - Complete model. |
| |
| Combines: |
| 1. CompactVAE for image encoding/decoding |
| 2. SimpleTextEncoder for text conditioning |
| 3. RecursiveLatentCore for denoising |
| |
| Training modes: |
| - 'vae': Train only the VAE |
| - 'denoise': Train only the denoising core (freeze VAE) |
| - 'e2e': End-to-end fine-tuning |
| - 'distill': Consistency distillation from teacher |
| """ |
| |
| def __init__(self, config: Optional[Dict[str, Any]] = None): |
| super().__init__() |
| |
| config = config or self.default_config() |
| self.config = config |
| |
| |
| self.vae = CompactVAE( |
| in_channels=3, |
| latent_channels=config['latent_channels'], |
| encoder_base_ch=config.get('encoder_base_ch', 64), |
| decoder_base_ch=config.get('decoder_base_ch', 128), |
| ) |
| |
| |
| self.text_encoder = SimpleTextEncoder( |
| vocab_size=config.get('vocab_size', 32000), |
| max_length=config.get('max_text_length', 77), |
| dim=config['cond_dim'], |
| num_layers=config.get('text_layers', 4), |
| num_heads=config.get('text_heads', 4), |
| ) |
| |
| |
| self.core = RecursiveLatentCore( |
| dim=config['latent_channels'], |
| cond_dim=config['cond_dim'], |
| num_blocks=config['num_blocks'], |
| num_heads=config.get('num_heads', 6), |
| head_dim=config.get('head_dim', 64), |
| T_inner=config.get('T_inner', 4), |
| T_outer=config.get('T_outer', 2), |
| ffn_mult=config.get('ffn_mult', 2.67), |
| dropout=config.get('dropout', 0.0), |
| use_ift_training=config.get('use_ift', True), |
| ) |
| |
| |
| self.latent_scale = nn.Parameter(torch.tensor(1.0)) |
| |
| @staticmethod |
| def default_config(): |
| """Default config targeting ~50M params, trainable on 16GB.""" |
| return { |
| 'latent_channels': 32, |
| 'cond_dim': 256, |
| 'num_blocks': 4, |
| 'num_heads': 4, |
| 'head_dim': 64, |
| 'T_inner': 4, |
| 'T_outer': 2, |
| 'ffn_mult': 2.67, |
| 'dropout': 0.0, |
| 'use_ift': True, |
| 'encoder_base_ch': 64, |
| 'decoder_base_ch': 128, |
| 'vocab_size': 32000, |
| 'max_text_length': 77, |
| 'text_layers': 4, |
| 'text_heads': 4, |
| } |
| |
| @staticmethod |
| def tiny_config(): |
| """Tiny config for quick testing.""" |
| return { |
| 'latent_channels': 16, |
| 'cond_dim': 128, |
| 'num_blocks': 2, |
| 'num_heads': 2, |
| 'head_dim': 32, |
| 'T_inner': 2, |
| 'T_outer': 1, |
| 'ffn_mult': 2.0, |
| 'dropout': 0.0, |
| 'use_ift': False, |
| 'encoder_base_ch': 32, |
| 'decoder_base_ch': 64, |
| 'vocab_size': 32000, |
| 'max_text_length': 77, |
| 'text_layers': 2, |
| 'text_heads': 2, |
| } |
| |
| def encode_image(self, x): |
| """Encode image to latent space.""" |
| z, mean, logvar = self.vae.encode(x) |
| return z * self.latent_scale, mean, logvar |
| |
| def decode_latent(self, z): |
| """Decode latent to image.""" |
| return self.vae.decode(z / self.latent_scale) |
| |
| def encode_text(self, token_ids, attention_mask=None): |
| """Encode text to conditioning vectors.""" |
| return self.text_encoder(token_ids, attention_mask) |
| |
| def predict_velocity(self, z_t, t, text_emb=None, text_global=None, image_cond=None): |
| """Predict velocity for rectified flow.""" |
| return self.core(z_t, t, text_emb, text_global, image_cond) |
| |
| def get_param_groups(self): |
| """Return parameter groups for staged training.""" |
| return { |
| 'vae_encoder': list(self.vae.encoder.parameters()), |
| 'vae_decoder': list(self.vae.decoder.parameters()), |
| 'text_encoder': list(self.text_encoder.parameters()), |
| 'core': list(self.core.parameters()), |
| 'latent_scale': [self.latent_scale], |
| } |
| |
| def count_parameters(self): |
| """Count parameters per module.""" |
| counts = {} |
| for name, module in [ |
| ('vae_encoder', self.vae.encoder), |
| ('vae_decoder', self.vae.decoder), |
| ('text_encoder', self.text_encoder), |
| ('core', self.core), |
| ]: |
| counts[name] = sum(p.numel() for p in module.parameters()) |
| counts['latent_scale'] = 1 |
| counts['total'] = sum(counts.values()) |
| return counts |
| |
| def forward(self, x=None, token_ids=None, attention_mask=None, **kwargs): |
| """Full forward pass for training. See training script for usage.""" |
| raise NotImplementedError( |
| "Use the training pipeline functions instead of calling forward() directly. " |
| "See LRFTrainer for VAE training, denoiser training, and distillation." |
| ) |
|
|