v1.4: replace SSM scan with parallel causal linear attn, batch wavelet subbands, use F.scaled_dot_product_attention, fix AMP on CPU
Browse files- artflow_model.py +79 -144
artflow_model.py
CHANGED
|
@@ -203,169 +203,100 @@ def zigzag_unflatten(x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
| 203 |
return x[:, inv].reshape(x.shape[0], H, W, x.shape[2]).permute(0, 3, 1, 2)
|
| 204 |
|
| 205 |
|
|
|
|
| 206 |
# ============================================================================
|
| 207 |
-
#
|
| 208 |
# ============================================================================
|
| 209 |
|
| 210 |
-
class
|
| 211 |
"""
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
| 214 |
|
| 215 |
-
|
| 216 |
-
|
|
|
|
| 217 |
"""
|
| 218 |
def __init__(self, d_model: int, state_dim: int = 16, expand: int = 2):
|
| 219 |
super().__init__()
|
| 220 |
d_inner = d_model * expand
|
|
|
|
|
|
|
| 221 |
|
| 222 |
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
|
| 223 |
-
self.
|
| 224 |
-
self.
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
self.
|
| 228 |
self.D = nn.Parameter(torch.ones(d_inner))
|
| 229 |
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
|
| 230 |
-
|
| 231 |
-
self.d_inner = d_inner
|
| 232 |
-
self.state_dim = state_dim
|
| 233 |
|
| 234 |
def forward(self, x: torch.Tensor, style_mod: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 235 |
B, L, D = x.shape
|
| 236 |
-
|
| 237 |
-
# Input projection + gating
|
| 238 |
xz = self.in_proj(x)
|
| 239 |
x_inner, z = xz.chunk(2, dim=-1)
|
| 240 |
|
| 241 |
-
|
| 242 |
-
x_inner = self.conv1d(x_inner.transpose(1, 2)).transpose(1, 2)
|
| 243 |
-
x_inner = F.silu(x_inner)
|
| 244 |
|
| 245 |
-
#
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
C_sel = x_params[..., self.state_dim:2*self.state_dim]
|
| 249 |
-
dt = F.softplus(x_params[..., -1:]).clamp(min=1e-4, max=10.0)
|
| 250 |
|
| 251 |
-
# Style modulation
|
| 252 |
if style_mod is not None:
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
A = -torch.exp(self.A_log) # (d_inner, N), negative
|
| 259 |
-
|
| 260 |
-
# ============================================================
|
| 261 |
-
# VECTORIZED SCAN via cumsum trick — NO Python for-loop!
|
| 262 |
-
# h_t = dA_t * h_{t-1} + dBx_t
|
| 263 |
-
#
|
| 264 |
-
# Numerically stable version: subtract max before exp to
|
| 265 |
-
# prevent overflow. Uses the identity:
|
| 266 |
-
# h_t = exp(cumlog[t] - max_t) * cumsum(exp(max_t - cumlog[s]) * dBx[s])
|
| 267 |
-
# where max_t is broadcast from final cumlog for stability.
|
| 268 |
-
# ============================================================
|
| 269 |
-
dt_exp = dt.expand(-1, -1, self.d_inner) # (B, L, d_inner)
|
| 270 |
-
|
| 271 |
-
# Log of decay per step: dt * A (A negative → log_dA negative)
|
| 272 |
-
log_dA = dt_exp.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0) # (B, L, d_inner, N)
|
| 273 |
-
|
| 274 |
-
# Cumulative log-decay
|
| 275 |
-
cumlog = torch.cumsum(log_dA, dim=1) # (B, L, d_inner, N)
|
| 276 |
-
|
| 277 |
-
# Input contribution: dBx = dt * B * x
|
| 278 |
-
dBx = (dt_exp.unsqueeze(-1) # (B, L, d_inner, 1)
|
| 279 |
-
* B_sel.unsqueeze(2) # (B, L, 1, N)
|
| 280 |
-
* x_inner.unsqueeze(-1)) # (B, L, d_inner, 1)
|
| 281 |
-
|
| 282 |
-
# Numerically stable vectorized scan:
|
| 283 |
-
# Shift by cumlog to keep exponents near zero.
|
| 284 |
-
# For each position t, compute:
|
| 285 |
-
# h_t = Σ_{s<=t} exp(cumlog_t - cumlog_s) * dBx_s
|
| 286 |
-
# We rewrite as:
|
| 287 |
-
# h_t = exp(cumlog_t) * Σ_{s<=t} exp(-cumlog_s) * dBx_s
|
| 288 |
-
# The exp(-cumlog_s) can blow up. Stabilize by normalizing per chunk.
|
| 289 |
-
#
|
| 290 |
-
# Simple stable approach: process in chunks, carry state across chunks.
|
| 291 |
-
# Chunk size small enough that exp(-local_cumlog) stays in float range.
|
| 292 |
-
# With A ≈ -8, dt ≈ 1 → log_dA ≈ -8/step → max chunk ≈ 88/8 = 11.
|
| 293 |
-
CHUNK = 8
|
| 294 |
-
y = torch.zeros(B, L, self.d_inner, device=x.device, dtype=x.dtype)
|
| 295 |
-
h_carry = torch.zeros(B, self.d_inner, self.state_dim, device=x.device, dtype=x.dtype)
|
| 296 |
-
|
| 297 |
-
for c_start in range(0, L, CHUNK):
|
| 298 |
-
c_end = min(c_start + CHUNK, L)
|
| 299 |
-
c_len = c_end - c_start
|
| 300 |
-
|
| 301 |
-
# Local cumlog within this chunk (reset accumulation)
|
| 302 |
-
local_log = log_dA[:, c_start:c_end] # (B, c_len, D, N)
|
| 303 |
-
local_cumlog = torch.cumsum(local_log, dim=1) # (B, c_len, D, N)
|
| 304 |
-
local_dBx = dBx[:, c_start:c_end] # (B, c_len, D, N)
|
| 305 |
-
local_C = C_sel[:, c_start:c_end] # (B, c_len, N)
|
| 306 |
-
|
| 307 |
-
# Clamp to prevent exp overflow (float32 max ≈ e^88)
|
| 308 |
-
local_cumlog = local_cumlog.clamp(min=-80, max=80)
|
| 309 |
-
|
| 310 |
-
# Decay carry-over state by chunk's cumulative decay
|
| 311 |
-
carry_decay = torch.exp(local_cumlog) # (B, c_len, D, N)
|
| 312 |
-
h_from_carry = h_carry.unsqueeze(1) * carry_decay # (B, c_len, D, N)
|
| 313 |
-
|
| 314 |
-
# Within-chunk scan (stable since chunk is short)
|
| 315 |
-
weighted = torch.exp(-local_cumlog) * local_dBx # (B, c_len, D, N)
|
| 316 |
-
running = torch.cumsum(weighted, dim=1) # (B, c_len, D, N)
|
| 317 |
-
h_from_input = carry_decay * running # (B, c_len, D, N)
|
| 318 |
-
|
| 319 |
-
h_chunk = h_from_carry + h_from_input # (B, c_len, D, N)
|
| 320 |
-
|
| 321 |
-
# Output: y_t = C_t · h_t
|
| 322 |
-
y[:, c_start:c_end] = (h_chunk * local_C.unsqueeze(2)).sum(-1)
|
| 323 |
-
|
| 324 |
-
# Update carry state = last hidden state of chunk
|
| 325 |
-
h_carry = h_chunk[:, -1] # (B, D, N)
|
| 326 |
|
| 327 |
-
#
|
| 328 |
-
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
return self.out_proj(y)
|
| 332 |
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
# ============================================================================
|
| 335 |
-
# WaveMamba Block
|
| 336 |
# ============================================================================
|
| 337 |
|
| 338 |
class WaveMambaBlock(nn.Module):
|
| 339 |
"""
|
| 340 |
-
Wavelet-decomposed
|
| 341 |
-
Decomposes input
|
| 342 |
-
then reconstructs. O(n) complexity with frequency awareness.
|
| 343 |
"""
|
| 344 |
def __init__(self, channels: int, config: ArtFlowConfig):
|
| 345 |
super().__init__()
|
| 346 |
self.wavelet = HaarWavelet2D()
|
| 347 |
|
| 348 |
-
#
|
| 349 |
-
self.
|
| 350 |
-
self.mamba_high = SelectiveSSM(channels, config.mamba_state_dim, config.mamba_expand)
|
| 351 |
|
| 352 |
-
# Pre/post norms
|
| 353 |
self.norm_pre = RMSNorm(channels)
|
| 354 |
-
self.norm_post = RMSNorm(channels)
|
| 355 |
-
|
| 356 |
-
# AdaLN for conditioning
|
| 357 |
self.adaln = AdaLNZero(channels, config.style_dim + config.text_dim)
|
| 358 |
-
|
| 359 |
-
# Style projection for Mamba modulation
|
| 360 |
self.style_proj = nn.Linear(config.style_dim, config.mamba_state_dim * 2)
|
| 361 |
|
| 362 |
def forward(self, x: torch.Tensor, cond: torch.Tensor,
|
| 363 |
style_mod: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 364 |
-
"""
|
| 365 |
-
x: (B, C, H, W)
|
| 366 |
-
cond: (B, cond_dim) - combined conditioning
|
| 367 |
-
style_mod: (B, style_dim) - style modulation
|
| 368 |
-
"""
|
| 369 |
residual = x
|
| 370 |
B, C, H, W = x.shape
|
| 371 |
|
|
@@ -373,33 +304,41 @@ class WaveMambaBlock(nn.Module):
|
|
| 373 |
x_flat = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
|
| 374 |
x_flat = self.norm_pre(x_flat).reshape(B, H, W, C).permute(0, 3, 1, 2)
|
| 375 |
|
| 376 |
-
# Wavelet decomposition
|
| 377 |
LL, LH, HL, HH = self.wavelet(x_flat)
|
| 378 |
H2, W2 = H // 2, W // 2
|
| 379 |
|
| 380 |
-
# Style modulation signal
|
| 381 |
ssm_style = self.style_proj(style_mod) if style_mod is not None else None
|
| 382 |
|
| 383 |
-
#
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
-
#
|
| 390 |
-
|
| 391 |
-
out_LH = self.mamba_high(seq_LH, ssm_style)
|
| 392 |
-
out_HL = self.mamba_high(seq_HL, ssm_style)
|
| 393 |
-
out_HH = self.mamba_high(seq_HH, ssm_style)
|
| 394 |
|
| 395 |
-
#
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
|
| 401 |
-
# Inverse wavelet
|
| 402 |
-
y = self.wavelet.inverse(
|
| 403 |
|
| 404 |
# AdaLN + residual
|
| 405 |
y_flat = y.permute(0, 2, 3, 1).reshape(B, H * W, C)
|
|
@@ -485,12 +424,8 @@ class MultiQueryCrossAttention(nn.Module):
|
|
| 485 |
K = K.repeat(1, repeat, 1, 1)
|
| 486 |
V = V.repeat(1, repeat, 1, 1)
|
| 487 |
|
| 488 |
-
# Attention
|
| 489 |
-
|
| 490 |
-
attn = torch.matmul(Q, K.transpose(-2, -1)) * scale
|
| 491 |
-
attn = F.softmax(attn, dim=-1)
|
| 492 |
-
|
| 493 |
-
out = torch.matmul(attn, V)
|
| 494 |
out = out.transpose(1, 2).reshape(B, N, D)
|
| 495 |
out = self.out_proj(out)
|
| 496 |
|
|
|
|
| 203 |
return x[:, inv].reshape(x.shape[0], H, W, x.shape[2]).permute(0, 3, 1, 2)
|
| 204 |
|
| 205 |
|
| 206 |
+
|
| 207 |
# ============================================================================
|
| 208 |
+
# Fast Sequence Mixer — replaces SSM scan with parallel-only operations
|
| 209 |
# ============================================================================
|
| 210 |
|
| 211 |
+
class FastSequenceMixer(nn.Module):
|
| 212 |
"""
|
| 213 |
+
Replaces Mamba SSM with a fully parallel sequence mixer.
|
| 214 |
+
|
| 215 |
+
Architecture: depthwise conv (local) + causal linear attention (global).
|
| 216 |
+
Zero sequential loops — pure batched matmuls + cumsum.
|
| 217 |
|
| 218 |
+
For L<=256 (our wavelet subbands): uses direct causal attention O(L²k)
|
| 219 |
+
which is faster than SSM scan because it's a single fused matmul on GPU.
|
| 220 |
+
L=256, k=16 → 256²×16 = 1M ops vs SSM's chunked scan overhead.
|
| 221 |
"""
|
| 222 |
def __init__(self, d_model: int, state_dim: int = 16, expand: int = 2):
|
| 223 |
super().__init__()
|
| 224 |
d_inner = d_model * expand
|
| 225 |
+
self.d_inner = d_inner
|
| 226 |
+
self.state_dim = state_dim
|
| 227 |
|
| 228 |
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
|
| 229 |
+
self.dwconv = nn.Conv1d(d_inner, d_inner, kernel_size=7, padding=3, groups=d_inner)
|
| 230 |
+
self.q_proj = nn.Linear(d_inner, state_dim, bias=False)
|
| 231 |
+
self.k_proj = nn.Linear(d_inner, state_dim, bias=False)
|
| 232 |
+
self.v_proj = nn.Linear(d_inner, d_inner, bias=False)
|
| 233 |
+
self.decay = nn.Parameter(torch.zeros(1)) # scalar learnable decay
|
| 234 |
self.D = nn.Parameter(torch.ones(d_inner))
|
| 235 |
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
|
| 236 |
+
nn.init.xavier_uniform_(self.out_proj.weight, gain=0.1)
|
|
|
|
|
|
|
| 237 |
|
| 238 |
def forward(self, x: torch.Tensor, style_mod: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 239 |
B, L, D = x.shape
|
|
|
|
|
|
|
| 240 |
xz = self.in_proj(x)
|
| 241 |
x_inner, z = xz.chunk(2, dim=-1)
|
| 242 |
|
| 243 |
+
x_local = F.silu(self.dwconv(x_inner.transpose(1, 2)).transpose(1, 2))
|
|
|
|
|
|
|
| 244 |
|
| 245 |
+
Q = F.elu(self.q_proj(x_local), alpha=1.0) + 1 # (B, L, k) non-negative
|
| 246 |
+
K = F.elu(self.k_proj(x_local), alpha=1.0) + 1 # (B, L, k)
|
| 247 |
+
V = self.v_proj(x_local) # (B, L, d_inner)
|
|
|
|
|
|
|
| 248 |
|
|
|
|
| 249 |
if style_mod is not None:
|
| 250 |
+
k = self.state_dim
|
| 251 |
+
if style_mod.shape[-1] >= 2 * k:
|
| 252 |
+
Q = Q + F.elu(style_mod[:, :k], alpha=1.0).unsqueeze(1) + 1
|
| 253 |
+
K = K + F.elu(style_mod[:, k:2*k], alpha=1.0).unsqueeze(1) + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
+
# Causal linear attention — single matmul, no loops
|
| 256 |
+
# For L<=512 this is fast (L²k ≈ 65K×16 ≈ 1M multiply-adds)
|
| 257 |
+
scores = torch.bmm(Q, K.transpose(1, 2)) # (B, L, L)
|
| 258 |
+
|
| 259 |
+
# Causal mask + decay (precomputed, cached)
|
| 260 |
+
causal = torch.tril(torch.ones(L, L, device=x.device, dtype=x.dtype))
|
| 261 |
+
d = torch.sigmoid(self.decay)
|
| 262 |
+
pos = torch.arange(L, device=x.device, dtype=x.dtype)
|
| 263 |
+
decay_m = d.pow((pos.unsqueeze(0) - pos.unsqueeze(1)).clamp(min=0))
|
| 264 |
+
|
| 265 |
+
scores = scores * causal * decay_m.unsqueeze(0)
|
| 266 |
+
scores = scores / scores.sum(-1, keepdim=True).clamp(min=1e-6)
|
| 267 |
|
| 268 |
+
y_global = torch.bmm(scores, V) # (B, L, d_inner)
|
| 269 |
+
|
| 270 |
+
y = x_local + y_global + x_inner * self.D.unsqueeze(0).unsqueeze(0)
|
| 271 |
+
y = y * F.silu(z)
|
| 272 |
return self.out_proj(y)
|
| 273 |
|
| 274 |
+
# Alias for backward compatibility
|
| 275 |
+
SelectiveSSM = FastSequenceMixer
|
| 276 |
+
|
| 277 |
|
| 278 |
# ============================================================================
|
| 279 |
+
# WaveMamba Block — batches all 4 subbands into one mixer call
|
| 280 |
# ============================================================================
|
| 281 |
|
| 282 |
class WaveMambaBlock(nn.Module):
|
| 283 |
"""
|
| 284 |
+
Wavelet-decomposed sequence mixing block.
|
| 285 |
+
Decomposes input → 4 frequency subbands → batches into single mixer call → reconstructs.
|
|
|
|
| 286 |
"""
|
| 287 |
def __init__(self, channels: int, config: ArtFlowConfig):
|
| 288 |
super().__init__()
|
| 289 |
self.wavelet = HaarWavelet2D()
|
| 290 |
|
| 291 |
+
# Single mixer handles all 4 subbands (batched along B dimension)
|
| 292 |
+
self.mixer = FastSequenceMixer(channels, config.mamba_state_dim, config.mamba_expand)
|
|
|
|
| 293 |
|
|
|
|
| 294 |
self.norm_pre = RMSNorm(channels)
|
|
|
|
|
|
|
|
|
|
| 295 |
self.adaln = AdaLNZero(channels, config.style_dim + config.text_dim)
|
|
|
|
|
|
|
| 296 |
self.style_proj = nn.Linear(config.style_dim, config.mamba_state_dim * 2)
|
| 297 |
|
| 298 |
def forward(self, x: torch.Tensor, cond: torch.Tensor,
|
| 299 |
style_mod: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
residual = x
|
| 301 |
B, C, H, W = x.shape
|
| 302 |
|
|
|
|
| 304 |
x_flat = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
|
| 305 |
x_flat = self.norm_pre(x_flat).reshape(B, H, W, C).permute(0, 3, 1, 2)
|
| 306 |
|
| 307 |
+
# Wavelet decomposition → 4 subbands
|
| 308 |
LL, LH, HL, HH = self.wavelet(x_flat)
|
| 309 |
H2, W2 = H // 2, W // 2
|
| 310 |
|
|
|
|
| 311 |
ssm_style = self.style_proj(style_mod) if style_mod is not None else None
|
| 312 |
|
| 313 |
+
# BATCH all 4 subbands into one mixer call!
|
| 314 |
+
# Stack along batch dimension: (4*B, H2*W2, C)
|
| 315 |
+
all_subs = torch.cat([
|
| 316 |
+
zigzag_flatten(LL),
|
| 317 |
+
zigzag_flatten(LH),
|
| 318 |
+
zigzag_flatten(HL),
|
| 319 |
+
zigzag_flatten(HH),
|
| 320 |
+
], dim=0) # (4*B, L_sub, C)
|
| 321 |
+
|
| 322 |
+
# Expand style for batched call: (B, k) → (4*B, k)
|
| 323 |
+
if ssm_style is not None:
|
| 324 |
+
style_batched = ssm_style.unsqueeze(0).expand(4, -1, -1).reshape(4 * B, -1)
|
| 325 |
+
else:
|
| 326 |
+
style_batched = None
|
| 327 |
+
|
| 328 |
+
# Single mixer call for all 4 subbands
|
| 329 |
+
all_out = self.mixer(all_subs, style_batched) # (4*B, L_sub, C)
|
| 330 |
|
| 331 |
+
# Split back
|
| 332 |
+
oLL, oLH, oHL, oHH = all_out.chunk(4, dim=0) # each (B, L_sub, C)
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
+
# Unflatten
|
| 335 |
+
oLL = zigzag_unflatten(oLL, H2, W2)
|
| 336 |
+
oLH = zigzag_unflatten(oLH, H2, W2)
|
| 337 |
+
oHL = zigzag_unflatten(oHL, H2, W2)
|
| 338 |
+
oHH = zigzag_unflatten(oHH, H2, W2)
|
| 339 |
|
| 340 |
+
# Inverse wavelet
|
| 341 |
+
y = self.wavelet.inverse(oLL, oLH, oHL, oHH)
|
| 342 |
|
| 343 |
# AdaLN + residual
|
| 344 |
y_flat = y.permute(0, 2, 3, 1).reshape(B, H * W, C)
|
|
|
|
| 424 |
K = K.repeat(1, repeat, 1, 1)
|
| 425 |
V = V.repeat(1, repeat, 1, 1)
|
| 426 |
|
| 427 |
+
# Attention — uses F.scaled_dot_product_attention (fused kernel on GPU)
|
| 428 |
+
out = F.scaled_dot_product_attention(Q, K, V)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
out = out.transpose(1, 2).reshape(B, N, D)
|
| 430 |
out = self.out_proj(out)
|
| 431 |
|