""" CfC Cell — Closed-form Continuous-time neural network cell. FULLY PARALLEL implementation — no sequential loops. From: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022) Core CfC equation (Eq. 10 from paper): x(t) = σ(-f(x,I;θ_f)·t) ⊙ g(x,I;θ_g) + (1 - σ(-f(x,I;θ_f)·t)) ⊙ h(x,I;θ_h) Key insight for parallelization: The CfC equation is a CLOSED-FORM expression. It maps (input, time) → output with NO recurrent dependency between timesteps. This means for image processing we can compute ALL spatial positions in a single parallel pass. We use it as an adaptive gating mechanism: - f network produces position-dependent time constants - g/h networks produce two candidate feature maps - The sigmoid gate blends them adaptively per-position """ import torch import torch.nn as nn import torch.nn.functional as F class CfCCell(nn.Module): """ Parallel CfC cell — processes ALL positions simultaneously. The key realization: CfC's closed-form solution is NOT recurrent. It's a function of (input, time) → output. So we apply it to all spatial positions in parallel. For a sequence [B, L, D]: - f, g, h networks are applied to ALL L positions in parallel - The time parameter t modulates the gate per-position - Output is computed in a single vectorized operation Args: dim: Feature dimension dropout: Dropout rate time_scale: Range for time parameter """ def __init__(self, dim, dropout=0.0, time_scale=(0.1, 1.0)): super().__init__() self.dim = dim self.time_scale = time_scale # Shared backbone (processes all positions in parallel) self.backbone = nn.Sequential( nn.Linear(dim, dim * 4), nn.LayerNorm(dim * 4), nn.SiLU(), nn.Dropout(dropout), ) # f head: time-constant (bounded by tanh for stability) self.f_head = nn.Sequential( nn.Linear(dim * 4, dim), nn.Tanh(), ) # g head: "fast" feature (dominant when gate ≈ 1, i.e. small t) self.g_head = nn.Sequential( nn.Linear(dim * 4, dim), ) # h head: "slow" feature (dominant when gate ≈ 0, i.e. large t) self.h_head = nn.Sequential( nn.Linear(dim * 4, dim), ) # Learnable time-bias per channel (makes time adaptive per feature) self.time_bias = nn.Parameter(torch.zeros(dim)) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight, gain=0.02) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x, t=None): """ Fully parallel CfC forward pass. Args: x: [B, L, D] — all positions processed simultaneously t: Optional time parameter [B, 1, 1] or scalar. If None, sampled randomly during training, fixed during eval. Returns: out: [B, L, D] """ B, L, D = x.shape device = x.device # Time parameter if t is None: if self.training: # Random time per batch during training (data augmentation) t = torch.rand(B, 1, 1, device=device) * ( self.time_scale[1] - self.time_scale[0] ) + self.time_scale[0] else: # Fixed midpoint during inference t = torch.full((B, 1, 1), 0.5 * (self.time_scale[0] + self.time_scale[1]), device=device) # Shared backbone (parallel over all B*L positions) features = self.backbone(x) # [B, L, dim*4] # Three heads (all parallel) f_out = self.f_head(features) # [B, L, D] — bounded by tanh g_out = self.g_head(features) # [B, L, D] h_out = self.h_head(features) # [B, L, D] # CfC gating: σ(-f * (t + time_bias)) # time_bias makes gating adaptive per-channel effective_t = t + self.time_bias.view(1, 1, -1) # [B, 1, D] broadcast gate = torch.sigmoid(-f_out * effective_t) # [B, L, D] # CfC output: gate * g + (1-gate) * h out = gate * g_out + (1 - gate) * h_out # [B, L, D] return out class CfCBlock(nn.Module): """ CfC block for 2D image processing. Fully parallel — no sequential loops. Architecture: Input [B, C, H, W] → flatten → CfC (parallel) → reshape → Output With: pre-norm, residual connection, feed-forward """ def __init__(self, dim, dropout=0.0, expansion_factor=2): super().__init__() self.norm1 = nn.LayerNorm(dim) self.cfc = CfCCell(dim=dim, dropout=dropout) self.norm2 = nn.LayerNorm(dim) ff_dim = dim * expansion_factor self.ff = nn.Sequential( nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(ff_dim, dim), nn.Dropout(dropout), ) def forward(self, x): """ Args: x: [B, C, H, W] or [B, L, C] Returns: Same shape as input """ is_2d = x.dim() == 4 if is_2d: B, C, H, W = x.shape x = x.flatten(2).transpose(1, 2) # [B, HW, C] # Pre-norm + CfC + residual x = x + self.cfc(self.norm1(x)) # Pre-norm + FF + residual x = x + self.ff(self.norm2(x)) if is_2d: x = x.transpose(1, 2).reshape(B, C, H, W) return x