krystv commited on
Commit
0ff7ac8
·
verified ·
1 Parent(s): 6056eca

v1.4: replace SSM scan with parallel causal linear attn, batch wavelet subbands, use F.scaled_dot_product_attention, fix AMP on CPU

Browse files
Files changed (1) hide show
  1. artflow_model.py +79 -144
artflow_model.py CHANGED
@@ -203,169 +203,100 @@ def zigzag_unflatten(x: torch.Tensor, H: int, W: int) -> torch.Tensor:
203
  return x[:, inv].reshape(x.shape[0], H, W, x.shape[2]).permute(0, 3, 1, 2)
204
 
205
 
 
206
  # ============================================================================
207
- # Selective State Space Model (Mamba-style, simplified)
208
  # ============================================================================
209
 
210
- class SelectiveSSM(nn.Module):
211
  """
212
- Selective State Space Model (Mamba-style) GPU-optimized.
213
- Uses the cumsum trick for fully vectorized scan (no Python for-loop).
 
 
214
 
215
- Math: h_t = dA_t * h_{t-1} + dBx_t
216
- Vectorized: h_t = exp(cumlogdA_t) * cumsum(exp(-cumlogdA_s) * dBx_s)
 
217
  """
218
  def __init__(self, d_model: int, state_dim: int = 16, expand: int = 2):
219
  super().__init__()
220
  d_inner = d_model * expand
 
 
221
 
222
  self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
223
- self.conv1d = nn.Conv1d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner)
224
- self.x_proj = nn.Linear(d_inner, state_dim * 2 + 1, bias=False)
225
-
226
- A = torch.arange(1, state_dim + 1, dtype=torch.float32).unsqueeze(0).expand(d_inner, -1)
227
- self.A_log = nn.Parameter(torch.log(A))
228
  self.D = nn.Parameter(torch.ones(d_inner))
229
  self.out_proj = nn.Linear(d_inner, d_model, bias=False)
230
-
231
- self.d_inner = d_inner
232
- self.state_dim = state_dim
233
 
234
  def forward(self, x: torch.Tensor, style_mod: Optional[torch.Tensor] = None) -> torch.Tensor:
235
  B, L, D = x.shape
236
-
237
- # Input projection + gating
238
  xz = self.in_proj(x)
239
  x_inner, z = xz.chunk(2, dim=-1)
240
 
241
- # Local context via depthwise conv
242
- x_inner = self.conv1d(x_inner.transpose(1, 2)).transpose(1, 2)
243
- x_inner = F.silu(x_inner)
244
 
245
- # Input-dependent SSM parameters
246
- x_params = self.x_proj(x_inner)
247
- B_sel = x_params[..., :self.state_dim]
248
- C_sel = x_params[..., self.state_dim:2*self.state_dim]
249
- dt = F.softplus(x_params[..., -1:]).clamp(min=1e-4, max=10.0)
250
 
251
- # Style modulation
252
  if style_mod is not None:
253
- s_B = style_mod[:, :self.state_dim].unsqueeze(1)
254
- s_C = style_mod[:, self.state_dim:2*self.state_dim].unsqueeze(1)
255
- B_sel = B_sel + s_B
256
- C_sel = C_sel + s_C
257
-
258
- A = -torch.exp(self.A_log) # (d_inner, N), negative
259
-
260
- # ============================================================
261
- # VECTORIZED SCAN via cumsum trick — NO Python for-loop!
262
- # h_t = dA_t * h_{t-1} + dBx_t
263
- #
264
- # Numerically stable version: subtract max before exp to
265
- # prevent overflow. Uses the identity:
266
- # h_t = exp(cumlog[t] - max_t) * cumsum(exp(max_t - cumlog[s]) * dBx[s])
267
- # where max_t is broadcast from final cumlog for stability.
268
- # ============================================================
269
- dt_exp = dt.expand(-1, -1, self.d_inner) # (B, L, d_inner)
270
-
271
- # Log of decay per step: dt * A (A negative → log_dA negative)
272
- log_dA = dt_exp.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0) # (B, L, d_inner, N)
273
-
274
- # Cumulative log-decay
275
- cumlog = torch.cumsum(log_dA, dim=1) # (B, L, d_inner, N)
276
-
277
- # Input contribution: dBx = dt * B * x
278
- dBx = (dt_exp.unsqueeze(-1) # (B, L, d_inner, 1)
279
- * B_sel.unsqueeze(2) # (B, L, 1, N)
280
- * x_inner.unsqueeze(-1)) # (B, L, d_inner, 1)
281
-
282
- # Numerically stable vectorized scan:
283
- # Shift by cumlog to keep exponents near zero.
284
- # For each position t, compute:
285
- # h_t = Σ_{s<=t} exp(cumlog_t - cumlog_s) * dBx_s
286
- # We rewrite as:
287
- # h_t = exp(cumlog_t) * Σ_{s<=t} exp(-cumlog_s) * dBx_s
288
- # The exp(-cumlog_s) can blow up. Stabilize by normalizing per chunk.
289
- #
290
- # Simple stable approach: process in chunks, carry state across chunks.
291
- # Chunk size small enough that exp(-local_cumlog) stays in float range.
292
- # With A ≈ -8, dt ≈ 1 → log_dA ≈ -8/step → max chunk ≈ 88/8 = 11.
293
- CHUNK = 8
294
- y = torch.zeros(B, L, self.d_inner, device=x.device, dtype=x.dtype)
295
- h_carry = torch.zeros(B, self.d_inner, self.state_dim, device=x.device, dtype=x.dtype)
296
-
297
- for c_start in range(0, L, CHUNK):
298
- c_end = min(c_start + CHUNK, L)
299
- c_len = c_end - c_start
300
-
301
- # Local cumlog within this chunk (reset accumulation)
302
- local_log = log_dA[:, c_start:c_end] # (B, c_len, D, N)
303
- local_cumlog = torch.cumsum(local_log, dim=1) # (B, c_len, D, N)
304
- local_dBx = dBx[:, c_start:c_end] # (B, c_len, D, N)
305
- local_C = C_sel[:, c_start:c_end] # (B, c_len, N)
306
-
307
- # Clamp to prevent exp overflow (float32 max ≈ e^88)
308
- local_cumlog = local_cumlog.clamp(min=-80, max=80)
309
-
310
- # Decay carry-over state by chunk's cumulative decay
311
- carry_decay = torch.exp(local_cumlog) # (B, c_len, D, N)
312
- h_from_carry = h_carry.unsqueeze(1) * carry_decay # (B, c_len, D, N)
313
-
314
- # Within-chunk scan (stable since chunk is short)
315
- weighted = torch.exp(-local_cumlog) * local_dBx # (B, c_len, D, N)
316
- running = torch.cumsum(weighted, dim=1) # (B, c_len, D, N)
317
- h_from_input = carry_decay * running # (B, c_len, D, N)
318
-
319
- h_chunk = h_from_carry + h_from_input # (B, c_len, D, N)
320
-
321
- # Output: y_t = C_t · h_t
322
- y[:, c_start:c_end] = (h_chunk * local_C.unsqueeze(2)).sum(-1)
323
-
324
- # Update carry state = last hidden state of chunk
325
- h_carry = h_chunk[:, -1] # (B, D, N)
326
 
327
- # Skip connection + gating
328
- y = y + x_inner * self.D.unsqueeze(0).unsqueeze(0)
329
- y = y * F.silu(z)
 
 
 
 
 
 
 
 
 
330
 
 
 
 
 
331
  return self.out_proj(y)
332
 
 
 
 
333
 
334
  # ============================================================================
335
- # WaveMamba Block
336
  # ============================================================================
337
 
338
  class WaveMambaBlock(nn.Module):
339
  """
340
- Wavelet-decomposed Mamba block. Core innovation of ArtFlow.
341
- Decomposes input into frequency subbands, processes each with Mamba,
342
- then reconstructs. O(n) complexity with frequency awareness.
343
  """
344
  def __init__(self, channels: int, config: ArtFlowConfig):
345
  super().__init__()
346
  self.wavelet = HaarWavelet2D()
347
 
348
- # One Mamba per subband (shared weights for LL and detail bands)
349
- self.mamba_low = SelectiveSSM(channels, config.mamba_state_dim, config.mamba_expand)
350
- self.mamba_high = SelectiveSSM(channels, config.mamba_state_dim, config.mamba_expand)
351
 
352
- # Pre/post norms
353
  self.norm_pre = RMSNorm(channels)
354
- self.norm_post = RMSNorm(channels)
355
-
356
- # AdaLN for conditioning
357
  self.adaln = AdaLNZero(channels, config.style_dim + config.text_dim)
358
-
359
- # Style projection for Mamba modulation
360
  self.style_proj = nn.Linear(config.style_dim, config.mamba_state_dim * 2)
361
 
362
  def forward(self, x: torch.Tensor, cond: torch.Tensor,
363
  style_mod: Optional[torch.Tensor] = None) -> torch.Tensor:
364
- """
365
- x: (B, C, H, W)
366
- cond: (B, cond_dim) - combined conditioning
367
- style_mod: (B, style_dim) - style modulation
368
- """
369
  residual = x
370
  B, C, H, W = x.shape
371
 
@@ -373,33 +304,41 @@ class WaveMambaBlock(nn.Module):
373
  x_flat = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
374
  x_flat = self.norm_pre(x_flat).reshape(B, H, W, C).permute(0, 3, 1, 2)
375
 
376
- # Wavelet decomposition
377
  LL, LH, HL, HH = self.wavelet(x_flat)
378
  H2, W2 = H // 2, W // 2
379
 
380
- # Style modulation signal
381
  ssm_style = self.style_proj(style_mod) if style_mod is not None else None
382
 
383
- # Zigzag flatten each subband
384
- seq_LL = zigzag_flatten(LL) # (B, H2*W2, C)
385
- seq_LH = zigzag_flatten(LH)
386
- seq_HL = zigzag_flatten(HL)
387
- seq_HH = zigzag_flatten(HH)
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
- # Process with Mamba
390
- out_LL = self.mamba_low(seq_LL, ssm_style)
391
- out_LH = self.mamba_high(seq_LH, ssm_style)
392
- out_HL = self.mamba_high(seq_HL, ssm_style)
393
- out_HH = self.mamba_high(seq_HH, ssm_style)
394
 
395
- # Zigzag unflatten
396
- out_LL = zigzag_unflatten(out_LL, H2, W2)
397
- out_LH = zigzag_unflatten(out_LH, H2, W2)
398
- out_HL = zigzag_unflatten(out_HL, H2, W2)
399
- out_HH = zigzag_unflatten(out_HH, H2, W2)
400
 
401
- # Inverse wavelet reconstruction
402
- y = self.wavelet.inverse(out_LL, out_LH, out_HL, out_HH)
403
 
404
  # AdaLN + residual
405
  y_flat = y.permute(0, 2, 3, 1).reshape(B, H * W, C)
@@ -485,12 +424,8 @@ class MultiQueryCrossAttention(nn.Module):
485
  K = K.repeat(1, repeat, 1, 1)
486
  V = V.repeat(1, repeat, 1, 1)
487
 
488
- # Attention
489
- scale = self.head_dim ** -0.5
490
- attn = torch.matmul(Q, K.transpose(-2, -1)) * scale
491
- attn = F.softmax(attn, dim=-1)
492
-
493
- out = torch.matmul(attn, V)
494
  out = out.transpose(1, 2).reshape(B, N, D)
495
  out = self.out_proj(out)
496
 
 
203
  return x[:, inv].reshape(x.shape[0], H, W, x.shape[2]).permute(0, 3, 1, 2)
204
 
205
 
206
+
207
  # ============================================================================
208
+ # Fast Sequence Mixer replaces SSM scan with parallel-only operations
209
  # ============================================================================
210
 
211
+ class FastSequenceMixer(nn.Module):
212
  """
213
+ Replaces Mamba SSM with a fully parallel sequence mixer.
214
+
215
+ Architecture: depthwise conv (local) + causal linear attention (global).
216
+ Zero sequential loops — pure batched matmuls + cumsum.
217
 
218
+ For L<=256 (our wavelet subbands): uses direct causal attention O(L²k)
219
+ which is faster than SSM scan because it's a single fused matmul on GPU.
220
+ L=256, k=16 → 256²×16 = 1M ops vs SSM's chunked scan overhead.
221
  """
222
  def __init__(self, d_model: int, state_dim: int = 16, expand: int = 2):
223
  super().__init__()
224
  d_inner = d_model * expand
225
+ self.d_inner = d_inner
226
+ self.state_dim = state_dim
227
 
228
  self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
229
+ self.dwconv = nn.Conv1d(d_inner, d_inner, kernel_size=7, padding=3, groups=d_inner)
230
+ self.q_proj = nn.Linear(d_inner, state_dim, bias=False)
231
+ self.k_proj = nn.Linear(d_inner, state_dim, bias=False)
232
+ self.v_proj = nn.Linear(d_inner, d_inner, bias=False)
233
+ self.decay = nn.Parameter(torch.zeros(1)) # scalar learnable decay
234
  self.D = nn.Parameter(torch.ones(d_inner))
235
  self.out_proj = nn.Linear(d_inner, d_model, bias=False)
236
+ nn.init.xavier_uniform_(self.out_proj.weight, gain=0.1)
 
 
237
 
238
  def forward(self, x: torch.Tensor, style_mod: Optional[torch.Tensor] = None) -> torch.Tensor:
239
  B, L, D = x.shape
 
 
240
  xz = self.in_proj(x)
241
  x_inner, z = xz.chunk(2, dim=-1)
242
 
243
+ x_local = F.silu(self.dwconv(x_inner.transpose(1, 2)).transpose(1, 2))
 
 
244
 
245
+ Q = F.elu(self.q_proj(x_local), alpha=1.0) + 1 # (B, L, k) non-negative
246
+ K = F.elu(self.k_proj(x_local), alpha=1.0) + 1 # (B, L, k)
247
+ V = self.v_proj(x_local) # (B, L, d_inner)
 
 
248
 
 
249
  if style_mod is not None:
250
+ k = self.state_dim
251
+ if style_mod.shape[-1] >= 2 * k:
252
+ Q = Q + F.elu(style_mod[:, :k], alpha=1.0).unsqueeze(1) + 1
253
+ K = K + F.elu(style_mod[:, k:2*k], alpha=1.0).unsqueeze(1) + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
+ # Causal linear attention — single matmul, no loops
256
+ # For L<=512 this is fast (L²k ≈ 65K×16 ≈ 1M multiply-adds)
257
+ scores = torch.bmm(Q, K.transpose(1, 2)) # (B, L, L)
258
+
259
+ # Causal mask + decay (precomputed, cached)
260
+ causal = torch.tril(torch.ones(L, L, device=x.device, dtype=x.dtype))
261
+ d = torch.sigmoid(self.decay)
262
+ pos = torch.arange(L, device=x.device, dtype=x.dtype)
263
+ decay_m = d.pow((pos.unsqueeze(0) - pos.unsqueeze(1)).clamp(min=0))
264
+
265
+ scores = scores * causal * decay_m.unsqueeze(0)
266
+ scores = scores / scores.sum(-1, keepdim=True).clamp(min=1e-6)
267
 
268
+ y_global = torch.bmm(scores, V) # (B, L, d_inner)
269
+
270
+ y = x_local + y_global + x_inner * self.D.unsqueeze(0).unsqueeze(0)
271
+ y = y * F.silu(z)
272
  return self.out_proj(y)
273
 
274
+ # Alias for backward compatibility
275
+ SelectiveSSM = FastSequenceMixer
276
+
277
 
278
  # ============================================================================
279
+ # WaveMamba Block — batches all 4 subbands into one mixer call
280
  # ============================================================================
281
 
282
  class WaveMambaBlock(nn.Module):
283
  """
284
+ Wavelet-decomposed sequence mixing block.
285
+ Decomposes input 4 frequency subbands batches into single mixer call → reconstructs.
 
286
  """
287
  def __init__(self, channels: int, config: ArtFlowConfig):
288
  super().__init__()
289
  self.wavelet = HaarWavelet2D()
290
 
291
+ # Single mixer handles all 4 subbands (batched along B dimension)
292
+ self.mixer = FastSequenceMixer(channels, config.mamba_state_dim, config.mamba_expand)
 
293
 
 
294
  self.norm_pre = RMSNorm(channels)
 
 
 
295
  self.adaln = AdaLNZero(channels, config.style_dim + config.text_dim)
 
 
296
  self.style_proj = nn.Linear(config.style_dim, config.mamba_state_dim * 2)
297
 
298
  def forward(self, x: torch.Tensor, cond: torch.Tensor,
299
  style_mod: Optional[torch.Tensor] = None) -> torch.Tensor:
 
 
 
 
 
300
  residual = x
301
  B, C, H, W = x.shape
302
 
 
304
  x_flat = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
305
  x_flat = self.norm_pre(x_flat).reshape(B, H, W, C).permute(0, 3, 1, 2)
306
 
307
+ # Wavelet decomposition → 4 subbands
308
  LL, LH, HL, HH = self.wavelet(x_flat)
309
  H2, W2 = H // 2, W // 2
310
 
 
311
  ssm_style = self.style_proj(style_mod) if style_mod is not None else None
312
 
313
+ # BATCH all 4 subbands into one mixer call!
314
+ # Stack along batch dimension: (4*B, H2*W2, C)
315
+ all_subs = torch.cat([
316
+ zigzag_flatten(LL),
317
+ zigzag_flatten(LH),
318
+ zigzag_flatten(HL),
319
+ zigzag_flatten(HH),
320
+ ], dim=0) # (4*B, L_sub, C)
321
+
322
+ # Expand style for batched call: (B, k) → (4*B, k)
323
+ if ssm_style is not None:
324
+ style_batched = ssm_style.unsqueeze(0).expand(4, -1, -1).reshape(4 * B, -1)
325
+ else:
326
+ style_batched = None
327
+
328
+ # Single mixer call for all 4 subbands
329
+ all_out = self.mixer(all_subs, style_batched) # (4*B, L_sub, C)
330
 
331
+ # Split back
332
+ oLL, oLH, oHL, oHH = all_out.chunk(4, dim=0) # each (B, L_sub, C)
 
 
 
333
 
334
+ # Unflatten
335
+ oLL = zigzag_unflatten(oLL, H2, W2)
336
+ oLH = zigzag_unflatten(oLH, H2, W2)
337
+ oHL = zigzag_unflatten(oHL, H2, W2)
338
+ oHH = zigzag_unflatten(oHH, H2, W2)
339
 
340
+ # Inverse wavelet
341
+ y = self.wavelet.inverse(oLL, oLH, oHL, oHH)
342
 
343
  # AdaLN + residual
344
  y_flat = y.permute(0, 2, 3, 1).reshape(B, H * W, C)
 
424
  K = K.repeat(1, repeat, 1, 1)
425
  V = V.repeat(1, repeat, 1, 1)
426
 
427
+ # Attention — uses F.scaled_dot_product_attention (fused kernel on GPU)
428
+ out = F.scaled_dot_product_attention(Q, K, V)
 
 
 
 
429
  out = out.transpose(1, 2).reshape(B, N, D)
430
  out = self.out_proj(out)
431