krystv's picture
Upload lrf/model.py with huggingface_hub
4f47596 verified
"""
LatentRecurrentFlow (LRF) - Core Architecture Modules
Architecture Overview:
=====================
The LRF architecture consists of 4 main components:
1. CompactEncoder/Decoder (VAE): f=32 spatial compression with tiny decoder
2. TextConditioner: Lightweight text encoding (TinyCLIP or small LM)
3. RecursiveLatentCore: The novel HRM-inspired denoising backbone
4. FlowScheduler: Rectified flow for training and sampling
The RecursiveLatentCore is the key innovation:
- It contains N_blocks GLD (Gated Linear Diffusion) blocks
- These blocks are applied recursively T_outer * T_inner times
- The same parameters are reused across recursions (weight sharing)
- Training uses IFT (Implicit Function Theorem) for O(1) memory backprop
- This gives effective depth of T_outer * T_inner * N_blocks layers
from only N_blocks parameter sets
Memory budget at inference (1024x1024, INT8):
- Text encoder: ~150MB (TinyCLIP-ViT-B/16)
- VAE encoder: ~100MB (f32 encoder, only needed for editing)
- VAE decoder: ~6MB (SnapGen-style tiny decoder)
- LRF core: ~200-400MB (depending on config)
- Activations: ~500MB peak
- Total: ~1-1.5GB model + ~500MB activations = 1.5-2GB
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from typing import Optional, Tuple, Dict, Any
# ============================================================================
# Utility Modules
# ============================================================================
class RMSNorm(nn.Module):
"""RMSNorm - more stable than LayerNorm for small models."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * norm).type_as(x) * self.weight
class SwiGLU(nn.Module):
"""SwiGLU FFN - better than GELU for small models, mobile-friendly (SiLU not GELU)."""
def __init__(self, dim: int, hidden_dim: Optional[int] = None, dropout: float = 0.0):
super().__init__()
hidden_dim = hidden_dim or int(dim * 8 / 3)
# Round to nearest multiple of 8 for efficiency
hidden_dim = ((hidden_dim + 7) // 8) * 8
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class DepthwiseSeparableConv2d(nn.Module):
"""Mobile-optimized convolution."""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
super().__init__()
padding = kernel_size // 2
self.dw = nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding, groups=in_channels, bias=False)
self.pw = nn.Conv2d(in_channels, out_channels, 1, bias=False)
def forward(self, x):
return self.pw(self.dw(x))
# ============================================================================
# 2D Positional Encoding
# ============================================================================
class RotaryPositionEncoding2D(nn.Module):
"""2D RoPE for spatial tokens - resolution-independent."""
def __init__(self, dim: int, max_res: int = 64):
super().__init__()
self.dim = dim
half_dim = dim // 4 # Split into 4 parts: sin_h, cos_h, sin_w, cos_w
freqs = torch.exp(torch.arange(half_dim) * -(math.log(10000.0) / half_dim))
self.register_buffer('freqs', freqs)
def forward(self, h: int, w: int, device=None):
device = device or self.freqs.device
pos_h = torch.arange(h, device=device).float()
pos_w = torch.arange(w, device=device).float()
freqs_h = torch.outer(pos_h, self.freqs.to(device)) # [H, D/4]
freqs_w = torch.outer(pos_w, self.freqs.to(device)) # [W, D/4]
# Expand to [H, W, D/4] each
freqs_h = freqs_h.unsqueeze(1).expand(-1, w, -1)
freqs_w = freqs_w.unsqueeze(0).expand(h, -1, -1)
# Concatenate: [H, W, D/2] for sin, [H, W, D/2] for cos
freqs = torch.cat([freqs_h, freqs_w], dim=-1) # [H, W, D/2]
sin_enc = freqs.sin()
cos_enc = freqs.cos()
return sin_enc.reshape(h * w, -1), cos_enc.reshape(h * w, -1)
def apply_rope_2d(x, sin_enc, cos_enc):
"""Apply 2D RoPE to queries/keys."""
d = x.shape[-1]
half_d = d // 2
x1, x2 = x[..., :half_d], x[..., half_d:]
# Expand sin/cos to match batch dims
while sin_enc.dim() < x1.dim():
sin_enc = sin_enc.unsqueeze(0)
cos_enc = cos_enc.unsqueeze(0)
return torch.cat([x1 * cos_enc - x2 * sin_enc, x2 * cos_enc + x1 * sin_enc], dim=-1)
# ============================================================================
# Gated Linear Diffusion (GLD) Block - The Core Spatial Mixer
# ============================================================================
class GatedLinearAttention(nn.Module):
"""
Gated Linear Attention for 2D spatial mixing.
O(N) complexity instead of O(N²) softmax attention.
Based on ViG/GLA research but adapted for diffusion:
- Bidirectional scan (forward + backward)
- 2D locality injection via depthwise conv gating
- Token-differential operator to prevent oversmoothing (from DyDiLA)
Math:
Q, K, V = linear(x), linear(x), linear(x)
Q = phi(Q), K = phi(K) where phi = 1 + elu (non-negative feature map)
Forward scan: S_i = decay * S_{i-1} + K_i^T V_i; O_i = Q_i S_i
Backward scan: same in reverse
Output = gate * (O_fwd + O_bwd) * local_gate
Complexity: O(N * d²) where d is head dimension, N is sequence length
"""
def __init__(self, dim: int, num_heads: int = 8, head_dim: int = 32, dropout: float = 0.0):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
inner_dim = num_heads * head_dim
self.qkv = nn.Linear(dim, 3 * inner_dim, bias=False)
self.out_proj = nn.Linear(inner_dim, dim, bias=False)
# Learnable decay for recurrence (per-head)
self.log_decay = nn.Parameter(torch.zeros(num_heads))
# Gate for output
self.gate = nn.Linear(dim, inner_dim, bias=False)
# 2D locality injection (depthwise conv) - critical for spatial structure
self.local_conv = nn.Conv2d(inner_dim, inner_dim, 3, padding=1, groups=inner_dim, bias=False)
self.local_gate = nn.Linear(dim, inner_dim, bias=False)
# Token differential parameter (from DyDiLA - prevents oversmoothing)
self.diff_lambda = nn.Parameter(torch.tensor(0.1))
self.dropout = nn.Dropout(dropout)
self.norm = RMSNorm(inner_dim)
def _feature_map(self, x):
"""Non-negative feature map: 1 + elu(x)"""
return 1.0 + F.elu(x)
def _scan(self, Q, K, V, reverse=False):
"""Linear recurrent scan - O(N * d²) per direction."""
B, H, N, D = Q.shape
decay = torch.sigmoid(self.log_decay).view(1, H, 1, 1) # [1, H, 1, 1]
if reverse:
Q = Q.flip(2)
K = K.flip(2)
V = V.flip(2)
# Chunk-wise computation for memory efficiency
chunk_size = min(64, N)
outputs = []
S = torch.zeros(B, H, D, D, device=Q.device, dtype=Q.dtype)
for i in range(0, N, chunk_size):
q_chunk = Q[:, :, i:i+chunk_size] # [B, H, C, D]
k_chunk = K[:, :, i:i+chunk_size]
v_chunk = V[:, :, i:i+chunk_size]
chunk_len = q_chunk.shape[2]
# Update state: S = decay * S + K^T V
kv = torch.einsum('bhcd,bhce->bhde', k_chunk, v_chunk)
S = decay * S + kv
# Query state: O = Q S
o_chunk = torch.einsum('bhcd,bhde->bhce', q_chunk, S)
outputs.append(o_chunk)
output = torch.cat(outputs, dim=2)
if reverse:
output = output.flip(2)
return output
def forward(self, x, h: int, w: int):
"""
Args:
x: [B, N, D] where N = H*W
h, w: spatial dimensions
Returns:
[B, N, D]
"""
B, N, D = x.shape
# Project to Q, K, V
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape to heads
q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)
# Token differential (prevents oversmoothing)
# Q_diff = Q_i - lambda * Q_{i-1}, K_diff = K_i - lambda * K_{i-1}
lam = torch.sigmoid(self.diff_lambda)
q_shifted = F.pad(q[:, :, :-1], (0, 0, 1, 0))
k_shifted = F.pad(k[:, :, :-1], (0, 0, 1, 0))
q = q - lam * q_shifted
k = k - lam * k_shifted
# Apply feature map (non-negative)
q = self._feature_map(q)
k = self._feature_map(k)
# Bidirectional scan
o_fwd = self._scan(q, k, v, reverse=False)
o_bwd = self._scan(q, k, v, reverse=True)
output = o_fwd + o_bwd
# Normalize
output = rearrange(output, 'b h n d -> b n (h d)')
output = self.norm(output)
# 2D locality injection (GaLI from ViG)
x_2d = rearrange(x, 'b (h w) d -> b d h w', h=h, w=w)
gate_input = rearrange(x, 'b n d -> b n d')
local_feat = self.local_conv(rearrange(self.local_gate(gate_input), 'b (h w) d -> b d h w', h=h, w=w))
local_feat = rearrange(local_feat, 'b d h w -> b (h w) d')
# Gated output
g = torch.sigmoid(self.gate(x))
output = g * output * torch.sigmoid(local_feat)
return self.dropout(self.out_proj(output))
class GLDBlock(nn.Module):
"""
Gated Linear Diffusion Block.
Components:
1. GatedLinearAttention for spatial mixing (O(N) complexity)
2. SwiGLU FFN for channel mixing
3. Timestep + condition modulation (adaptive layer norm)
4. 2D RoPE for position encoding
This replaces the standard transformer block in diffusion models.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
head_dim: int = 32,
ffn_mult: float = 2.67,
dropout: float = 0.0,
cond_dim: int = 256,
):
super().__init__()
self.norm1 = RMSNorm(dim)
self.norm2 = RMSNorm(dim)
self.attn = GatedLinearAttention(dim, num_heads, head_dim, dropout)
self.ffn = SwiGLU(dim, int(dim * ffn_mult), dropout)
# Adaptive modulation (scale, shift, gate for each sub-layer)
# Conditioned on timestep + text embedding
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(cond_dim, 6 * dim, bias=False),
)
# Cross-attention to text (lightweight - only when text is available)
self.cross_norm = RMSNorm(dim)
self.cross_q = nn.Linear(dim, dim, bias=False)
self.cross_kv = nn.Linear(cond_dim, 2 * dim, bias=False)
self.cross_out = nn.Linear(dim, dim, bias=False)
self.cross_gate = nn.Parameter(torch.zeros(1)) # Zero-init for residual
def forward(
self,
x: torch.Tensor, # [B, N, D]
cond: torch.Tensor, # [B, cond_dim] - timestep + global condition
text_ctx: Optional[torch.Tensor] = None, # [B, T, cond_dim] - text tokens
h: int = 32,
w: int = 32,
) -> torch.Tensor:
B, N, D = x.shape
# Compute modulation parameters
mod = self.adaLN_modulation(cond) # [B, 6*D]
shift1, scale1, gate1, shift2, scale2, gate2 = mod.chunk(6, dim=-1)
# Pre-norm + modulate + GLA
x_norm = self.norm1(x)
x_norm = x_norm * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1)
x = x + gate1.unsqueeze(1) * self.attn(x_norm, h, w)
# Cross-attention to text (if available)
if text_ctx is not None:
x_cross = self.cross_norm(x)
q = self.cross_q(x_cross)
kv = self.cross_kv(text_ctx)
k, v = kv.chunk(2, dim=-1)
# Simple dot-product attention (text sequence is short, so O(N*T) is fine)
scale = q.shape[-1] ** -0.5
attn_weights = torch.bmm(q, k.transpose(-2, -1)) * scale
attn_weights = F.softmax(attn_weights, dim=-1)
cross_out = torch.bmm(attn_weights, v)
x = x + torch.tanh(self.cross_gate) * self.cross_out(cross_out)
# Pre-norm + modulate + FFN
x_norm = self.norm2(x)
x_norm = x_norm * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1)
x = x + gate2.unsqueeze(1) * self.ffn(x_norm)
return x
# ============================================================================
# Recursive Latent Refinement (RLR) Core - THE KEY INNOVATION
# ============================================================================
class RecursiveLatentCore(nn.Module):
"""
The Recursive Latent Refinement (RLR) Core.
This is the key architectural innovation of LRF. Instead of stacking
many unique transformer layers (like DiT with 28 layers), we use a
small set of GLD blocks applied RECURSIVELY through an HRM-inspired
iterative refinement loop.
Architecture:
- N_blocks GLD blocks (typically 4-6, shared across recursions)
- T_inner recursive applications per outer step (typically 4-6)
- T_outer outer steps with slow abstract state update (typically 2-3)
Effective depth: T_outer * T_inner * N_blocks = 2*4*4 = 32 effective layers
Actual parameters: only N_blocks sets = 4 unique block parameter sets
Training uses IFT (Implicit Function Theorem):
- Forward: run full recursion with torch.no_grad() for warmup
- Backward: only backprop through the LAST recursion step
- This gives O(1) memory cost regardless of recursion depth!
Mathematical formulation:
Let z be the noisy latent, c be the condition embedding.
Outer loop (j = 1..T_outer):
z_abstract = f_slow(z, c) # Abstract planning update
Inner loop (i = 1..T_inner):
z = f_blocks(z, z_abstract, c) # Apply N shared GLD blocks
Where f_blocks applies the same N GLD blocks in sequence.
The model learns a FIXED POINT: z* = f(z*, c)
At convergence, the output is the denoised prediction v(z_t, t, c).
"""
def __init__(
self,
dim: int = 384,
cond_dim: int = 256,
num_blocks: int = 4,
num_heads: int = 6,
head_dim: int = 64,
T_inner: int = 4,
T_outer: int = 2,
ffn_mult: float = 2.67,
dropout: float = 0.0,
use_ift_training: bool = True,
):
super().__init__()
self.dim = dim
self.cond_dim = cond_dim
self.num_blocks = num_blocks
self.T_inner = T_inner
self.T_outer = T_outer
self.use_ift_training = use_ift_training
# The shared GLD blocks (applied recursively)
self.blocks = nn.ModuleList([
GLDBlock(
dim=dim,
num_heads=num_heads,
head_dim=head_dim,
ffn_mult=ffn_mult,
dropout=dropout,
cond_dim=cond_dim,
)
for _ in range(num_blocks)
])
# Abstract state updater (the "slow" H-module from HRM)
# This updates a global abstract representation every T_inner steps
self.abstract_norm = RMSNorm(dim)
self.abstract_update = nn.Sequential(
nn.Linear(dim * 2, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim, bias=False),
)
self.abstract_gate = nn.Parameter(torch.zeros(1)) # Zero-init
# Input projection
self.input_proj = nn.Linear(dim, dim, bias=False)
# Timestep embedding
self.time_embed = nn.Sequential(
nn.Linear(256, cond_dim),
nn.SiLU(),
nn.Linear(cond_dim, cond_dim),
)
# Output projection (predicts velocity v for rectified flow)
self.out_norm = RMSNorm(dim)
self.out_proj = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim, bias=False),
)
# Recursion depth embedding (tells the model which recursion step it's on)
self.recursion_embed = nn.Embedding(T_outer * T_inner + 1, cond_dim)
# 2D positional encoding
self.rope = RotaryPositionEncoding2D(head_dim)
def _sinusoidal_embedding(self, t: torch.Tensor, dim: int = 256) -> torch.Tensor:
"""Sinusoidal timestep embedding."""
half_dim = dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t.unsqueeze(-1) * emb.unsqueeze(0)
return torch.cat([emb.sin(), emb.cos()], dim=-1)
def _apply_blocks(self, z, cond, text_ctx, h, w):
"""Apply all GLD blocks once."""
for block in self.blocks:
z = block(z, cond, text_ctx, h, w)
return z
def _recursive_refinement(self, z, cond_base, text_ctx, h, w):
"""
Full recursive refinement loop.
Returns the refined latent z after T_outer * T_inner applications.
"""
z_abstract = z.mean(dim=1, keepdim=True).expand_as(z) # Initial abstract state
step_idx = 0
for j in range(self.T_outer):
# Abstract state update (slow H-module)
z_pooled = z.mean(dim=1, keepdim=True).expand_as(z)
abstract_input = torch.cat([self.abstract_norm(z), z_pooled], dim=-1)
z_abstract = z_abstract + torch.tanh(self.abstract_gate) * self.abstract_update(abstract_input)
for i in range(self.T_inner):
# Add recursion depth information to conditioning
rec_emb = self.recursion_embed(
torch.tensor([step_idx], device=z.device)
).expand(z.shape[0], -1)
cond = cond_base + rec_emb
# Apply shared blocks with abstract state modulation
z_input = z + z_abstract # Combine detail + abstract
z = z + (self._apply_blocks(z_input, cond, text_ctx, h, w) - z) * 0.5 # Damped update
step_idx += 1
return z
def forward(
self,
z_t: torch.Tensor, # [B, C, H, W] - noisy latent
t: torch.Tensor, # [B] - timestep (0 to 1)
text_emb: Optional[torch.Tensor] = None, # [B, T, cond_dim] - text tokens
text_global: Optional[torch.Tensor] = None, # [B, cond_dim] - global text embedding
image_cond: Optional[torch.Tensor] = None, # [B, C, H, W] - for editing tasks
) -> torch.Tensor:
"""
Forward pass predicting velocity v_theta(z_t, t, c).
For rectified flow: z_t = (1-t) * z_0 + t * epsilon
Target: v = epsilon - z_0
"""
B, C, H, W = z_t.shape
# Flatten spatial dims
z = rearrange(z_t, 'b c h w -> b (h w) c')
# If editing: concatenate condition image (channel-wise before projection)
if image_cond is not None:
img_cond_flat = rearrange(image_cond, 'b c h w -> b (h w) c')
z = z + img_cond_flat # Additive conditioning preserves spatial correspondence
# Project
z = self.input_proj(z)
# Build conditioning
t_emb = self._sinusoidal_embedding(t)
t_emb = self.time_embed(t_emb) # [B, cond_dim]
if text_global is not None:
cond = t_emb + text_global
else:
cond = t_emb
# Apply recursive refinement
if self.training and self.use_ift_training:
# IFT training: no_grad warmup + 1-step grad
with torch.no_grad():
for _ in range(self.T_outer - 1):
z = self._recursive_refinement(z, cond, text_emb, H, W)
# Last step with gradients
z = self._recursive_refinement(z, cond, text_emb, H, W)
else:
# Full recursion (inference or non-IFT training)
z = self._recursive_refinement(z, cond, text_emb, H, W)
# Output projection
z = self.out_norm(z)
v = self.out_proj(z)
# Reshape back to spatial
v = rearrange(v, 'b (h w) c -> b c h w', h=H, w=W)
return v
# ============================================================================
# Compact VAE (Tiny Decoder inspired by SnapGen)
# ============================================================================
class TinyResBlock(nn.Module):
"""Ultra-compact residual block for tiny decoder."""
def __init__(self, in_channels: int, out_channels: int = None):
super().__init__()
out_channels = out_channels or in_channels
self.norm1 = nn.GroupNorm(min(8, in_channels), in_channels)
self.conv1 = DepthwiseSeparableConv2d(in_channels, out_channels, 3)
self.norm2 = nn.GroupNorm(min(8, out_channels), out_channels)
self.conv2 = DepthwiseSeparableConv2d(out_channels, out_channels, 3)
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) if in_channels != out_channels else nn.Identity()
def forward(self, x):
h = self.conv1(F.silu(self.norm1(x)))
h = self.conv2(F.silu(self.norm2(h)))
return self.skip(x) + h
class CompactEncoder(nn.Module):
"""
Compact image encoder: image -> latent space.
f=16 spatial compression, C_latent channels.
Uses strided depthwise-separable convolutions for efficiency.
4 downsampling stages: 256->128->64->32->16 (for 256x256 input)
"""
def __init__(
self,
in_channels: int = 3,
latent_channels: int = 32,
base_channels: int = 64,
num_res_blocks: int = 2,
):
super().__init__()
channels = [base_channels, base_channels * 2, base_channels * 4, base_channels * 4]
self.stem = nn.Conv2d(in_channels, channels[0], 3, padding=1, bias=False)
self.downs = nn.ModuleList()
ch_in = channels[0]
for ch_out in channels:
blocks = nn.ModuleList()
# First block handles channel transition
blocks.append(TinyResBlock(ch_in, ch_out))
for _ in range(num_res_blocks - 1):
blocks.append(TinyResBlock(ch_out, ch_out))
# Downsample with strided conv
down = nn.Conv2d(ch_out, ch_out, 4, stride=2, padding=1, bias=False)
self.downs.append(nn.ModuleDict({
'blocks': blocks,
'down': down,
}))
ch_in = ch_out
# To latent
self.to_latent = nn.Sequential(
nn.GroupNorm(8, ch_in),
nn.SiLU(),
nn.Conv2d(ch_in, latent_channels * 2, 1, bias=False), # *2 for mean+logvar
)
def forward(self, x):
h = self.stem(x)
for down_module in self.downs:
for block in down_module['blocks']:
h = block(h)
h = down_module['down'](h)
params = self.to_latent(h)
mean, logvar = params.chunk(2, dim=1)
logvar = torch.clamp(logvar, -30.0, 20.0)
return mean, logvar
class TinyDecoder(nn.Module):
"""
SnapGen-inspired tiny decoder: latent -> image.
~1-2M parameters. No attention layers.
Uses depthwise-separable convolutions + minimal GroupNorm.
4 upsampling stages matching the encoder.
"""
def __init__(
self,
latent_channels: int = 32,
out_channels: int = 3,
base_channels: int = 128,
num_res_blocks: int = 2,
):
super().__init__()
channels = [base_channels * 2, base_channels * 2, base_channels, base_channels // 2]
self.from_latent = nn.Conv2d(latent_channels, channels[0], 1, bias=False)
self.ups = nn.ModuleList()
ch_in = channels[0]
for ch_out in channels:
blocks = nn.ModuleList()
for _ in range(num_res_blocks):
blocks.append(TinyResBlock(ch_in, ch_in))
# Upsample with channel transition
up = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
DepthwiseSeparableConv2d(ch_in, ch_out, 3),
)
self.ups.append(nn.ModuleDict({
'blocks': blocks,
'up': up,
}))
ch_in = ch_out
self.to_image = nn.Sequential(
nn.GroupNorm(min(8, ch_in), ch_in),
nn.SiLU(),
nn.Conv2d(ch_in, out_channels, 3, padding=1),
nn.Tanh(), # Output in [-1, 1]
)
def forward(self, z):
h = self.from_latent(z)
for up_module in self.ups:
for block in up_module['blocks']:
h = block(h)
h = up_module['up'](h)
return self.to_image(h)
class CompactVAE(nn.Module):
"""
Complete VAE with compact encoder + tiny decoder.
f=16 compression, configurable latent channels.
"""
def __init__(
self,
in_channels: int = 3,
latent_channels: int = 32,
encoder_base_ch: int = 64,
decoder_base_ch: int = 128,
):
super().__init__()
self.encoder = CompactEncoder(in_channels, latent_channels, encoder_base_ch)
self.decoder = TinyDecoder(latent_channels, in_channels, decoder_base_ch)
self.latent_channels = latent_channels
def encode(self, x):
mean, logvar = self.encoder(x)
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mean + eps * std
else:
z = mean
return z, mean, logvar
def decode(self, z):
return self.decoder(z)
def forward(self, x):
z, mean, logvar = self.encode(x)
recon = self.decode(z)
return recon, mean, logvar
# ============================================================================
# Text Conditioner (Lightweight)
# ============================================================================
class SimpleTextEncoder(nn.Module):
"""
Lightweight text encoder for the standalone prototype.
In production, this would be replaced by TinyCLIP or a small LM.
For the prototype: simple learned embeddings + small transformer.
This lets us test the full pipeline without a heavy text encoder.
"""
def __init__(
self,
vocab_size: int = 32000,
max_length: int = 77,
dim: int = 256,
num_layers: int = 4,
num_heads: int = 4,
):
super().__init__()
self.dim = dim
self.token_embed = nn.Embedding(vocab_size, dim)
self.pos_embed = nn.Embedding(max_length, dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim, nhead=num_heads, dim_feedforward=dim*4,
dropout=0.1, activation='gelu', batch_first=True, norm_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.norm = RMSNorm(dim)
# Global pooling projection
self.global_proj = nn.Sequential(
nn.Linear(dim, dim),
nn.SiLU(),
nn.Linear(dim, dim),
)
def forward(self, token_ids, attention_mask=None):
B, T = token_ids.shape
pos_ids = torch.arange(T, device=token_ids.device).unsqueeze(0).expand(B, -1)
x = self.token_embed(token_ids) + self.pos_embed(pos_ids)
if attention_mask is not None:
# Convert to transformer mask (True = ignore)
src_key_padding_mask = ~attention_mask.bool()
else:
src_key_padding_mask = None
x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
x = self.norm(x)
# Global embedding (mean pool over non-padded tokens)
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float()
global_emb = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
else:
global_emb = x.mean(dim=1)
global_emb = self.global_proj(global_emb)
return x, global_emb # [B, T, D], [B, D]
# ============================================================================
# Full LRF Model
# ============================================================================
class LatentRecurrentFlow(nn.Module):
"""
LatentRecurrentFlow (LRF) - Complete model.
Combines:
1. CompactVAE for image encoding/decoding
2. SimpleTextEncoder for text conditioning
3. RecursiveLatentCore for denoising
Training modes:
- 'vae': Train only the VAE
- 'denoise': Train only the denoising core (freeze VAE)
- 'e2e': End-to-end fine-tuning
- 'distill': Consistency distillation from teacher
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__()
config = config or self.default_config()
self.config = config
# VAE
self.vae = CompactVAE(
in_channels=3,
latent_channels=config['latent_channels'],
encoder_base_ch=config.get('encoder_base_ch', 64),
decoder_base_ch=config.get('decoder_base_ch', 128),
)
# Text encoder
self.text_encoder = SimpleTextEncoder(
vocab_size=config.get('vocab_size', 32000),
max_length=config.get('max_text_length', 77),
dim=config['cond_dim'],
num_layers=config.get('text_layers', 4),
num_heads=config.get('text_heads', 4),
)
# Denoising core
self.core = RecursiveLatentCore(
dim=config['latent_channels'],
cond_dim=config['cond_dim'],
num_blocks=config['num_blocks'],
num_heads=config.get('num_heads', 6),
head_dim=config.get('head_dim', 64),
T_inner=config.get('T_inner', 4),
T_outer=config.get('T_outer', 2),
ffn_mult=config.get('ffn_mult', 2.67),
dropout=config.get('dropout', 0.0),
use_ift_training=config.get('use_ift', True),
)
# Latent scaling (learnable, stabilizes training)
self.latent_scale = nn.Parameter(torch.tensor(1.0))
@staticmethod
def default_config():
"""Default config targeting ~50M params, trainable on 16GB."""
return {
'latent_channels': 32,
'cond_dim': 256,
'num_blocks': 4,
'num_heads': 4,
'head_dim': 64,
'T_inner': 4,
'T_outer': 2,
'ffn_mult': 2.67,
'dropout': 0.0,
'use_ift': True,
'encoder_base_ch': 64,
'decoder_base_ch': 128,
'vocab_size': 32000,
'max_text_length': 77,
'text_layers': 4,
'text_heads': 4,
}
@staticmethod
def tiny_config():
"""Tiny config for quick testing."""
return {
'latent_channels': 16,
'cond_dim': 128,
'num_blocks': 2,
'num_heads': 2,
'head_dim': 32,
'T_inner': 2,
'T_outer': 1,
'ffn_mult': 2.0,
'dropout': 0.0,
'use_ift': False,
'encoder_base_ch': 32,
'decoder_base_ch': 64,
'vocab_size': 32000,
'max_text_length': 77,
'text_layers': 2,
'text_heads': 2,
}
def encode_image(self, x):
"""Encode image to latent space."""
z, mean, logvar = self.vae.encode(x)
return z * self.latent_scale, mean, logvar
def decode_latent(self, z):
"""Decode latent to image."""
return self.vae.decode(z / self.latent_scale)
def encode_text(self, token_ids, attention_mask=None):
"""Encode text to conditioning vectors."""
return self.text_encoder(token_ids, attention_mask)
def predict_velocity(self, z_t, t, text_emb=None, text_global=None, image_cond=None):
"""Predict velocity for rectified flow."""
return self.core(z_t, t, text_emb, text_global, image_cond)
def get_param_groups(self):
"""Return parameter groups for staged training."""
return {
'vae_encoder': list(self.vae.encoder.parameters()),
'vae_decoder': list(self.vae.decoder.parameters()),
'text_encoder': list(self.text_encoder.parameters()),
'core': list(self.core.parameters()),
'latent_scale': [self.latent_scale],
}
def count_parameters(self):
"""Count parameters per module."""
counts = {}
for name, module in [
('vae_encoder', self.vae.encoder),
('vae_decoder', self.vae.decoder),
('text_encoder', self.text_encoder),
('core', self.core),
]:
counts[name] = sum(p.numel() for p in module.parameters())
counts['latent_scale'] = 1
counts['total'] = sum(counts.values())
return counts
def forward(self, x=None, token_ids=None, attention_mask=None, **kwargs):
"""Full forward pass for training. See training script for usage."""
raise NotImplementedError(
"Use the training pipeline functions instead of calling forward() directly. "
"See LRFTrainer for VAE training, denoiser training, and distillation."
)