| """ |
| CfC Cell β Closed-form Continuous-time neural network cell. |
| FULLY PARALLEL implementation β no sequential loops. |
| |
| From: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022) |
| |
| Core CfC equation (Eq. 10 from paper): |
| x(t) = Ο(-f(x,I;ΞΈ_f)Β·t) β g(x,I;ΞΈ_g) + (1 - Ο(-f(x,I;ΞΈ_f)Β·t)) β h(x,I;ΞΈ_h) |
| |
| Key insight for parallelization: |
| The CfC equation is a CLOSED-FORM expression. It maps (input, time) β output |
| with NO recurrent dependency between timesteps. This means for image processing |
| we can compute ALL spatial positions in a single parallel pass. |
| |
| We use it as an adaptive gating mechanism: |
| - f network produces position-dependent time constants |
| - g/h networks produce two candidate feature maps |
| - The sigmoid gate blends them adaptively per-position |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class CfCCell(nn.Module): |
| """ |
| Parallel CfC cell β processes ALL positions simultaneously. |
| |
| The key realization: CfC's closed-form solution is NOT recurrent. |
| It's a function of (input, time) β output. So we apply it to all |
| spatial positions in parallel. |
| |
| For a sequence [B, L, D]: |
| - f, g, h networks are applied to ALL L positions in parallel |
| - The time parameter t modulates the gate per-position |
| - Output is computed in a single vectorized operation |
| |
| Args: |
| dim: Feature dimension |
| dropout: Dropout rate |
| time_scale: Range for time parameter |
| """ |
| |
| def __init__(self, dim, dropout=0.0, time_scale=(0.1, 1.0)): |
| super().__init__() |
| self.dim = dim |
| self.time_scale = time_scale |
| |
| |
| self.backbone = nn.Sequential( |
| nn.Linear(dim, dim * 4), |
| nn.LayerNorm(dim * 4), |
| nn.SiLU(), |
| nn.Dropout(dropout), |
| ) |
| |
| |
| self.f_head = nn.Sequential( |
| nn.Linear(dim * 4, dim), |
| nn.Tanh(), |
| ) |
| |
| |
| self.g_head = nn.Sequential( |
| nn.Linear(dim * 4, dim), |
| ) |
| |
| |
| self.h_head = nn.Sequential( |
| nn.Linear(dim * 4, dim), |
| ) |
| |
| |
| self.time_bias = nn.Parameter(torch.zeros(dim)) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight, gain=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| |
| def forward(self, x, t=None): |
| """ |
| Fully parallel CfC forward pass. |
| |
| Args: |
| x: [B, L, D] β all positions processed simultaneously |
| t: Optional time parameter [B, 1, 1] or scalar. |
| If None, sampled randomly during training, fixed during eval. |
| Returns: |
| out: [B, L, D] |
| """ |
| B, L, D = x.shape |
| device = x.device |
| |
| |
| if t is None: |
| if self.training: |
| |
| t = torch.rand(B, 1, 1, device=device) * ( |
| self.time_scale[1] - self.time_scale[0] |
| ) + self.time_scale[0] |
| else: |
| |
| t = torch.full((B, 1, 1), 0.5 * (self.time_scale[0] + self.time_scale[1]), device=device) |
| |
| |
| features = self.backbone(x) |
| |
| |
| f_out = self.f_head(features) |
| g_out = self.g_head(features) |
| h_out = self.h_head(features) |
| |
| |
| |
| effective_t = t + self.time_bias.view(1, 1, -1) |
| gate = torch.sigmoid(-f_out * effective_t) |
| |
| |
| out = gate * g_out + (1 - gate) * h_out |
| |
| return out |
|
|
|
|
| class CfCBlock(nn.Module): |
| """ |
| CfC block for 2D image processing. |
| Fully parallel β no sequential loops. |
| |
| Architecture: |
| Input [B, C, H, W] β flatten β CfC (parallel) β reshape β Output |
| With: pre-norm, residual connection, feed-forward |
| """ |
| |
| def __init__(self, dim, dropout=0.0, expansion_factor=2): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim) |
| self.cfc = CfCCell(dim=dim, dropout=dropout) |
| |
| self.norm2 = nn.LayerNorm(dim) |
| ff_dim = dim * expansion_factor |
| self.ff = nn.Sequential( |
| nn.Linear(dim, ff_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(ff_dim, dim), |
| nn.Dropout(dropout), |
| ) |
| |
| def forward(self, x): |
| """ |
| Args: |
| x: [B, C, H, W] or [B, L, C] |
| Returns: |
| Same shape as input |
| """ |
| is_2d = x.dim() == 4 |
| if is_2d: |
| B, C, H, W = x.shape |
| x = x.flatten(2).transpose(1, 2) |
| |
| |
| x = x + self.cfc(self.norm1(x)) |
| |
| x = x + self.ff(self.norm2(x)) |
| |
| if is_2d: |
| x = x.transpose(1, 2).reshape(B, C, H, W) |
| return x |
|
|