krystv commited on
Commit
3de88b7
·
verified ·
1 Parent(s): 6f8fc17

Add model.py — core LiquidFlow architecture

Browse files
Files changed (1) hide show
  1. liquidflow/model.py +590 -0
liquidflow/model.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LiquidFlow: A Novel Liquid-SSM Flow Matching Image Generator
3
+
4
+ Architecture combines:
5
+ 1. Liquid Time-Constant (LTC) dynamics as the velocity field (Hasani et al. 2020)
6
+ 2. Selective State Space scanning (Mamba-style) in pure PyTorch for parallel training
7
+ 3. Zigzag scanning patterns for 2D spatial awareness (ZigMa, 2024)
8
+ 4. Physics-informed regularization (smoothness + continuity constraints)
9
+ 5. Closed-form Continuous-depth (CfC) approximation for fast forward pass
10
+ 6. Rectified Flow / Flow Matching training objective (Lipman et al. 2022)
11
+
12
+ Designed for:
13
+ - Training on Google Colab free tier (T4 16GB) or Kaggle (P100 16GB)
14
+ - Mobile deployment (< 15M parameters for 128x128, < 25M for 512x512)
15
+ - No custom CUDA kernels required - pure PyTorch
16
+ """
17
+
18
+ import math
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from einops import rearrange, repeat
23
+
24
+
25
+ # ============================================================
26
+ # 1. LIQUID TIME-CONSTANT CELL (CfC - Closed-Form Continuous)
27
+ # ============================================================
28
+
29
+ class LiquidCfCCell(nn.Module):
30
+ """
31
+ Closed-form Continuous-depth Liquid Cell.
32
+
33
+ Instead of solving the LTC ODE numerically:
34
+ dx/dt = -[1/τ + f(x,I,t)] * x + f(x,I,t)
35
+
36
+ We use the CfC closed-form solution:
37
+ x(t+Δt) = σ(-f_τ) ⊙ x(t) + (1 - σ(-f_τ)) ⊙ f_x
38
+
39
+ Where:
40
+ f_τ = learned time-constant modulation
41
+ f_x = learned state update
42
+ σ = sigmoid (ensures bounded dynamics → no explosion)
43
+
44
+ This is parallelizable (no sequential ODE steps) and stable by construction.
45
+ """
46
+
47
+ def __init__(self, input_dim, hidden_dim):
48
+ super().__init__()
49
+ self.hidden_dim = hidden_dim
50
+
51
+ # Time-constant network (τ modulation)
52
+ self.tau_net = nn.Sequential(
53
+ nn.Linear(hidden_dim + hidden_dim, hidden_dim),
54
+ nn.Tanh(), # Tanh per PINN stability research (Wang et al. 2020)
55
+ nn.Linear(hidden_dim, hidden_dim),
56
+ )
57
+
58
+ # State update network
59
+ self.state_net = nn.Sequential(
60
+ nn.Linear(hidden_dim + hidden_dim, hidden_dim),
61
+ nn.Tanh(),
62
+ nn.Linear(hidden_dim, hidden_dim),
63
+ )
64
+
65
+ # Backbone mixing (replaces wiring in original NCP)
66
+ self.backbone = nn.Linear(input_dim, hidden_dim)
67
+
68
+ def forward(self, x, h=None):
69
+ """
70
+ x: (B, L, input_dim) - input features
71
+ h: (B, hidden_dim) - hidden state (optional, zeros if None)
72
+
73
+ Returns: (B, L, hidden_dim) - output for all positions (parallelized)
74
+ """
75
+ B, L, D = x.shape
76
+
77
+ # Backbone projection: input preprocessing (NCP-style wiring)
78
+ x_proj = self.backbone(x) # (B, L, hidden_dim)
79
+
80
+ if h is None:
81
+ h = torch.zeros(B, self.hidden_dim, device=x.device, dtype=x.dtype)
82
+
83
+ # Expand h to match sequence length for parallel computation
84
+ h_expanded = h.unsqueeze(1).expand(-1, L, -1) # (B, L, hidden_dim)
85
+
86
+ # Use backbone-projected input + state for gating
87
+ xh = torch.cat([x_proj, h_expanded], dim=-1) # (B, L, hidden_dim + hidden_dim)
88
+
89
+ # Compute time-constant modulation and state update
90
+ f_tau = self.tau_net(xh) # (B, L, hidden_dim)
91
+ f_x = self.state_net(xh) # (B, L, hidden_dim)
92
+
93
+ # CfC closed-form update:
94
+ # gate = σ(-f_τ) controls how much old state to keep
95
+ # new_h = gate * h + (1 - gate) * f_x
96
+ gate = torch.sigmoid(-f_tau)
97
+ new_h = gate * h_expanded + (1.0 - gate) * f_x
98
+
99
+ return new_h # (B, L, hidden_dim)
100
+
101
+
102
+ # ============================================================
103
+ # 2. SELECTIVE STATE SPACE BLOCK (Pure PyTorch Mamba-style)
104
+ # ============================================================
105
+
106
+ class SelectiveSSM(nn.Module):
107
+ """
108
+ Simplified Selective State Space Model in pure PyTorch.
109
+
110
+ Key insight from Mamba: make B, C, Δ input-dependent (selective)
111
+ while keeping A fixed (diagonal, learned).
112
+
113
+ The discretized SSM:
114
+ h_i = Ā * h_{i-1} + B̄ * x_i
115
+ y_i = C * h_i
116
+ Where Ā = exp(Δ * A), B̄ ≈ Δ * B
117
+ """
118
+
119
+ def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
120
+ super().__init__()
121
+ self.d_model = d_model
122
+ self.d_state = d_state
123
+ self.d_inner = int(d_model * expand)
124
+
125
+ # Input projection (expand)
126
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
127
+
128
+ # 1D convolution for local context
129
+ self.conv1d = nn.Conv1d(
130
+ in_channels=self.d_inner,
131
+ out_channels=self.d_inner,
132
+ kernel_size=d_conv,
133
+ padding=d_conv - 1,
134
+ groups=self.d_inner,
135
+ bias=True,
136
+ )
137
+
138
+ # SSM parameters
139
+ # A: diagonal state matrix (fixed, learned)
140
+ # Initialize A with negative values for stability (ensures exp(ΔA) < 1)
141
+ A = torch.arange(1, d_state + 1, dtype=torch.float32)
142
+ self.A_log = nn.Parameter(torch.log(A).unsqueeze(0).expand(self.d_inner, -1).clone())
143
+
144
+ # D: skip connection
145
+ self.D = nn.Parameter(torch.ones(self.d_inner))
146
+
147
+ # Input-dependent projections for B, C, Δ
148
+ self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) # B, C, Δ
149
+ self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
150
+
151
+ # Output projection
152
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
153
+
154
+ # Initialize dt_proj bias for stable Δ range
155
+ with torch.no_grad():
156
+ dt_init = torch.exp(
157
+ torch.rand(self.d_inner) * (math.log(0.1) - math.log(0.001)) + math.log(0.001)
158
+ )
159
+ inv_dt = dt_init + torch.log(-torch.expm1(-dt_init))
160
+ self.dt_proj.bias.copy_(inv_dt)
161
+
162
+ def forward(self, x):
163
+ """
164
+ x: (B, L, d_model)
165
+ Returns: (B, L, d_model)
166
+ """
167
+ B, L, D = x.shape
168
+
169
+ # Input projection → split into x and z (gating)
170
+ xz = self.in_proj(x) # (B, L, 2*d_inner)
171
+ x_inner, z = xz.chunk(2, dim=-1) # each (B, L, d_inner)
172
+
173
+ # 1D convolution for local context
174
+ x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2)
175
+ x_conv = F.silu(x_conv)
176
+
177
+ # Compute input-dependent B, C, Δ
178
+ x_proj = self.x_proj(x_conv) # (B, L, 2*d_state + 1)
179
+ B_sel = x_proj[:, :, :self.d_state] # (B, L, d_state)
180
+ C_sel = x_proj[:, :, self.d_state:2*self.d_state] # (B, L, d_state)
181
+ dt = x_proj[:, :, -1:] # (B, L, 1)
182
+
183
+ # Project Δ to per-channel
184
+ dt = F.softplus(self.dt_proj(dt)) # (B, L, d_inner)
185
+
186
+ # Discretize: Ā = exp(Δ * A), B̄ = Δ * B
187
+ A = -torch.exp(self.A_log) # (d_inner, d_state), negative for stability
188
+
189
+ # SSM scan
190
+ y = self._selective_scan(x_conv, dt, A, B_sel, C_sel)
191
+
192
+ # Apply skip connection (D parameter)
193
+ y = y + x_conv * self.D.unsqueeze(0).unsqueeze(0)
194
+
195
+ # Gate with z
196
+ y = y * F.silu(z)
197
+
198
+ # Output projection
199
+ return self.out_proj(y)
200
+
201
+ def _selective_scan(self, x, dt, A, B, C):
202
+ """
203
+ Sequential selective scan (PyTorch-compatible, works on CPU/GPU).
204
+ For short sequences (image patches), this is fast enough.
205
+ No custom CUDA kernels needed.
206
+ """
207
+ B_batch, L, d_inner = x.shape
208
+ d_state = A.shape[1]
209
+
210
+ # Compute discretized parameters
211
+ dA = torch.einsum('bld,dn->bldn', dt, A) # (B, L, d_inner, d_state)
212
+ dA = torch.exp(dA) # Ā
213
+ dB = torch.einsum('bld,bln->bldn', dt, B) # (B, L, d_inner, d_state)
214
+
215
+ # x contribution: dB * x
216
+ dBx = dB * x.unsqueeze(-1) # (B, L, d_inner, d_state)
217
+
218
+ # Sequential scan
219
+ h = torch.zeros(B_batch, d_inner, d_state, device=x.device, dtype=x.dtype)
220
+ ys = []
221
+
222
+ for i in range(L):
223
+ h = dA[:, i] * h + dBx[:, i] # (B, d_inner, d_state)
224
+ y_i = torch.einsum('bdn,bn->bd', h, C[:, i]) # (B, d_inner)
225
+ ys.append(y_i)
226
+
227
+ y = torch.stack(ys, dim=1) # (B, L, d_inner)
228
+ return y
229
+
230
+
231
+ # ============================================================
232
+ # 3. ZIGZAG SCAN PATTERNS
233
+ # ============================================================
234
+
235
+ def create_scan_patterns(H, W):
236
+ """
237
+ Create zigzag scan patterns for 2D spatial awareness.
238
+ Returns 4 patterns: row-major, reversed, column-major, zigzag.
239
+ """
240
+ total = H * W
241
+ indices = torch.arange(total)
242
+
243
+ row_major = indices.clone()
244
+ row_major_rev = indices.flip(0)
245
+
246
+ grid = indices.view(H, W)
247
+ col_major = grid.t().contiguous().view(-1)
248
+
249
+ zigzag = []
250
+ for i in range(H):
251
+ row = grid[i]
252
+ if i % 2 == 1:
253
+ row = row.flip(0)
254
+ zigzag.append(row)
255
+ zigzag = torch.cat(zigzag)
256
+
257
+ patterns = [row_major, row_major_rev, col_major, zigzag]
258
+
259
+ inverse_patterns = []
260
+ for p in patterns:
261
+ inv = torch.zeros_like(p)
262
+ inv[p] = torch.arange(total)
263
+ inverse_patterns.append(inv)
264
+
265
+ return patterns, inverse_patterns
266
+
267
+
268
+ # ============================================================
269
+ # 4. LIQUID-SSM BLOCK (Core Building Block)
270
+ # ============================================================
271
+
272
+ class LiquidSSMBlock(nn.Module):
273
+ """
274
+ Combines Liquid CfC dynamics with Selective SSM in one block.
275
+
276
+ Dual-path: SSM captures long-range spatial dependencies via scanning,
277
+ Liquid CfC adds continuous-time adaptive dynamics with bounded gates.
278
+ """
279
+
280
+ def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dropout=0.0):
281
+ super().__init__()
282
+
283
+ self.norm1 = nn.LayerNorm(d_model)
284
+ self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand)
285
+
286
+ self.norm2 = nn.LayerNorm(d_model)
287
+ self.liquid = LiquidCfCCell(d_model, d_model)
288
+
289
+ self.norm3 = nn.LayerNorm(d_model)
290
+ self.ff = nn.Sequential(
291
+ nn.Linear(d_model, d_model * 4),
292
+ nn.GELU(),
293
+ nn.Dropout(dropout),
294
+ nn.Linear(d_model * 4, d_model),
295
+ nn.Dropout(dropout),
296
+ )
297
+
298
+ self.mix_alpha = nn.Parameter(torch.tensor(0.5))
299
+
300
+ def forward(self, x, scan_idx=None, unscan_idx=None):
301
+ if scan_idx is not None:
302
+ x_scanned = x[:, scan_idx]
303
+ else:
304
+ x_scanned = x
305
+
306
+ ssm_out = self.ssm(self.norm1(x_scanned))
307
+
308
+ if unscan_idx is not None:
309
+ ssm_out = ssm_out[:, unscan_idx]
310
+
311
+ liquid_out = self.liquid(self.norm2(x))
312
+
313
+ alpha = torch.sigmoid(self.mix_alpha)
314
+ mixed = alpha * ssm_out + (1.0 - alpha) * liquid_out
315
+
316
+ x = x + mixed
317
+ x = x + self.ff(self.norm3(x))
318
+
319
+ return x
320
+
321
+
322
+ # ============================================================
323
+ # 5. TIMESTEP & CONDITION EMBEDDINGS
324
+ # ============================================================
325
+
326
+ class SinusoidalPosEmb(nn.Module):
327
+ def __init__(self, dim):
328
+ super().__init__()
329
+ self.dim = dim
330
+
331
+ def forward(self, t):
332
+ device = t.device
333
+ half_dim = self.dim // 2
334
+ emb = math.log(10000) / (half_dim - 1)
335
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
336
+ emb = t.unsqueeze(-1) * emb.unsqueeze(0)
337
+ emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
338
+ return emb
339
+
340
+
341
+ class AdaptiveLayerNorm(nn.Module):
342
+ """DiT-style Adaptive Layer Norm with scale and shift from condition."""
343
+ def __init__(self, d_model, cond_dim):
344
+ super().__init__()
345
+ self.norm = nn.LayerNorm(d_model, elementwise_affine=False)
346
+ self.proj = nn.Sequential(
347
+ nn.SiLU(),
348
+ nn.Linear(cond_dim, d_model * 2),
349
+ )
350
+
351
+ def forward(self, x, cond):
352
+ scale_shift = self.proj(cond)
353
+ scale, shift = scale_shift.chunk(2, dim=-1)
354
+ scale = scale.unsqueeze(1)
355
+ shift = shift.unsqueeze(1)
356
+ return self.norm(x) * (1 + scale) + shift
357
+
358
+
359
+ # ============================================================
360
+ # 6. LIQUIDFLOW VELOCITY NETWORK (Full Architecture)
361
+ # ============================================================
362
+
363
+ class LiquidFlowNet(nn.Module):
364
+ """
365
+ LiquidFlow: The complete velocity field network for flow matching.
366
+
367
+ Training: ||v_θ(x_t, t) - (x_1 - x_0)||² (rectified flow)
368
+ Sampling: x_{t+dt} = x_t + v_θ(x_t, t) * dt (Euler method)
369
+ """
370
+
371
+ def __init__(
372
+ self,
373
+ img_size=128,
374
+ patch_size=4,
375
+ in_channels=3,
376
+ d_model=256,
377
+ depth=8,
378
+ d_state=16,
379
+ d_conv=4,
380
+ expand=2,
381
+ dropout=0.0,
382
+ num_classes=0,
383
+ ):
384
+ super().__init__()
385
+ self.img_size = img_size
386
+ self.patch_size = patch_size
387
+ self.in_channels = in_channels
388
+ self.d_model = d_model
389
+ self.depth = depth
390
+ self.num_classes = num_classes
391
+
392
+ self.num_patches_h = img_size // patch_size
393
+ self.num_patches_w = img_size // patch_size
394
+ self.num_patches = self.num_patches_h * self.num_patches_w
395
+ self.patch_dim = in_channels * patch_size * patch_size
396
+
397
+ self.patch_embed = nn.Sequential(
398
+ nn.Linear(self.patch_dim, d_model),
399
+ nn.LayerNorm(d_model),
400
+ )
401
+
402
+ self.pos_embed = nn.Parameter(
403
+ torch.randn(1, self.num_patches, d_model) * 0.02
404
+ )
405
+
406
+ self.time_embed = nn.Sequential(
407
+ SinusoidalPosEmb(d_model),
408
+ nn.Linear(d_model, d_model * 4),
409
+ nn.GELU(),
410
+ nn.Linear(d_model * 4, d_model),
411
+ )
412
+
413
+ if num_classes > 0:
414
+ self.class_embed = nn.Embedding(num_classes, d_model)
415
+ else:
416
+ self.class_embed = None
417
+
418
+ cond_dim = d_model
419
+
420
+ self.blocks = nn.ModuleList([
421
+ LiquidSSMBlock(d_model, d_state, d_conv, expand, dropout)
422
+ for _ in range(depth)
423
+ ])
424
+
425
+ self.adaln_blocks = nn.ModuleList([
426
+ AdaptiveLayerNorm(d_model, cond_dim)
427
+ for _ in range(depth)
428
+ ])
429
+
430
+ self.skip_projs = nn.ModuleList()
431
+ for i in range(depth // 2):
432
+ self.skip_projs.append(nn.Linear(d_model * 2, d_model))
433
+
434
+ self.final_norm = nn.LayerNorm(d_model)
435
+ self.final_proj = nn.Linear(d_model, self.patch_dim)
436
+
437
+ patterns, inv_patterns = create_scan_patterns(
438
+ self.num_patches_h, self.num_patches_w
439
+ )
440
+ for i, (p, ip) in enumerate(zip(patterns, inv_patterns)):
441
+ self.register_buffer(f'scan_{i}', p)
442
+ self.register_buffer(f'unscan_{i}', ip)
443
+
444
+ self.num_scan_patterns = len(patterns)
445
+
446
+ self.pre_conv = nn.Conv2d(d_model, d_model, 3, padding=1, groups=d_model)
447
+ self.post_conv = nn.Conv2d(d_model, d_model, 3, padding=1, groups=d_model)
448
+
449
+ self._init_weights()
450
+
451
+ def _init_weights(self):
452
+ for m in self.modules():
453
+ if isinstance(m, nn.Linear):
454
+ nn.init.xavier_uniform_(m.weight)
455
+ if m.bias is not None:
456
+ nn.init.zeros_(m.bias)
457
+ elif isinstance(m, (nn.Conv2d, nn.Conv1d)):
458
+ nn.init.xavier_uniform_(m.weight)
459
+ if m.bias is not None:
460
+ nn.init.zeros_(m.bias)
461
+ nn.init.zeros_(self.final_proj.weight)
462
+ nn.init.zeros_(self.final_proj.bias)
463
+
464
+ def patchify(self, x):
465
+ B, C, H, W = x.shape
466
+ p = self.patch_size
467
+ x = x.unfold(2, p, p).unfold(3, p, p)
468
+ x = x.contiguous().view(B, C, self.num_patches_h, self.num_patches_w, p * p)
469
+ x = x.permute(0, 2, 3, 1, 4)
470
+ x = x.contiguous().view(B, self.num_patches, self.patch_dim)
471
+ return x
472
+
473
+ def unpatchify(self, x):
474
+ B = x.shape[0]
475
+ p = self.patch_size
476
+ C = self.in_channels
477
+ H = self.num_patches_h
478
+ W = self.num_patches_w
479
+ x = x.view(B, H, W, C, p, p)
480
+ x = x.permute(0, 3, 1, 4, 2, 5)
481
+ x = x.contiguous().view(B, C, H * p, W * p)
482
+ return x
483
+
484
+ def forward(self, x, t, class_label=None):
485
+ B = x.shape[0]
486
+
487
+ tokens = self.patchify(x)
488
+ tokens = self.patch_embed(tokens)
489
+ tokens = tokens + self.pos_embed
490
+
491
+ h_2d = tokens.view(B, self.num_patches_h, self.num_patches_w, self.d_model)
492
+ h_2d = h_2d.permute(0, 3, 1, 2)
493
+ h_2d = self.pre_conv(h_2d)
494
+ tokens = h_2d.permute(0, 2, 3, 1).contiguous().view(B, self.num_patches, self.d_model)
495
+
496
+ t_emb = self.time_embed(t)
497
+ if self.class_embed is not None and class_label is not None:
498
+ t_emb = t_emb + self.class_embed(class_label)
499
+
500
+ skips = []
501
+
502
+ for i, (block, adaln) in enumerate(zip(self.blocks, self.adaln_blocks)):
503
+ tokens = adaln(tokens, t_emb)
504
+
505
+ scan_pattern_idx = i % self.num_scan_patterns
506
+ scan_idx = getattr(self, f'scan_{scan_pattern_idx}')
507
+ unscan_idx = getattr(self, f'unscan_{scan_pattern_idx}')
508
+
509
+ if i < self.depth // 2:
510
+ skips.append(tokens)
511
+
512
+ tokens = block(tokens, scan_idx, unscan_idx)
513
+
514
+ if i >= self.depth // 2:
515
+ skip_idx = self.depth - 1 - i
516
+ if skip_idx < len(skips):
517
+ skip_proj = self.skip_projs[skip_idx]
518
+ tokens = skip_proj(torch.cat([tokens, skips[skip_idx]], dim=-1))
519
+
520
+ h_2d = tokens.view(B, self.num_patches_h, self.num_patches_w, self.d_model)
521
+ h_2d = h_2d.permute(0, 3, 1, 2)
522
+ h_2d = self.post_conv(h_2d)
523
+ tokens = h_2d.permute(0, 2, 3, 1).contiguous().view(B, self.num_patches, self.d_model)
524
+
525
+ tokens = self.final_norm(tokens)
526
+ velocity = self.final_proj(tokens)
527
+ velocity = self.unpatchify(velocity)
528
+
529
+ return velocity
530
+
531
+ def count_params(self):
532
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
533
+
534
+
535
+ # ============================================================
536
+ # 7. MODEL CONFIGURATIONS
537
+ # ============================================================
538
+
539
+ def liquidflow_tiny(img_size=128, num_classes=0):
540
+ """~5M params - for quick experiments and 128x128"""
541
+ return LiquidFlowNet(
542
+ img_size=img_size, patch_size=4, in_channels=3,
543
+ d_model=192, depth=6, d_state=8, d_conv=4, expand=2,
544
+ num_classes=num_classes,
545
+ )
546
+
547
+ def liquidflow_small(img_size=128, num_classes=0):
548
+ """~12M params - main model for 128x128"""
549
+ return LiquidFlowNet(
550
+ img_size=img_size, patch_size=4, in_channels=3,
551
+ d_model=256, depth=8, d_state=16, d_conv=4, expand=2,
552
+ num_classes=num_classes,
553
+ )
554
+
555
+ def liquidflow_base(img_size=256, num_classes=0):
556
+ """~25M params - for 256x256"""
557
+ return LiquidFlowNet(
558
+ img_size=img_size, patch_size=8, in_channels=3,
559
+ d_model=384, depth=10, d_state=16, d_conv=4, expand=2,
560
+ num_classes=num_classes,
561
+ )
562
+
563
+ def liquidflow_512(img_size=512, num_classes=0):
564
+ """~25M params - for 512x512"""
565
+ return LiquidFlowNet(
566
+ img_size=img_size, patch_size=16, in_channels=3,
567
+ d_model=384, depth=10, d_state=16, d_conv=4, expand=2,
568
+ num_classes=num_classes,
569
+ )
570
+
571
+
572
+ if __name__ == "__main__":
573
+ device = torch.device("cpu")
574
+ for name, factory in [
575
+ ("tiny-128", lambda: liquidflow_tiny(128)),
576
+ ("small-128", lambda: liquidflow_small(128)),
577
+ ("base-256", lambda: liquidflow_base(256)),
578
+ ("512", lambda: liquidflow_512(512)),
579
+ ]:
580
+ model = factory().to(device)
581
+ params = model.count_params()
582
+ print(f"\n{name}: {params/1e6:.2f}M params")
583
+ B = 2
584
+ img_size = model.img_size
585
+ x = torch.randn(B, 3, img_size, img_size, device=device)
586
+ t = torch.rand(B, device=device)
587
+ v = model(x, t)
588
+ print(f" Input: {x.shape} → Output: {v.shape}")
589
+ assert v.shape == x.shape
590
+ print(f" ✓ Forward pass OK")