krystv commited on
Commit
099d3c7
·
verified ·
1 Parent(s): c57f0d4

v0.4: Use mambapy.pscan (proven, grad-safe parallel scan) — no more torch.associative_scan

Browse files
Files changed (1) hide show
  1. liquidflow/model.py +58 -307
liquidflow/model.py CHANGED
@@ -1,12 +1,8 @@
1
  """
2
- LiquidFlow: A Novel Liquid-SSM Flow Matching Image Generator
3
- v0.3.0 — PARALLEL SSM scan via torch.associative_scan (O(log L) not O(L))
4
-
5
- Key changes from v0.2:
6
- - SSM uses torch.associative_scan for O(log L) parallel scan (no Python for-loop)
7
- - Fallback: Blelloch tree-scan in pure PyTorch for older PyTorch versions
8
- - patch_size=8 for 128×128 → L=256 tokens (not 1024)
9
- - patch_size=4 for 32/64 → fine at small sizes
10
  """
11
 
12
  import math
@@ -15,210 +11,35 @@ import torch.nn as nn
15
  import torch.nn.functional as F
16
  from torch.utils.checkpoint import checkpoint
17
 
18
- # ---- Parallel Scan Infrastructure ----
19
-
20
- HAS_NATIVE_SCAN = False
21
  try:
22
- from torch._higher_order_ops.associative_scan import associative_scan as _native_scan
23
- HAS_NATIVE_SCAN = True
24
  except ImportError:
25
- pass
26
-
27
-
28
- def _ssm_combine(left, right):
29
- """Associative operator for SSM: (a,b) ⊕ (a',b') = (a'*a, a'*b + b')"""
30
- return (left[0] * right[0], right[0] * left[1] + right[1])
31
-
32
-
33
- def parallel_scan_native(A, X, dim=1):
34
- """Use PyTorch built-in associative_scan (≥2.4). O(log L) parallel depth."""
35
- mode = 'pointwise' if A.is_cuda else 'generic'
36
- _, h_all = _native_scan(_ssm_combine, (A, X), dim=dim, combine_mode=mode)
37
- return h_all
38
 
39
-
40
- def parallel_scan_blelloch(A, X):
41
- """
42
- Blelloch tree-scan fallback for older PyTorch.
43
- Pure tensor ops, O(L log L) work, O(log L) depth.
44
- A, X: (B, L, D) — L must be power of 2.
45
- Returns: H (B, L, D) — all prefix scan results.
46
- """
47
- B, L, D = A.shape
48
- assert L & (L - 1) == 0, f"L must be power of 2, got {L}"
49
-
50
- Aa = A.clone()
51
- Xa = X.clone()
52
- num_steps = int(math.log2(L))
53
-
54
- # Up-sweep (reduce): merge pairs → quads → ...
55
- for k in range(num_steps):
56
- s = 2 ** (k + 1)
57
- half = s // 2
58
- # right = op(left, right) for all pairs in parallel
59
- Xa[:, s - 1::s] = Aa[:, s - 1::s] * Xa[:, half - 1::s] + Xa[:, s - 1::s]
60
- Aa[:, s - 1::s] = Aa[:, s - 1::s] * Aa[:, half - 1::s]
61
-
62
- # Clear last element (it has the total reduction, not needed for scan)
63
- Xa[:, -1] = 0
64
- Aa[:, -1] = 0
65
-
66
- # Down-sweep: distribute prefix sums back
67
- for k in range(num_steps - 1, -1, -1):
68
- s = 2 ** (k + 1)
69
- half = s // 2
70
- # Save left child
71
- tmp_a = Aa[:, half - 1::s].clone()
72
- tmp_x = Xa[:, half - 1::s].clone()
73
- # Left child ← parent
74
- Aa[:, half - 1::s] = Aa[:, s - 1::s]
75
- Xa[:, half - 1::s] = Xa[:, s - 1::s]
76
- # Right child ← op(parent, saved left)
77
- Xa[:, s - 1::s] = Aa[:, s - 1::s] * tmp_x + Xa[:, s - 1::s] # WRONG — needs old right
78
- Aa[:, s - 1::s] = Aa[:, s - 1::s] * tmp_a
79
-
80
- # The Blelloch scan gives exclusive prefix sums. Convert to inclusive:
81
- # h_t = A_t * prefix_{t-1} + X_t
82
- # For inclusive: shift and apply one more step
83
- # Actually, let's use the simpler Hillis-Steele approach which gives inclusive directly:
84
- pass # Blelloch is tricky to get right — use Hillis-Steele instead
85
-
86
-
87
- def parallel_scan_hillis_steele(A, X):
88
  """
89
- Hillis-Steele inclusive parallel scan. Simpler than Blelloch.
90
- O(L log L) work, O(log L) depth. All tensor operations.
91
- A, X: (B, L, D). Returns H: (B, L, D) = all hidden states.
92
  """
93
- B, L, D = A.shape
94
-
95
- # Pad to power of 2 if needed
96
- orig_L = L
97
- next_pow2 = 1 << (L - 1).bit_length()
98
- if next_pow2 != L:
99
- pad = next_pow2 - L
100
- A = F.pad(A, (0, 0, 0, pad), value=1.0) # pad A with 1 (identity for mult)
101
- X = F.pad(X, (0, 0, 0, pad), value=0.0) # pad X with 0 (identity for add)
102
- L = next_pow2
103
-
104
- h = X.clone() # (B, L, D)
105
- a = A.clone()
106
-
107
- num_steps = int(math.log2(L))
108
- for d in range(num_steps):
109
- stride = 2 ** d
110
- # h[i] = a_shifted[i] * h[i-stride] + h[i] (for i >= stride)
111
- h_shifted = F.pad(h[:, :-stride], (0, 0, stride, 0)) # shift right by stride
112
- a_shifted = F.pad(a[:, :-stride], (0, 0, stride, 0), value=1.0)
113
-
114
- h = a_shifted * h_shifted + h # this is wrong for multi-step...
115
-
116
- # Actually Hillis-Steele doesn't directly work for (a,b) pairs.
117
- # Let me implement the correct parallel prefix approach.
118
- return h[:, :orig_L]
119
-
120
-
121
- def parallel_scan_correct(A, X):
122
- """
123
- Work-efficient parallel prefix scan for SSM recurrence.
124
- h_t = A_t * h_{t-1} + X_t
125
-
126
- Uses up-sweep + down-sweep on the (A, X) pair.
127
- A, X: (B, L, D). Returns H: (B, L, D).
128
- """
129
- B, L, D = A.shape
130
-
131
- # Pad L to power of 2
132
- orig_L = L
133
- next_pow2 = 1 << (L - 1).bit_length()
134
- if next_pow2 != L:
135
- pad = next_pow2 - L
136
- A = F.pad(A, (0, 0, 0, pad), value=1.0)
137
- X = F.pad(X, (0, 0, 0, pad), value=0.0)
138
- L = next_pow2
139
-
140
- # Work on clones
141
- a = A.clone()
142
- x = X.clone()
143
-
144
- # Store intermediate values for down-sweep
145
- a_levels = []
146
- x_levels = []
147
-
148
- # UP-SWEEP: reduce pairs
149
- num_levels = int(math.log2(L))
150
- for level in range(num_levels):
151
- # Current length
152
- cur_len = a.shape[1]
153
- a_even = a[:, 0::2] # left children
154
- a_odd = a[:, 1::2] # right children
155
- x_even = x[:, 0::2]
156
- x_odd = x[:, 1::2]
157
-
158
- # Save for down-sweep
159
- a_levels.append((a_even.clone(), a_odd.clone()))
160
- x_levels.append((x_even.clone(), x_odd.clone()))
161
-
162
- # Merge: right = right ⊕ left → (a_r*a_l, a_r*x_l + x_r)
163
- a = a_odd * a_even
164
- x = a_odd * x_even + x_odd
165
-
166
- # After up-sweep, a and x have length 1 containing the full reduction.
167
- # We need the inclusive prefix scan, not just the total.
168
-
169
- # DOWN-SWEEP: propagate prefix sums
170
- # Start with identity prefix (for position before the first element)
171
- prefix_a = torch.ones(B, 1, D, device=A.device, dtype=A.dtype)
172
- prefix_x = torch.zeros(B, 1, D, device=A.device, dtype=A.dtype)
173
-
174
- for level in range(num_levels - 1, -1, -1):
175
- a_even, a_odd = a_levels[level]
176
- x_even, x_odd = x_levels[level]
177
-
178
- # For each pair (even, odd) with prefix:
179
- # Result for even = prefix ⊕ even
180
- # Result for odd = (prefix ⊕ even) ⊕ odd
181
-
182
- # prefix ⊕ even: (prefix_a * a_even, prefix_a * x_even + prefix_x)
183
- new_a_even = prefix_a * a_even
184
- new_x_even = prefix_a * x_even + prefix_x
185
-
186
- # (prefix ⊕ even) ⊕ odd: (new_a_even * a_odd, a_odd * new_x_even + x_odd)
187
- # Wait, the operator order matters. SSM recurrence: h_t = A_t * h_{t-1} + X_t
188
- # So element t is (A_t, X_t), and the scan computes h_t = result_x of prefix up to t.
189
- # The operator is: (a_l, x_l) ⊕ (a_r, x_r) = (a_r * a_l, a_r * x_l + x_r)
190
- new_a_odd = a_odd * new_a_even
191
- new_x_odd = a_odd * new_x_even + x_odd
192
-
193
- # Interleave back: [even_0, odd_0, even_1, odd_1, ...]
194
- out_a = torch.stack([new_a_even, new_a_odd], dim=2).view(B, -1, D)
195
- out_x = torch.stack([new_x_even, new_x_odd], dim=2).view(B, -1, D)
196
-
197
- prefix_a = out_a
198
- prefix_x = out_x
199
-
200
- return prefix_x[:, :orig_L]
201
-
202
-
203
- def parallel_ssm_scan(A, X):
204
- """
205
- Top-level SSM parallel scan dispatcher.
206
- A: (B, L, D) — discretized diagonal A (decay) per timestep
207
- X: (B, L, D) — B_bar * u (input contribution) per timestep
208
- Returns: H (B, L, D) — all hidden states h_1..h_L
209
- """
210
- if HAS_NATIVE_SCAN:
211
- return parallel_scan_native(A, X, dim=1)
212
  else:
213
- return parallel_scan_correct(A, X)
 
 
 
 
 
 
 
214
 
215
 
216
- # ============================================================
217
- # 1. LIQUID TIME-CONSTANT CELL
218
- # ============================================================
219
-
220
  class LiquidCfCCell(nn.Module):
221
- """CfC: gate=σ(-f_τ), out = gate*h + (1-gate)*f_x. Bounded by sigmoid."""
222
  def __init__(self, input_dim, hidden_dim):
223
  super().__init__()
224
  self.backbone = nn.Linear(input_dim, hidden_dim)
@@ -231,16 +52,8 @@ class LiquidCfCCell(nn.Module):
231
  return gate * h + (1.0 - gate) * f_x
232
 
233
 
234
- # ============================================================
235
- # 2. SELECTIVE SSM — PARALLEL SCAN
236
- # ============================================================
237
-
238
  class SelectiveSSM(nn.Module):
239
- """
240
- Selective SSM with PARALLEL scan. No Python for-loops over L.
241
- Uses torch.associative_scan on GPU, tree-scan fallback on CPU.
242
- Training speed: O(L log L) parallel vs O(L) sequential.
243
- """
244
  def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
245
  super().__init__()
246
  self.d_model = d_model
@@ -267,83 +80,48 @@ class SelectiveSSM(nn.Module):
267
  B, L, _ = x.shape
268
  xz = self.in_proj(x)
269
  x_inner, z = xz.chunk(2, dim=-1)
270
-
271
  x_conv = F.silu(self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2))
272
 
273
  x_ssm = self.x_proj(x_conv)
274
- B_sel = x_ssm[:, :, :self.d_state] # (B, L, N)
275
- C_sel = x_ssm[:, :, self.d_state:2*self.d_state] # (B, L, N)
276
- dt = F.softplus(self.dt_proj(x_ssm[:, :, -1:])) # (B, L, d_inner)
277
 
278
  A = -torch.exp(self.A_log) # (d_inner, N)
279
 
280
- y = self._parallel_ssm(x_conv, dt, A, B_sel, C_sel)
281
-
282
- y = y + x_conv * self.D.unsqueeze(0).unsqueeze(0)
283
- return self.out_proj(y * F.silu(z))
284
-
285
- def _parallel_ssm(self, x, dt, A, B, C):
286
- """
287
- Parallel selective scan. No Python for-loop.
288
- x: (B, L, d_inner)
289
- dt: (B, L, d_inner)
290
- A: (d_inner, N) — negative
291
- B: (B, L, N)
292
- C: (B, L, N)
293
- Returns: y (B, L, d_inner)
294
- """
295
- Bs, L, d_inner = x.shape
296
- N = A.shape[1]
297
-
298
- # Discretize: A_bar = exp(dt * A) — per (batch, pos, channel, state)
299
- # dt: (B, L, d_inner) → (B, L, d_inner, 1)
300
- # A: (d_inner, N) → (1, 1, d_inner, N)
301
  A_bar = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) # (B, L, d_inner, N)
 
302
 
303
- # B_bar * x: dt * B * x → (B, L, d_inner, N)
304
- BX = dt.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1) # (B, L, d_inner, N)
305
-
306
- # Flatten (d_inner, N) → D for the scan
307
- D = d_inner * N
308
- A_flat = A_bar.reshape(Bs, L, D) # (B, L, D)
309
- BX_flat = BX.reshape(Bs, L, D) # (B, L, D)
310
-
311
- # PARALLEL SCAN: h_t = A_t * h_{t-1} + BX_t
312
- h_flat = parallel_ssm_scan(A_flat, BX_flat) # (B, L, D)
313
 
314
- # Unflatten and apply C
315
- h = h_flat.reshape(Bs, L, d_inner, N) # (B, L, d_inner, N)
 
316
 
317
- # y_t = sum_n(C_t_n * h_t_n) → (B, L, d_inner)
318
- y = (h * C.unsqueeze(2)).sum(-1)
319
-
320
- return y
321
 
 
 
322
 
323
- # ============================================================
324
- # 3. ZIGZAG SCAN PATTERNS
325
- # ============================================================
326
 
327
  def create_scan_patterns(H, W):
328
- total = H * W
329
- idx = torch.arange(total)
330
- grid = idx.view(H, W)
331
- patterns = [
332
- idx.clone(),
333
- idx.flip(0),
334
- grid.t().contiguous().view(-1),
335
- torch.cat([grid[i].flip(0) if i % 2 else grid[i] for i in range(H)]),
336
- ]
337
  inv = []
338
  for p in patterns:
339
  i = torch.zeros_like(p); i[p] = torch.arange(total); inv.append(i)
340
  return patterns, inv
341
 
342
 
343
- # ============================================================
344
- # 4. LIQUID-SSM BLOCK
345
- # ============================================================
346
-
347
  class LiquidSSMBlock(nn.Module):
348
  def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dropout=0.0):
349
  super().__init__()
@@ -356,35 +134,27 @@ class LiquidSSMBlock(nn.Module):
356
  nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout),
357
  nn.Linear(d_model * 4, d_model), nn.Dropout(dropout))
358
  self.mix_alpha = nn.Parameter(torch.tensor(0.5))
359
-
360
  def _ssm_fwd(self, x): return self.ssm(self.norm1(x))
361
  def _liq_fwd(self, x): return self.liquid(self.norm2(x))
362
-
363
  def forward(self, x, scan_idx=None, unscan_idx=None):
364
  xs = x[:, scan_idx] if scan_idx is not None else x
365
  if self.training and x.requires_grad:
366
  so = checkpoint(self._ssm_fwd, xs, use_reentrant=False)
367
  lo = checkpoint(self._liq_fwd, x, use_reentrant=False)
368
  else:
369
- so = self._ssm_fwd(xs)
370
- lo = self._liq_fwd(x)
371
  if unscan_idx is not None: so = so[:, unscan_idx]
372
  a = torch.sigmoid(self.mix_alpha)
373
  x = x + a * so + (1 - a) * lo
374
  return x + self.ff(self.norm3(x))
375
 
376
 
377
- # ============================================================
378
- # 5. EMBEDDINGS
379
- # ============================================================
380
-
381
  class SinusoidalPosEmb(nn.Module):
382
  def __init__(self, dim): super().__init__(); self.dim = dim
383
  def forward(self, t):
384
  h = self.dim // 2; e = math.log(10000)/(h-1)
385
  e = torch.exp(torch.arange(h, device=t.device)*-e)
386
- e = t.unsqueeze(-1)*e.unsqueeze(0)
387
- return torch.cat([e.sin(), e.cos()], -1)
388
 
389
  class AdaptiveLayerNorm(nn.Module):
390
  def __init__(self, d, c):
@@ -395,10 +165,6 @@ class AdaptiveLayerNorm(nn.Module):
395
  return self.norm(x) * (1+s.unsqueeze(1)) + b.unsqueeze(1)
396
 
397
 
398
- # ============================================================
399
- # 6. LIQUIDFLOW VELOCITY NETWORK
400
- # ============================================================
401
-
402
  class LiquidFlowNet(nn.Module):
403
  def __init__(self, img_size=128, patch_size=8, in_channels=3, d_model=256,
404
  depth=8, d_state=16, d_conv=4, expand=2, dropout=0.0, num_classes=0):
@@ -406,11 +172,10 @@ class LiquidFlowNet(nn.Module):
406
  self.img_size = img_size; self.patch_size = patch_size
407
  self.in_channels = in_channels; self.d_model = d_model
408
  self.depth = depth; self.num_classes = num_classes
409
- self.nph = img_size // patch_size; self.npw = img_size // patch_size
410
- self.num_patches = self.nph * self.npw
 
411
  self.patch_dim = in_channels * patch_size * patch_size
412
- # Alias for backward compat
413
- self.num_patches_h = self.nph; self.num_patches_w = self.npw
414
 
415
  self.patch_embed = nn.Sequential(nn.Linear(self.patch_dim, d_model), nn.LayerNorm(d_model))
416
  self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, d_model) * 0.02)
@@ -420,15 +185,13 @@ class LiquidFlowNet(nn.Module):
420
  self.blocks = nn.ModuleList([LiquidSSMBlock(d_model, d_state, d_conv, expand, dropout) for _ in range(depth)])
421
  self.adaln = nn.ModuleList([AdaptiveLayerNorm(d_model, d_model) for _ in range(depth)])
422
  self.skips = nn.ModuleList([nn.Linear(d_model*2, d_model) for _ in range(depth//2)])
423
-
424
  self.final_norm = nn.LayerNorm(d_model)
425
  self.final_proj = nn.Linear(d_model, self.patch_dim)
426
 
427
- pats, ipats = create_scan_patterns(self.nph, self.npw)
428
  for i,(p,ip) in enumerate(zip(pats, ipats)):
429
  self.register_buffer(f'scan_{i}', p); self.register_buffer(f'unscan_{i}', ip)
430
  self.n_scans = len(pats)
431
-
432
  self.pre_conv = nn.Conv2d(d_model, d_model, 3, padding=1, groups=d_model)
433
  self.post_conv = nn.Conv2d(d_model, d_model, 3, padding=1, groups=d_model)
434
  self._init_weights()
@@ -445,32 +208,28 @@ class LiquidFlowNet(nn.Module):
445
 
446
  def patchify(self, x):
447
  B,C,H,W = x.shape; p = self.patch_size
448
- return x.unfold(2,p,p).unfold(3,p,p).contiguous().view(B,C,self.nph,self.npw,p*p).permute(0,2,3,1,4).contiguous().view(B,self.num_patches,self.patch_dim)
449
 
450
  def unpatchify(self, x):
451
  B=x.shape[0]; p=self.patch_size
452
- return x.view(B,self.nph,self.npw,self.in_channels,p,p).permute(0,3,1,4,2,5).contiguous().view(B,self.in_channels,self.nph*p,self.npw*p)
453
 
454
  def forward(self, x, t, class_label=None):
455
  B = x.shape[0]
456
  tok = self.patch_embed(self.patchify(x)) + self.pos_embed
457
- h = tok.view(B,self.nph,self.npw,self.d_model).permute(0,3,1,2)
458
  tok = self.pre_conv(h).permute(0,2,3,1).contiguous().view(B,self.num_patches,self.d_model)
459
-
460
  te = self.time_embed(t)
461
  if self.class_embed is not None and class_label is not None: te = te + self.class_embed(class_label)
462
-
463
  sk = []
464
  for i,(blk,aln) in enumerate(zip(self.blocks, self.adaln)):
465
- tok = aln(tok, te)
466
- si = i % self.n_scans
467
  if i < self.depth//2: sk.append(tok)
468
  tok = blk(tok, getattr(self,f'scan_{si}'), getattr(self,f'unscan_{si}'))
469
  if i >= self.depth//2:
470
  j = self.depth-1-i
471
  if j < len(sk): tok = self.skips[j](torch.cat([tok, sk[j]], -1))
472
-
473
- h = tok.view(B,self.nph,self.npw,self.d_model).permute(0,3,1,2)
474
  tok = self.post_conv(h).permute(0,2,3,1).contiguous().view(B,self.num_patches,self.d_model)
475
  return self.unpatchify(self.final_proj(self.final_norm(tok)))
476
 
@@ -478,24 +237,16 @@ class LiquidFlowNet(nn.Module):
478
  return sum(p.numel() for p in self.parameters() if p.requires_grad)
479
 
480
 
481
- # ============================================================
482
- # 7. MODEL CONFIGURATIONS
483
- # ============================================================
484
-
485
  def liquidflow_tiny(img_size=128, num_classes=0):
486
- """~4M params — Colab free tier"""
487
  ps = 4 if img_size <= 64 else 8
488
  return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=192, depth=6, d_state=8, expand=2, num_classes=num_classes)
489
 
490
  def liquidflow_small(img_size=128, num_classes=0):
491
- """~10M params — production 128×128"""
492
  ps = 4 if img_size <= 64 else 8
493
  return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=256, depth=8, d_state=16, expand=2, num_classes=num_classes)
494
 
495
  def liquidflow_base(img_size=256, num_classes=0):
496
- """~25M params — 256×256"""
497
  return LiquidFlowNet(img_size=img_size, patch_size=8, d_model=384, depth=10, d_state=16, expand=2, num_classes=num_classes)
498
 
499
  def liquidflow_512(img_size=512, num_classes=0):
500
- """~25M params — 512×512"""
501
  return LiquidFlowNet(img_size=img_size, patch_size=16, d_model=384, depth=10, d_state=16, expand=2, num_classes=num_classes)
 
1
  """
2
+ LiquidFlow v0.4 Parallel SSM scan via mambapy.pscan (proven, grad-safe)
3
+
4
+ Uses the battle-tested pscan from alxndrTL/mamba.py — O(log L) parallel,
5
+ full autograd support, no custom CUDA kernels, works on any GPU.
 
 
 
 
6
  """
7
 
8
  import math
 
11
  import torch.nn.functional as F
12
  from torch.utils.checkpoint import checkpoint
13
 
14
+ # ---- Parallel Scan: mambapy (primary) or sequential fallback ----
 
 
15
  try:
16
+ from mambapy.pscan import pscan as _pscan
17
+ HAS_PSCAN = True
18
  except ImportError:
19
+ HAS_PSCAN = False
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def parallel_scan(A, X):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  """
23
+ Parallel prefix scan for SSM: h_t = A_t * h_{t-1} + X_t
24
+ A, X: (B, L, ED, N). Returns H: (B, L, ED, N).
25
+ Uses mambapy.pscan (O(log L)), falls back to sequential if unavailable.
26
  """
27
+ if HAS_PSCAN:
28
+ # pscan modifies X in-place, so clone
29
+ return _pscan(A, X.clone())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  else:
31
+ # Sequential fallback
32
+ B, L, ED, N = A.shape
33
+ h = torch.zeros(B, ED, N, device=A.device, dtype=A.dtype)
34
+ ys = []
35
+ for i in range(L):
36
+ h = A[:, i] * h + X[:, i]
37
+ ys.append(h)
38
+ return torch.stack(ys, dim=1)
39
 
40
 
 
 
 
 
41
  class LiquidCfCCell(nn.Module):
42
+ """CfC cell: sigmoid gating guarantees bounded dynamics."""
43
  def __init__(self, input_dim, hidden_dim):
44
  super().__init__()
45
  self.backbone = nn.Linear(input_dim, hidden_dim)
 
52
  return gate * h + (1.0 - gate) * f_x
53
 
54
 
 
 
 
 
55
  class SelectiveSSM(nn.Module):
56
+ """Selective SSM with parallel scan via mambapy.pscan."""
 
 
 
 
57
  def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
58
  super().__init__()
59
  self.d_model = d_model
 
80
  B, L, _ = x.shape
81
  xz = self.in_proj(x)
82
  x_inner, z = xz.chunk(2, dim=-1)
 
83
  x_conv = F.silu(self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2))
84
 
85
  x_ssm = self.x_proj(x_conv)
86
+ B_sel = x_ssm[:, :, :self.d_state]
87
+ C_sel = x_ssm[:, :, self.d_state:2*self.d_state]
88
+ dt = F.softplus(self.dt_proj(x_ssm[:, :, -1:]))
89
 
90
  A = -torch.exp(self.A_log) # (d_inner, N)
91
 
92
+ # Discretize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  A_bar = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) # (B, L, d_inner, N)
94
+ BX = dt.unsqueeze(-1) * B_sel.unsqueeze(2) * x_conv.unsqueeze(-1) # (B, L, d_inner, N)
95
 
96
+ # Pad L to power of 2 for pscan
97
+ orig_L = L
98
+ next_pow2 = 1 << (L - 1).bit_length()
99
+ if next_pow2 != L:
100
+ pad = next_pow2 - L
101
+ A_bar = F.pad(A_bar, (0,0, 0,0, 0,pad), value=1.0)
102
+ BX = F.pad(BX, (0,0, 0,0, 0,pad), value=0.0)
 
 
 
103
 
104
+ # PARALLEL SCAN O(log L), full grad support
105
+ h_all = parallel_scan(A_bar, BX) # (B, L_padded, d_inner, N)
106
+ h_all = h_all[:, :orig_L]
107
 
108
+ # Output: y_t = C_t · h_t
109
+ y = (h_all * C_sel.unsqueeze(2)).sum(-1) # (B, L, d_inner)
 
 
110
 
111
+ y = y + x_conv * self.D.unsqueeze(0).unsqueeze(0)
112
+ return self.out_proj(y * F.silu(z))
113
 
 
 
 
114
 
115
  def create_scan_patterns(H, W):
116
+ total = H * W; idx = torch.arange(total); grid = idx.view(H, W)
117
+ patterns = [idx.clone(), idx.flip(0), grid.t().contiguous().view(-1),
118
+ torch.cat([grid[i].flip(0) if i % 2 else grid[i] for i in range(H)])]
 
 
 
 
 
 
119
  inv = []
120
  for p in patterns:
121
  i = torch.zeros_like(p); i[p] = torch.arange(total); inv.append(i)
122
  return patterns, inv
123
 
124
 
 
 
 
 
125
  class LiquidSSMBlock(nn.Module):
126
  def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dropout=0.0):
127
  super().__init__()
 
134
  nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout),
135
  nn.Linear(d_model * 4, d_model), nn.Dropout(dropout))
136
  self.mix_alpha = nn.Parameter(torch.tensor(0.5))
 
137
  def _ssm_fwd(self, x): return self.ssm(self.norm1(x))
138
  def _liq_fwd(self, x): return self.liquid(self.norm2(x))
 
139
  def forward(self, x, scan_idx=None, unscan_idx=None):
140
  xs = x[:, scan_idx] if scan_idx is not None else x
141
  if self.training and x.requires_grad:
142
  so = checkpoint(self._ssm_fwd, xs, use_reentrant=False)
143
  lo = checkpoint(self._liq_fwd, x, use_reentrant=False)
144
  else:
145
+ so = self._ssm_fwd(xs); lo = self._liq_fwd(x)
 
146
  if unscan_idx is not None: so = so[:, unscan_idx]
147
  a = torch.sigmoid(self.mix_alpha)
148
  x = x + a * so + (1 - a) * lo
149
  return x + self.ff(self.norm3(x))
150
 
151
 
 
 
 
 
152
  class SinusoidalPosEmb(nn.Module):
153
  def __init__(self, dim): super().__init__(); self.dim = dim
154
  def forward(self, t):
155
  h = self.dim // 2; e = math.log(10000)/(h-1)
156
  e = torch.exp(torch.arange(h, device=t.device)*-e)
157
+ return torch.cat([(t.unsqueeze(-1)*e.unsqueeze(0)).sin(), (t.unsqueeze(-1)*e.unsqueeze(0)).cos()], -1)
 
158
 
159
  class AdaptiveLayerNorm(nn.Module):
160
  def __init__(self, d, c):
 
165
  return self.norm(x) * (1+s.unsqueeze(1)) + b.unsqueeze(1)
166
 
167
 
 
 
 
 
168
  class LiquidFlowNet(nn.Module):
169
  def __init__(self, img_size=128, patch_size=8, in_channels=3, d_model=256,
170
  depth=8, d_state=16, d_conv=4, expand=2, dropout=0.0, num_classes=0):
 
172
  self.img_size = img_size; self.patch_size = patch_size
173
  self.in_channels = in_channels; self.d_model = d_model
174
  self.depth = depth; self.num_classes = num_classes
175
+ self.num_patches_h = img_size // patch_size
176
+ self.num_patches_w = img_size // patch_size
177
+ self.num_patches = self.num_patches_h * self.num_patches_w
178
  self.patch_dim = in_channels * patch_size * patch_size
 
 
179
 
180
  self.patch_embed = nn.Sequential(nn.Linear(self.patch_dim, d_model), nn.LayerNorm(d_model))
181
  self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, d_model) * 0.02)
 
185
  self.blocks = nn.ModuleList([LiquidSSMBlock(d_model, d_state, d_conv, expand, dropout) for _ in range(depth)])
186
  self.adaln = nn.ModuleList([AdaptiveLayerNorm(d_model, d_model) for _ in range(depth)])
187
  self.skips = nn.ModuleList([nn.Linear(d_model*2, d_model) for _ in range(depth//2)])
 
188
  self.final_norm = nn.LayerNorm(d_model)
189
  self.final_proj = nn.Linear(d_model, self.patch_dim)
190
 
191
+ pats, ipats = create_scan_patterns(self.num_patches_h, self.num_patches_w)
192
  for i,(p,ip) in enumerate(zip(pats, ipats)):
193
  self.register_buffer(f'scan_{i}', p); self.register_buffer(f'unscan_{i}', ip)
194
  self.n_scans = len(pats)
 
195
  self.pre_conv = nn.Conv2d(d_model, d_model, 3, padding=1, groups=d_model)
196
  self.post_conv = nn.Conv2d(d_model, d_model, 3, padding=1, groups=d_model)
197
  self._init_weights()
 
208
 
209
  def patchify(self, x):
210
  B,C,H,W = x.shape; p = self.patch_size
211
+ return x.unfold(2,p,p).unfold(3,p,p).contiguous().view(B,C,self.num_patches_h,self.num_patches_w,p*p).permute(0,2,3,1,4).contiguous().view(B,self.num_patches,self.patch_dim)
212
 
213
  def unpatchify(self, x):
214
  B=x.shape[0]; p=self.patch_size
215
+ return x.view(B,self.num_patches_h,self.num_patches_w,self.in_channels,p,p).permute(0,3,1,4,2,5).contiguous().view(B,self.in_channels,self.num_patches_h*p,self.num_patches_w*p)
216
 
217
  def forward(self, x, t, class_label=None):
218
  B = x.shape[0]
219
  tok = self.patch_embed(self.patchify(x)) + self.pos_embed
220
+ h = tok.view(B,self.num_patches_h,self.num_patches_w,self.d_model).permute(0,3,1,2)
221
  tok = self.pre_conv(h).permute(0,2,3,1).contiguous().view(B,self.num_patches,self.d_model)
 
222
  te = self.time_embed(t)
223
  if self.class_embed is not None and class_label is not None: te = te + self.class_embed(class_label)
 
224
  sk = []
225
  for i,(blk,aln) in enumerate(zip(self.blocks, self.adaln)):
226
+ tok = aln(tok, te); si = i % self.n_scans
 
227
  if i < self.depth//2: sk.append(tok)
228
  tok = blk(tok, getattr(self,f'scan_{si}'), getattr(self,f'unscan_{si}'))
229
  if i >= self.depth//2:
230
  j = self.depth-1-i
231
  if j < len(sk): tok = self.skips[j](torch.cat([tok, sk[j]], -1))
232
+ h = tok.view(B,self.num_patches_h,self.num_patches_w,self.d_model).permute(0,3,1,2)
 
233
  tok = self.post_conv(h).permute(0,2,3,1).contiguous().view(B,self.num_patches,self.d_model)
234
  return self.unpatchify(self.final_proj(self.final_norm(tok)))
235
 
 
237
  return sum(p.numel() for p in self.parameters() if p.requires_grad)
238
 
239
 
 
 
 
 
240
  def liquidflow_tiny(img_size=128, num_classes=0):
 
241
  ps = 4 if img_size <= 64 else 8
242
  return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=192, depth=6, d_state=8, expand=2, num_classes=num_classes)
243
 
244
  def liquidflow_small(img_size=128, num_classes=0):
 
245
  ps = 4 if img_size <= 64 else 8
246
  return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=256, depth=8, d_state=16, expand=2, num_classes=num_classes)
247
 
248
  def liquidflow_base(img_size=256, num_classes=0):
 
249
  return LiquidFlowNet(img_size=img_size, patch_size=8, d_model=384, depth=10, d_state=16, expand=2, num_classes=num_classes)
250
 
251
  def liquidflow_512(img_size=512, num_classes=0):
 
252
  return LiquidFlowNet(img_size=img_size, patch_size=16, d_model=384, depth=10, d_state=16, expand=2, num_classes=num_classes)