""" LiquidFlow Block — Hybrid CfC + Mamba-2 SSD architecture. CORRECTED VERSION: proper dimensions, no sequential loops. Architecture per block: Input → Mamba-2 SSD (bidirectional) → CfC adaptive gate → Output The CfC provides adaptive gating that modulates the SSM output based on input-dependent "liquid" time constants. """ import torch import torch.nn as nn import torch.nn.functional as F import math from .cfc_cell import CfCCell, CfCBlock from .mamba2_ssd import Mamba2SSD, Mamba2Block class LiquidMambaBlock(nn.Module): """ LiquidMamba: CfC-gated Mamba-2 SSD block. Flow: 1. Input → LayerNorm → Mamba-2 SSD (bidirectional scan) 2. SSM output → CfC adaptive gate (parallel over all positions) 3. Gated output → residual + feed-forward The CfC gate learns WHEN to trust the SSM output vs the raw input, creating content-aware adaptive processing. """ def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0): super().__init__() self.dim = dim # LayerNorms self.norm_ssm = nn.LayerNorm(dim) self.norm_gate = nn.LayerNorm(dim) self.norm_ff = nn.LayerNorm(dim) # Mamba-2 SSD: bidirectional scan self.ssd_fwd = Mamba2SSD(dim=dim, d_state=d_state, d_conv=d_conv, expand=expand) self.ssd_bwd = Mamba2SSD(dim=dim, d_state=d_state, d_conv=d_conv, expand=expand) self.merge = nn.Linear(dim * 2, dim, bias=False) # CfC gate: parallel adaptive gating self.cfc_gate = CfCCell(dim=dim, dropout=dropout) # Gate projection (learnable mixing) self.gate_proj = nn.Linear(dim, dim) # Feed-forward ff_dim = dim * expand self.ff = nn.Sequential( nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(ff_dim, dim), nn.Dropout(dropout), ) def forward(self, x): """ Args: x: [B, C, H, W] or [B, L, C] Returns: Same shape as input """ is_2d = x.dim() == 4 if is_2d: B, C, H, W = x.shape x = x.flatten(2).transpose(1, 2) # [B, HW, C] # === SSM branch === residual = x x_norm = self.norm_ssm(x) # Bidirectional Mamba-2 scan fwd_out = self.ssd_fwd(x_norm) bwd_out = torch.flip(self.ssd_bwd(torch.flip(x_norm, [1])), [1]) ssm_out = self.merge(torch.cat([fwd_out, bwd_out], dim=-1)) # === CfC gate === # CfC processes the SSM output and produces adaptive gate gate_input = self.norm_gate(ssm_out) cfc_out = self.cfc_gate(gate_input) # [B, L, D] — parallel! # Sigmoid gate: how much SSM output to use gate = torch.sigmoid(self.gate_proj(cfc_out)) # Gated residual: blend SSM output with residual x = residual + gate * ssm_out # === Feed-forward === x = x + self.ff(self.norm_ff(x)) if is_2d: x = x.transpose(1, 2).reshape(B, C, H, W) return x class LiquidFlowStage(nn.Module): """Stack of LiquidMamba blocks at the same resolution.""" def __init__(self, dim, num_blocks=4, d_state=16, expand=2, dropout=0.0): super().__init__() self.blocks = nn.ModuleList([ LiquidMambaBlock(dim=dim, d_state=d_state, expand=expand, dropout=dropout) for _ in range(num_blocks) ]) def forward(self, x): for block in self.blocks: x = block(x) return x class LiquidFlowBackbone(nn.Module): """ Complete LiquidFlow backbone — DiT-style noise predictor. FIXED: Output shape == Input shape (no patch_size confusion). Architecture: Input [B, in_ch, H, W] → Conv2d projection to hidden_dim → + sinusoidal timestep embedding (AdaLN-style) → + learnable positional encoding → N × LiquidMamba Stages → Conv2d projection back to in_ch → Output [B, in_ch, H, W] """ def __init__( self, in_channels=4, hidden_dim=256, num_stages=4, blocks_per_stage=4, d_state=16, expand=2, dropout=0.0, ): super().__init__() self.in_channels = in_channels self.hidden_dim = hidden_dim # Input projection (pointwise conv) self.in_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1) # Timestep embedding self.time_embed = nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 4), nn.SiLU(), nn.Linear(hidden_dim * 4, hidden_dim), ) # AdaLN-style conditioning: scale and shift self.t_cond = nn.Sequential( nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 2), ) # Positional encoding (learnable, supports up to 64×64 = 4096 positions) self.pos_embed = nn.Parameter(torch.randn(1, 4096, hidden_dim) * 0.02) # LiquidFlow stages self.stages = nn.ModuleList([ LiquidFlowStage( dim=hidden_dim, num_blocks=blocks_per_stage, d_state=d_state, expand=expand, dropout=dropout, ) for _ in range(num_stages) ]) # Output projection self.out_norm = nn.LayerNorm(hidden_dim) self.out_proj = nn.Linear(hidden_dim, in_channels) self._init_weights() def _init_weights(self): # Zero-init output projection for residual learning nn.init.zeros_(self.out_proj.weight) nn.init.zeros_(self.out_proj.bias) def _sinusoidal_embedding(self, timesteps, dim): """Sinusoidal positional embedding for diffusion timesteps.""" half = dim // 2 freqs = torch.exp( -math.log(10000.0) * torch.arange(half, device=timesteps.device).float() / half ) args = timesteps.float().unsqueeze(-1) * freqs.unsqueeze(0) emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: emb = F.pad(emb, (0, 1)) return emb def forward(self, x, t): """ Args: x: [B, in_channels, H, W] — noisy latent t: [B] — diffusion timesteps (integers 0..T-1) Returns: [B, in_channels, H, W] — predicted noise (same shape as input!) """ B, C, H, W = x.shape L = H * W # Project to hidden dim x = self.in_proj(x) # [B, hidden_dim, H, W] x = x.flatten(2).transpose(1, 2) # [B, HW, hidden_dim] # Timestep conditioning (AdaLN) t_emb = self._sinusoidal_embedding(t, self.hidden_dim) # [B, hidden_dim] t_emb = self.time_embed(t_emb) # [B, hidden_dim] t_cond = self.t_cond(t_emb) # [B, hidden_dim*2] scale, shift = t_cond.chunk(2, dim=-1) # each [B, hidden_dim] # Apply conditioning + positional encoding x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) x = x + self.pos_embed[:, :L, :] # Reshape to 2D for processing x = x.transpose(1, 2).reshape(B, self.hidden_dim, H, W) # Process through all stages for stage in self.stages: x = stage(x) # Output head x = x.flatten(2).transpose(1, 2) # [B, HW, hidden_dim] x = self.out_norm(x) x = self.out_proj(x) # [B, HW, in_channels] # Reshape back to image: [B, in_channels, H, W] x = x.transpose(1, 2).reshape(B, self.in_channels, H, W) return x