krystv commited on
Commit
f4749f1
Β·
verified Β·
1 Parent(s): 029ca89

Upload liquid_flow/cfc_cell.py

Browse files
Files changed (1) hide show
  1. liquid_flow/cfc_cell.py +116 -111
liquid_flow/cfc_cell.py CHANGED
@@ -1,28 +1,21 @@
1
  """
2
  CfC Cell β€” Closed-form Continuous-time neural network cell.
 
3
 
4
  From: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022)
5
 
6
- The CfC model provides an approximate closed-form solution to Liquid Time-Constant (LTC)
7
- network dynamics without needing ODE solvers.
8
 
9
- Architecture:
10
- x(t) = Οƒ(-f(x,I;ΞΈ_f) Β· t) βŠ™ g(x,I;ΞΈ_g) + (1 - Οƒ(-f(x,I;ΞΈ_f) Β· t)) βŠ™ h(x,I;ΞΈ_h)
11
-
12
- Where:
13
- - f, g, h are neural network heads sharing a backbone
14
- - Οƒ is the sigmoid (replacing exponential decay for gradient stability)
15
- - t is a time parameter
16
- - The sigmoidal terms act as time-continuous gates between g and h
17
-
18
- Key properties:
19
- - No ODE solving β†’ 100x+ faster than Neural ODEs
20
- - Time-continuous gating mechanism β†’ adaptive computation
21
- - Closed-form β†’ stable gradients, easy to train
22
- - Naturally causal β†’ good for sequential processing
23
-
24
- For 2D image inputs: we treat the spatial sequence as "time" steps for the CfC,
25
- allowing the liquid dynamics to model spatial dependencies with adaptive gates.
26
  """
27
 
28
  import torch
@@ -32,138 +25,150 @@ import torch.nn.functional as F
32
 
33
  class CfCCell(nn.Module):
34
  """
35
- Single CfC cell with backbone + 3 heads (f, g, h).
 
 
 
 
 
 
 
 
 
36
 
37
  Args:
38
- dim: Hidden dimension
39
- backbone_dropout: Dropout in backbone layers
40
- time_scale: Range [a, b] for time parameter sampling
41
- use_conv: Add conv1d for local context
42
  """
43
 
44
- def __init__(self, dim, backbone_dropout=0.0, time_scale=(0.0, 1.0), use_conv=True):
45
  super().__init__()
46
  self.dim = dim
47
  self.time_scale = time_scale
48
 
49
- # Shared backbone
50
- backbone_dim = dim * 3
51
  self.backbone = nn.Sequential(
52
- nn.Linear(dim + dim, backbone_dim),
53
- nn.LayerNorm(backbone_dim),
54
- nn.SiLU(),
55
- nn.Dropout(backbone_dropout),
56
- nn.Linear(backbone_dim, dim * 4),
57
  nn.LayerNorm(dim * 4),
 
 
58
  )
59
 
60
- # Optional 1D conv
61
- self.conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1, groups=dim) if use_conv else None
 
 
 
62
 
63
- # Heads
64
- self.f_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.Tanh())
65
- self.g_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.GELU())
66
- self.h_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.GELU())
 
 
 
 
 
 
 
 
67
 
68
- self.out_proj = nn.Linear(dim, dim)
69
  self._init_weights()
70
 
71
  def _init_weights(self):
72
  for m in self.modules():
73
  if isinstance(m, nn.Linear):
74
- nn.init.normal_(m.weight, std=0.02)
75
  if m.bias is not None:
76
  nn.init.zeros_(m.bias)
77
 
78
- def forward(self, x, h_prev=None, t=None):
79
  """
 
 
80
  Args:
81
- x: [B, dim] or [B, L, dim]
82
- h_prev: Previous hidden state [B, dim]
83
- t: Time parameter
84
- Returns: h: [B, dim] or [B, L, dim]
 
85
  """
86
- is_seq = x.dim() == 3
87
- B, device = x.shape[0], x.device
88
-
89
- if is_seq:
90
- return self._forward_seq(x, h_prev, t)
91
-
92
- if h_prev is None:
93
- h_prev = torch.zeros(B, self.dim, device=device)
94
- if t is None:
95
- t = torch.rand(B, 1, device=device) * (self.time_scale[1] - self.time_scale[0]) + self.time_scale[0]
96
- elif t.dim() == 1:
97
- t = t.unsqueeze(1)
98
-
99
- return self._step(x, h_prev, t)
100
-
101
- def _forward_seq(self, x, h_prev=None, t=None):
102
  B, L, D = x.shape
103
  device = x.device
104
 
 
105
  if t is None:
106
- t = torch.rand(B, 1, 1, device=device) * (self.time_scale[1] - self.time_scale[0]) + self.time_scale[0]
107
-
108
- outputs = []
109
- h = torch.zeros(B, D, device=device) if h_prev is None else h_prev
110
- for step in range(L):
111
- h = self._step(x[:, step, :], h, t.squeeze(-1) if t.dim() == 3 else t)
112
- outputs.append(h)
113
- return torch.stack(outputs, dim=1)
114
-
115
- def _step(self, x, h_prev, t):
116
- """Core CfC step."""
117
- combined = torch.cat([x, h_prev], dim=-1)
118
- backbone_out = self.backbone(combined)
119
- f_base, g_base, h_base, skip = backbone_out.chunk(4, dim=-1)
120
-
121
- if self.conv is not None:
122
- f_base = f_base + self.conv(f_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1)
123
- g_base = g_base + self.conv(g_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1)
124
- h_base = h_base + self.conv(h_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1)
125
-
126
- f_out = self.f_head(f_base)
127
- g_out = self.g_head(g_base)
128
- h_out = self.h_head(h_base)
129
-
130
- gate = torch.sigmoid(-f_out * t)
131
- h = gate * g_out + (1 - gate) * h_out + skip
132
- return self.out_proj(h)
133
 
134
 
135
  class CfCBlock(nn.Module):
136
- """CfC block for 2D image processing with residual connection."""
 
 
137
 
138
- def __init__(self, dim, dropout=0.0, time_scale=(0.0, 1.0), expansion_factor=2):
 
 
 
 
 
139
  super().__init__()
140
- self.dim = dim
141
  self.norm1 = nn.LayerNorm(dim)
142
- self.norm2 = nn.LayerNorm(dim)
143
- self.cfc = CfCCell(dim=dim, backbone_dropout=dropout, time_scale=time_scale, use_conv=True)
144
 
 
145
  ff_dim = dim * expansion_factor
146
  self.ff = nn.Sequential(
147
- nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout),
148
- nn.Linear(ff_dim, dim), nn.Dropout(dropout),
 
 
 
149
  )
150
-
151
- self.pos_embed = nn.Parameter(torch.randn(1, 4096, dim) * 0.02)
152
 
153
- def forward(self, x, return_2d=True):
 
 
 
 
 
 
154
  is_2d = x.dim() == 4
155
  if is_2d:
156
  B, C, H, W = x.shape
157
- L = H * W
158
- x = x.flatten(2).transpose(1, 2)
159
- else:
160
- B, L, C = x.shape
161
-
162
- x_with_pos = x + self.pos_embed[:, :L, :]
163
- residual = x
164
- h = self.cfc(self.norm1(x_with_pos))
165
- x_out = h + self.ff(self.norm2(h + residual))
166
-
167
- if is_2d and return_2d:
168
- x_out = x_out.transpose(1, 2).reshape(B, C, H, W)
169
- return x_out
 
1
  """
2
  CfC Cell β€” Closed-form Continuous-time neural network cell.
3
+ FULLY PARALLEL implementation β€” no sequential loops.
4
 
5
  From: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022)
6
 
7
+ Core CfC equation (Eq. 10 from paper):
8
+ x(t) = Οƒ(-f(x,I;ΞΈ_f)Β·t) βŠ™ g(x,I;ΞΈ_g) + (1 - Οƒ(-f(x,I;ΞΈ_f)Β·t)) βŠ™ h(x,I;ΞΈ_h)
9
 
10
+ Key insight for parallelization:
11
+ The CfC equation is a CLOSED-FORM expression. It maps (input, time) β†’ output
12
+ with NO recurrent dependency between timesteps. This means for image processing
13
+ we can compute ALL spatial positions in a single parallel pass.
14
+
15
+ We use it as an adaptive gating mechanism:
16
+ - f network produces position-dependent time constants
17
+ - g/h networks produce two candidate feature maps
18
+ - The sigmoid gate blends them adaptively per-position
 
 
 
 
 
 
 
 
19
  """
20
 
21
  import torch
 
25
 
26
  class CfCCell(nn.Module):
27
  """
28
+ Parallel CfC cell β€” processes ALL positions simultaneously.
29
+
30
+ The key realization: CfC's closed-form solution is NOT recurrent.
31
+ It's a function of (input, time) β†’ output. So we apply it to all
32
+ spatial positions in parallel.
33
+
34
+ For a sequence [B, L, D]:
35
+ - f, g, h networks are applied to ALL L positions in parallel
36
+ - The time parameter t modulates the gate per-position
37
+ - Output is computed in a single vectorized operation
38
 
39
  Args:
40
+ dim: Feature dimension
41
+ dropout: Dropout rate
42
+ time_scale: Range for time parameter
 
43
  """
44
 
45
+ def __init__(self, dim, dropout=0.0, time_scale=(0.1, 1.0)):
46
  super().__init__()
47
  self.dim = dim
48
  self.time_scale = time_scale
49
 
50
+ # Shared backbone (processes all positions in parallel)
 
51
  self.backbone = nn.Sequential(
52
+ nn.Linear(dim, dim * 4),
 
 
 
 
53
  nn.LayerNorm(dim * 4),
54
+ nn.SiLU(),
55
+ nn.Dropout(dropout),
56
  )
57
 
58
+ # f head: time-constant (bounded by tanh for stability)
59
+ self.f_head = nn.Sequential(
60
+ nn.Linear(dim * 4, dim),
61
+ nn.Tanh(),
62
+ )
63
 
64
+ # g head: "fast" feature (dominant when gate β‰ˆ 1, i.e. small t)
65
+ self.g_head = nn.Sequential(
66
+ nn.Linear(dim * 4, dim),
67
+ )
68
+
69
+ # h head: "slow" feature (dominant when gate β‰ˆ 0, i.e. large t)
70
+ self.h_head = nn.Sequential(
71
+ nn.Linear(dim * 4, dim),
72
+ )
73
+
74
+ # Learnable time-bias per channel (makes time adaptive per feature)
75
+ self.time_bias = nn.Parameter(torch.zeros(dim))
76
 
 
77
  self._init_weights()
78
 
79
  def _init_weights(self):
80
  for m in self.modules():
81
  if isinstance(m, nn.Linear):
82
+ nn.init.xavier_uniform_(m.weight, gain=0.02)
83
  if m.bias is not None:
84
  nn.init.zeros_(m.bias)
85
 
86
+ def forward(self, x, t=None):
87
  """
88
+ Fully parallel CfC forward pass.
89
+
90
  Args:
91
+ x: [B, L, D] β€” all positions processed simultaneously
92
+ t: Optional time parameter [B, 1, 1] or scalar.
93
+ If None, sampled randomly during training, fixed during eval.
94
+ Returns:
95
+ out: [B, L, D]
96
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  B, L, D = x.shape
98
  device = x.device
99
 
100
+ # Time parameter
101
  if t is None:
102
+ if self.training:
103
+ # Random time per batch during training (data augmentation)
104
+ t = torch.rand(B, 1, 1, device=device) * (
105
+ self.time_scale[1] - self.time_scale[0]
106
+ ) + self.time_scale[0]
107
+ else:
108
+ # Fixed midpoint during inference
109
+ t = torch.full((B, 1, 1), 0.5 * (self.time_scale[0] + self.time_scale[1]), device=device)
110
+
111
+ # Shared backbone (parallel over all B*L positions)
112
+ features = self.backbone(x) # [B, L, dim*4]
113
+
114
+ # Three heads (all parallel)
115
+ f_out = self.f_head(features) # [B, L, D] β€” bounded by tanh
116
+ g_out = self.g_head(features) # [B, L, D]
117
+ h_out = self.h_head(features) # [B, L, D]
118
+
119
+ # CfC gating: Οƒ(-f * (t + time_bias))
120
+ # time_bias makes gating adaptive per-channel
121
+ effective_t = t + self.time_bias.view(1, 1, -1) # [B, 1, D] broadcast
122
+ gate = torch.sigmoid(-f_out * effective_t) # [B, L, D]
123
+
124
+ # CfC output: gate * g + (1-gate) * h
125
+ out = gate * g_out + (1 - gate) * h_out # [B, L, D]
126
+
127
+ return out
 
128
 
129
 
130
  class CfCBlock(nn.Module):
131
+ """
132
+ CfC block for 2D image processing.
133
+ Fully parallel β€” no sequential loops.
134
 
135
+ Architecture:
136
+ Input [B, C, H, W] β†’ flatten β†’ CfC (parallel) β†’ reshape β†’ Output
137
+ With: pre-norm, residual connection, feed-forward
138
+ """
139
+
140
+ def __init__(self, dim, dropout=0.0, expansion_factor=2):
141
  super().__init__()
 
142
  self.norm1 = nn.LayerNorm(dim)
143
+ self.cfc = CfCCell(dim=dim, dropout=dropout)
 
144
 
145
+ self.norm2 = nn.LayerNorm(dim)
146
  ff_dim = dim * expansion_factor
147
  self.ff = nn.Sequential(
148
+ nn.Linear(dim, ff_dim),
149
+ nn.GELU(),
150
+ nn.Dropout(dropout),
151
+ nn.Linear(ff_dim, dim),
152
+ nn.Dropout(dropout),
153
  )
 
 
154
 
155
+ def forward(self, x):
156
+ """
157
+ Args:
158
+ x: [B, C, H, W] or [B, L, C]
159
+ Returns:
160
+ Same shape as input
161
+ """
162
  is_2d = x.dim() == 4
163
  if is_2d:
164
  B, C, H, W = x.shape
165
+ x = x.flatten(2).transpose(1, 2) # [B, HW, C]
166
+
167
+ # Pre-norm + CfC + residual
168
+ x = x + self.cfc(self.norm1(x))
169
+ # Pre-norm + FF + residual
170
+ x = x + self.ff(self.norm2(x))
171
+
172
+ if is_2d:
173
+ x = x.transpose(1, 2).reshape(B, C, H, W)
174
+ return x