File size: 5,869 Bytes
714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 714db4b f4749f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """
CfC Cell — Closed-form Continuous-time neural network cell.
FULLY PARALLEL implementation — no sequential loops.
From: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022)
Core CfC equation (Eq. 10 from paper):
x(t) = σ(-f(x,I;θ_f)·t) ⊙ g(x,I;θ_g) + (1 - σ(-f(x,I;θ_f)·t)) ⊙ h(x,I;θ_h)
Key insight for parallelization:
The CfC equation is a CLOSED-FORM expression. It maps (input, time) → output
with NO recurrent dependency between timesteps. This means for image processing
we can compute ALL spatial positions in a single parallel pass.
We use it as an adaptive gating mechanism:
- f network produces position-dependent time constants
- g/h networks produce two candidate feature maps
- The sigmoid gate blends them adaptively per-position
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class CfCCell(nn.Module):
"""
Parallel CfC cell — processes ALL positions simultaneously.
The key realization: CfC's closed-form solution is NOT recurrent.
It's a function of (input, time) → output. So we apply it to all
spatial positions in parallel.
For a sequence [B, L, D]:
- f, g, h networks are applied to ALL L positions in parallel
- The time parameter t modulates the gate per-position
- Output is computed in a single vectorized operation
Args:
dim: Feature dimension
dropout: Dropout rate
time_scale: Range for time parameter
"""
def __init__(self, dim, dropout=0.0, time_scale=(0.1, 1.0)):
super().__init__()
self.dim = dim
self.time_scale = time_scale
# Shared backbone (processes all positions in parallel)
self.backbone = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.LayerNorm(dim * 4),
nn.SiLU(),
nn.Dropout(dropout),
)
# f head: time-constant (bounded by tanh for stability)
self.f_head = nn.Sequential(
nn.Linear(dim * 4, dim),
nn.Tanh(),
)
# g head: "fast" feature (dominant when gate ≈ 1, i.e. small t)
self.g_head = nn.Sequential(
nn.Linear(dim * 4, dim),
)
# h head: "slow" feature (dominant when gate ≈ 0, i.e. large t)
self.h_head = nn.Sequential(
nn.Linear(dim * 4, dim),
)
# Learnable time-bias per channel (makes time adaptive per feature)
self.time_bias = nn.Parameter(torch.zeros(dim))
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x, t=None):
"""
Fully parallel CfC forward pass.
Args:
x: [B, L, D] — all positions processed simultaneously
t: Optional time parameter [B, 1, 1] or scalar.
If None, sampled randomly during training, fixed during eval.
Returns:
out: [B, L, D]
"""
B, L, D = x.shape
device = x.device
# Time parameter
if t is None:
if self.training:
# Random time per batch during training (data augmentation)
t = torch.rand(B, 1, 1, device=device) * (
self.time_scale[1] - self.time_scale[0]
) + self.time_scale[0]
else:
# Fixed midpoint during inference
t = torch.full((B, 1, 1), 0.5 * (self.time_scale[0] + self.time_scale[1]), device=device)
# Shared backbone (parallel over all B*L positions)
features = self.backbone(x) # [B, L, dim*4]
# Three heads (all parallel)
f_out = self.f_head(features) # [B, L, D] — bounded by tanh
g_out = self.g_head(features) # [B, L, D]
h_out = self.h_head(features) # [B, L, D]
# CfC gating: σ(-f * (t + time_bias))
# time_bias makes gating adaptive per-channel
effective_t = t + self.time_bias.view(1, 1, -1) # [B, 1, D] broadcast
gate = torch.sigmoid(-f_out * effective_t) # [B, L, D]
# CfC output: gate * g + (1-gate) * h
out = gate * g_out + (1 - gate) * h_out # [B, L, D]
return out
class CfCBlock(nn.Module):
"""
CfC block for 2D image processing.
Fully parallel — no sequential loops.
Architecture:
Input [B, C, H, W] → flatten → CfC (parallel) → reshape → Output
With: pre-norm, residual connection, feed-forward
"""
def __init__(self, dim, dropout=0.0, expansion_factor=2):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.cfc = CfCCell(dim=dim, dropout=dropout)
self.norm2 = nn.LayerNorm(dim)
ff_dim = dim * expansion_factor
self.ff = nn.Sequential(
nn.Linear(dim, ff_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(ff_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
"""
Args:
x: [B, C, H, W] or [B, L, C]
Returns:
Same shape as input
"""
is_2d = x.dim() == 4
if is_2d:
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # [B, HW, C]
# Pre-norm + CfC + residual
x = x + self.cfc(self.norm1(x))
# Pre-norm + FF + residual
x = x + self.ff(self.norm2(x))
if is_2d:
x = x.transpose(1, 2).reshape(B, C, H, W)
return x
|