LiquidFlow-Gen / liquid_flow /cfc_cell.py
krystv's picture
Upload liquid_flow/cfc_cell.py
f4749f1 verified
"""
CfC Cell β€” Closed-form Continuous-time neural network cell.
FULLY PARALLEL implementation β€” no sequential loops.
From: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022)
Core CfC equation (Eq. 10 from paper):
x(t) = Οƒ(-f(x,I;ΞΈ_f)Β·t) βŠ™ g(x,I;ΞΈ_g) + (1 - Οƒ(-f(x,I;ΞΈ_f)Β·t)) βŠ™ h(x,I;ΞΈ_h)
Key insight for parallelization:
The CfC equation is a CLOSED-FORM expression. It maps (input, time) β†’ output
with NO recurrent dependency between timesteps. This means for image processing
we can compute ALL spatial positions in a single parallel pass.
We use it as an adaptive gating mechanism:
- f network produces position-dependent time constants
- g/h networks produce two candidate feature maps
- The sigmoid gate blends them adaptively per-position
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class CfCCell(nn.Module):
"""
Parallel CfC cell β€” processes ALL positions simultaneously.
The key realization: CfC's closed-form solution is NOT recurrent.
It's a function of (input, time) β†’ output. So we apply it to all
spatial positions in parallel.
For a sequence [B, L, D]:
- f, g, h networks are applied to ALL L positions in parallel
- The time parameter t modulates the gate per-position
- Output is computed in a single vectorized operation
Args:
dim: Feature dimension
dropout: Dropout rate
time_scale: Range for time parameter
"""
def __init__(self, dim, dropout=0.0, time_scale=(0.1, 1.0)):
super().__init__()
self.dim = dim
self.time_scale = time_scale
# Shared backbone (processes all positions in parallel)
self.backbone = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.LayerNorm(dim * 4),
nn.SiLU(),
nn.Dropout(dropout),
)
# f head: time-constant (bounded by tanh for stability)
self.f_head = nn.Sequential(
nn.Linear(dim * 4, dim),
nn.Tanh(),
)
# g head: "fast" feature (dominant when gate β‰ˆ 1, i.e. small t)
self.g_head = nn.Sequential(
nn.Linear(dim * 4, dim),
)
# h head: "slow" feature (dominant when gate β‰ˆ 0, i.e. large t)
self.h_head = nn.Sequential(
nn.Linear(dim * 4, dim),
)
# Learnable time-bias per channel (makes time adaptive per feature)
self.time_bias = nn.Parameter(torch.zeros(dim))
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x, t=None):
"""
Fully parallel CfC forward pass.
Args:
x: [B, L, D] β€” all positions processed simultaneously
t: Optional time parameter [B, 1, 1] or scalar.
If None, sampled randomly during training, fixed during eval.
Returns:
out: [B, L, D]
"""
B, L, D = x.shape
device = x.device
# Time parameter
if t is None:
if self.training:
# Random time per batch during training (data augmentation)
t = torch.rand(B, 1, 1, device=device) * (
self.time_scale[1] - self.time_scale[0]
) + self.time_scale[0]
else:
# Fixed midpoint during inference
t = torch.full((B, 1, 1), 0.5 * (self.time_scale[0] + self.time_scale[1]), device=device)
# Shared backbone (parallel over all B*L positions)
features = self.backbone(x) # [B, L, dim*4]
# Three heads (all parallel)
f_out = self.f_head(features) # [B, L, D] β€” bounded by tanh
g_out = self.g_head(features) # [B, L, D]
h_out = self.h_head(features) # [B, L, D]
# CfC gating: Οƒ(-f * (t + time_bias))
# time_bias makes gating adaptive per-channel
effective_t = t + self.time_bias.view(1, 1, -1) # [B, 1, D] broadcast
gate = torch.sigmoid(-f_out * effective_t) # [B, L, D]
# CfC output: gate * g + (1-gate) * h
out = gate * g_out + (1 - gate) * h_out # [B, L, D]
return out
class CfCBlock(nn.Module):
"""
CfC block for 2D image processing.
Fully parallel β€” no sequential loops.
Architecture:
Input [B, C, H, W] β†’ flatten β†’ CfC (parallel) β†’ reshape β†’ Output
With: pre-norm, residual connection, feed-forward
"""
def __init__(self, dim, dropout=0.0, expansion_factor=2):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.cfc = CfCCell(dim=dim, dropout=dropout)
self.norm2 = nn.LayerNorm(dim)
ff_dim = dim * expansion_factor
self.ff = nn.Sequential(
nn.Linear(dim, ff_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(ff_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
"""
Args:
x: [B, C, H, W] or [B, L, C]
Returns:
Same shape as input
"""
is_2d = x.dim() == 4
if is_2d:
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # [B, HW, C]
# Pre-norm + CfC + residual
x = x + self.cfc(self.norm1(x))
# Pre-norm + FF + residual
x = x + self.ff(self.norm2(x))
if is_2d:
x = x.transpose(1, 2).reshape(B, C, H, W)
return x