Spaces:
Running on Zero
Running on Zero
File size: 29,739 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 | """Flux2 transformer layers for LightDiffusion-Next.
Core building blocks for the Flux2 architecture:
- Attention mechanisms
- Modulation layers
- Transformer blocks (double and single stream)
- Embedding layers
Adapted from ComfyUI's Flux implementation for LightDiffusion-Next.
"""
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from src.cond import cast as ops_module
from src.Device import Device
# Get operations module
def get_ops():
"""Get the operations module for weight initialization."""
return ops_module.disable_weight_init
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization.
Uses native PyTorch rms_norm when available for numerical consistency with ComfyUI.
"""
def __init__(self, dim: int, eps: float = 1e-6, dtype=None, device=None):
super().__init__()
self.eps = eps
# Use 'scale' to match Flux2 checkpoint naming convention
self.scale = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
# Check if native rms_norm is available
self._use_native = hasattr(torch.nn.functional, 'rms_norm')
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Ensure scale is on the same device as input
scale = self.scale.to(x.device, x.dtype)
if self._use_native and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
# Use native PyTorch rms_norm for better precision (matches ComfyUI)
return torch.nn.functional.rms_norm(x, scale.shape, weight=scale, eps=self.eps)
else:
# Fallback implementation
rms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x * rms * scale
class EmbedND(nn.Module):
"""N-dimensional positional embedding using RoPE."""
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
"""Compute rotary positional embeddings.
Args:
ids: Position IDs tensor of shape [batch, seq_len, num_axes]
Returns:
Rotary embeddings of shape [batch, seq_len, dim]
"""
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
"""Compute rotary position embeddings.
Matches ComfyUI's implementation exactly for numerical precision.
Args:
pos: Position indices
dim: Embedding dimension
theta: Base frequency
Returns:
Rotary embeddings as float32 concatenation of cos and sin
"""
assert dim % 2 == 0
device = pos.device
# ComfyUI uses float64 for scale calculation for maximum precision
scale = torch.linspace(0, (dim - 2) / dim, dim // 2, dtype=torch.float64, device=device)
omega = 1.0 / (theta ** scale)
# Einsum for position-frequency interaction - cast pos to float32 like ComfyUI
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
# ComfyUI always returns float32 for RoPE embeddings
return out.to(dtype=torch.float32, device=pos.device)
class MLPEmbedder(nn.Module):
"""MLP for timestep and guidance embeddings."""
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None, ops_bias: bool = True):
super().__init__()
if operations is None:
operations = get_ops()
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=ops_bias, dtype=dtype, device=device)
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=ops_bias, dtype=dtype, device=device)
self.silu = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class GatedMLP(nn.Module):
"""Gated MLP (SwiGLU) for Klein models.
Structure: hidden -> 2*intermediate -> SiLU gate -> intermediate -> hidden
The first linear produces gate and value activations,
SiLU is applied to gate, then gate * value, then final projection.
"""
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None, ops_bias: bool = True):
super().__init__()
if operations is None:
operations = get_ops()
# First layer outputs 2x intermediate for gating
self.gate_up_proj = operations.Linear(hidden_size, intermediate_size * 2, bias=ops_bias, dtype=dtype, device=device)
self.down_proj = operations.Linear(intermediate_size, hidden_size, bias=ops_bias, dtype=dtype, device=device)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up = self.gate_up_proj(x)
gate, up = gate_up.chunk(2, dim=-1)
return self.down_proj(self.act(gate) * up)
class QKNorm(nn.Module):
"""Query-Key normalization layer."""
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
# Use native RMSNorm instead of operations.RMSNorm
self.query_norm = RMSNorm(dim, dtype=dtype, device=device)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
q = self.query_norm(q)
k = self.key_norm(k)
# Cast to v's dtype and device to match ComfyUI (crucial for numerical consistency)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
"""Self-attention with rotary position embedding (RoPE)."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
dtype=None,
device=None,
operations=None,
ops_bias: bool = True,
):
super().__init__()
if operations is None:
operations = get_ops()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, bias=ops_bias, dtype=dtype, device=device)
def forward(self, x: torch.Tensor, pe: torch.Tensor) -> torch.Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pe: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""Apply attention with rotary position embeddings.
Args:
q: Query tensor [batch, heads, seq, dim]
k: Key tensor [batch, heads, seq, dim]
v: Value tensor [batch, heads, seq, dim]
pe: Positional embeddings
mask: Optional attention mask for padding tokens
Returns:
Attention output [batch, seq, heads*dim]
"""
# Validate positional embedding sequence length to prevent RoPE shape errors
if pe is not None:
try:
pe_seq = pe.shape[2] if pe.ndim >= 3 else None
if pe_seq not in (1, q.shape[2]):
raise ValueError(
f"RoPE sequence length mismatch: pe.seq={pe_seq} != q.seq={q.shape[2]}. "
"Transformer options (img_h/img_w) may not match the input token grid; check calc_cond_batch merging of transformer_options."
)
except Exception:
# Re-raise as a clear ValueError for easier debugging
raise
q, k = apply_rope(q, k, pe)
# Efficient attention implementation
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, mask=mask)
return x
def apply_rope1(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Apply rotary position embedding to a single tensor.
Correctly applies the 2x2 rotation matrix:
y1 = x1 * cos - x2 * sin
y2 = x1 * sin + x2 * cos
Args:
x: Input tensor [batch, heads, seq, dim]
freqs_cis: Frequency tensor [batch, 1, seq, dim//2, 2, 2]
Returns:
Rotated tensor [batch, heads, seq, dim]
"""
# Reshape x to match RoPE components [batch, heads, seq, dim//2, 2]
x_reshaped = x.reshape(*x.shape[:-1], -1, 2)
# Handle differing sequence lengths between x and freqs_cis
# freqs_cis shape: [batch, 1, seq_pe, dim//2, 2, 2]
seq_x = x.shape[2]
seq_pe = freqs_cis.shape[2]
if seq_pe != seq_x:
if seq_pe < seq_x:
# Upsample by repeating along sequence dimension then slice to exact length
repeat = (seq_x + seq_pe - 1) // seq_pe
freqs_cis = freqs_cis.repeat_interleave(repeat, dim=2)[..., :seq_x, :, :, :]
else:
# Slice to match x sequence length
freqs_cis = freqs_cis[..., :seq_x, :, :, :]
# Sanity-check: feature dimension (half of head dim) must match freqs_cis
feat_half = x.shape[-1] // 2
if freqs_cis.shape[-3] != feat_half:
raise ValueError(
f"RoPE feature-dim mismatch: freqs_cis.dim={freqs_cis.shape[-3]} != x.dim/2={feat_half}. "
f"x.shape={x.shape}, freqs_cis.shape={freqs_cis.shape}"
)
# Extract rotation matrix components
# freqs_cis is [..., dim//2, row, col]
# row 0: [cos, -sin]
# row 1: [sin, cos]
cos = freqs_cis[..., 0, 0]
msin = freqs_cis[..., 0, 1] # -sin
sin = freqs_cis[..., 1, 0]
x1 = x_reshaped[..., 0]
x2 = x_reshaped[..., 1]
# Apply rotation
out1 = x1 * cos + x2 * msin
out2 = x1 * sin + x2 * cos
# Combine and reshape back to original
return torch.stack([out1, out2], dim=-1).reshape(*x.shape).type_as(x)
def apply_rope(q: torch.Tensor, k: torch.Tensor, pe: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary position embeddings to queries and keys.
Args:
q: Query tensor [batch, heads, seq, dim]
k: Key tensor [batch, heads, seq, dim]
pe: Positional embeddings [..., dim//2, 2, 2]
Returns:
Rotated (q, k) tensors
"""
return apply_rope1(q, pe), apply_rope1(k, pe)
def optimized_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor = None) -> torch.Tensor:
"""Optimized attention using Flash/SDPA with fallback to xformers.
Performance priority: cuDNN > Flash > SDPA > xformers > naive
Uses SDPA backend priority from Device module for optimal dispatch.
"""
b, _, seq_q, dim = q.shape
_, _, seq_kv, _ = k.shape
# Method 1: Use native scaled_dot_product_attention with backend priority
# This is the fastest path on modern PyTorch with GPU support
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
try:
# Get SDPA backend priority context manager from Device
sdpa_context = Device.get_sdpa_context()
# Process attention mask for SDPA if provided
attn_mask = None
if mask is not None:
# Add dimensions as needed: [B, L] -> [B, 1, 1, L] for broadcasting
if mask.ndim == 2:
attn_mask = mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, L]
elif mask.ndim == 3:
attn_mask = mask.unsqueeze(1) # [B, 1, L, L]
else:
attn_mask = mask
# Convert mask to additive form (0 for attend, -inf for mask)
# Input mask is 1 for valid, 0 for invalid (padding)
attn_mask = attn_mask.to(dtype=q.dtype)
attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min
# SDPA expects [batch, heads, seq, dim] - q/k/v are already in this format
with sdpa_context:
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
# Reshape: [batch, heads, seq, dim] -> [batch, seq, heads*dim]
# Use transpose + view for efficiency (avoid copy)
out = out.transpose(1, 2).reshape(b, seq_q, -1)
return out
except Exception:
pass # Fall through to xformers
# Method 2: Use xformers memory-efficient attention
if Device.xformers_enabled():
try:
import xformers.ops as xops
# xformers expects [batch, seq, heads, dim]
q_xf = q.transpose(1, 2).contiguous()
k_xf = k.transpose(1, 2).contiguous()
v_xf = v.transpose(1, 2).contiguous()
# Note: xformers has different mask format, conversion would be needed
out = xops.memory_efficient_attention(q_xf, k_xf, v_xf)
del q_xf, k_xf, v_xf # Free memory early
# Reshape: [batch, seq, heads, dim] -> [batch, seq, heads*dim]
out = out.reshape(b, seq_q, -1)
return out
except Exception:
pass # Fall through to naive
# Method 3: Naive implementation (slowest, memory intensive)
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(b, seq_q, -1)
return out
@dataclass
class ModulationOut:
"""Output of modulation layer."""
shift: torch.Tensor
scale: torch.Tensor
gate: torch.Tensor
class Modulation(nn.Module):
"""Adaptive layer normalization modulation.
Applies shift, scale, and gate from conditioning vector.
"""
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None, ops_bias: bool = True):
super().__init__()
if operations is None:
operations = get_ops()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = operations.Linear(dim, self.multiplier * dim, bias=ops_bias, dtype=dtype, device=device)
def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
mod1 = ModulationOut(shift=out[0], scale=out[1], gate=out[2])
mod2 = ModulationOut(shift=out[3], scale=out[4], gate=out[5]) if self.is_double else None
return mod1, mod2
class GlobalModulation(nn.Module):
"""Global modulation for Flux2 (Klein) double stream blocks."""
def __init__(self, dim: int, dtype=None, device=None, operations=None, ops_bias: bool = True):
super().__init__()
if operations is None:
operations = get_ops()
# 12 outputs: 6 for img stream, 6 for txt stream
self.lin = operations.Linear(dim, 12 * dim, bias=ops_bias, dtype=dtype, device=device)
def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut, ModulationOut, ModulationOut]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(12, dim=-1)
mod1_img = ModulationOut(shift=out[0], scale=out[1], gate=out[2])
mod2_img = ModulationOut(shift=out[3], scale=out[4], gate=out[5])
mod1_txt = ModulationOut(shift=out[6], scale=out[7], gate=out[8])
mod2_txt = ModulationOut(shift=out[9], scale=out[10], gate=out[11])
return mod1_img, mod2_img, mod1_txt, mod2_txt
class DoubleStreamBlock(nn.Module):
"""Transformer block with separate image and text streams.
Uses joint attention but separate MLPs for image and text.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
qkv_bias: bool = False,
global_modulation: bool = False,
dtype=None,
device=None,
operations=None,
flax_compatible: bool = False,
silu_mlp: bool = False,
gated_mlp: bool = False,
ops_bias: bool = True, # Whether to use bias in linear layers
):
super().__init__()
if operations is None:
operations = get_ops()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.flax_compatible = flax_compatible
self.silu_mlp = silu_mlp
self.gated_mlp = gated_mlp
# For gated MLP (Klein), mlp_ratio is the true ratio
# First layer outputs 2x for gating: hidden -> 2*intermediate
# Second layer: intermediate -> hidden
if gated_mlp:
mlp_intermediate = int(hidden_size * mlp_ratio)
mlp_hidden_dim = mlp_intermediate * 2 # Double for gate+up projection
else:
mlp_hidden_dim = int(hidden_size * mlp_ratio)
mlp_intermediate = mlp_hidden_dim
if global_modulation:
# When using global modulation at model level, don't create per-block modulation
self.double_stream_modulation = None
self.img_mod = None
self.txt_mod = None
self.use_global_modulation = True
else:
self.double_stream_modulation = None
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
self.use_global_modulation = False
# Image stream
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(hidden_size, num_heads, qkv_bias, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
if gated_mlp:
# Gated MLP with naming compatible with checkpoint: .0, .1 (identity), .2
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device),
nn.Identity(), # Placeholder for index 1
operations.Linear(mlp_intermediate, hidden_size, bias=ops_bias, dtype=dtype, device=device),
)
else:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device),
nn.SiLU() if silu_mlp else nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=ops_bias, dtype=dtype, device=device),
)
# Text stream
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(hidden_size, num_heads, qkv_bias, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
if gated_mlp:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device),
nn.Identity(),
operations.Linear(mlp_intermediate, hidden_size, bias=ops_bias, dtype=dtype, device=device),
)
else:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device),
nn.SiLU() if silu_mlp else nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=ops_bias, dtype=dtype, device=device),
)
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
attn_mask=None,
img_mod: tuple = None, # (img_mod1, img_mod2) from global modulation
txt_mod: tuple = None, # (txt_mod1, txt_mod2) from global modulation
) -> tuple[torch.Tensor, torch.Tensor]:
# Get modulation parameters
if self.use_global_modulation and img_mod is not None and txt_mod is not None:
# Use global modulation passed from model level
img_mod1, img_mod2 = img_mod
txt_mod1, txt_mod2 = txt_mod
elif self.img_mod is not None and self.txt_mod is not None:
# Use per-block modulation (Flux1 style)
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
else:
raise ValueError("No modulation available - either provide global or use per-block modulation")
# Prepare normed inputs
img_normed = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_normed + img_mod1.shift
del img_normed # Free memory early
txt_normed = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_normed + txt_mod1.shift
del txt_normed # Free memory early
# Run joint attention - use view+permute for efficiency instead of rearrange
img_qkv = self.img_attn.qkv(img_modulated)
del img_modulated
q_img, k_img, v_img = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del img_qkv
txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
q_txt, k_txt, v_txt = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del txt_qkv
q_img, k_img = self.img_attn.norm(q_img, k_img, v_img)
q_txt, k_txt = self.txt_attn.norm(q_txt, k_txt, v_txt)
# Concatenate for joint attention
q = torch.cat((q_txt, q_img), dim=2)
del q_txt, q_img
k = torch.cat((k_txt, k_img), dim=2)
del k_txt, k_img
v = torch.cat((v_txt, v_img), dim=2)
del v_txt, v_img
attn_out = attention(q, k, v, pe=pe, mask=attn_mask)
del q, k, v
txt_attn, img_attn = attn_out[:, : txt.shape[1]], attn_out[:, txt.shape[1] :]
del attn_out
# Apply residual connections with gating
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
del img_attn
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
del txt_attn
# MLP with modulation
img_mlp_in = (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
img = img + img_mod2.gate * self._forward_mlp(self.img_mlp, img_mlp_in)
del img_mlp_in
txt_mlp_in = (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
txt = txt + txt_mod2.gate * self._forward_mlp(self.txt_mlp, txt_mlp_in)
del txt_mlp_in
# Handle fp16 numerical issues (matches ComfyUI exactly)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
def _forward_mlp(self, mlp: nn.Sequential, x: torch.Tensor) -> torch.Tensor:
"""Forward through MLP, handling both standard and gated variants."""
if self.gated_mlp:
# Gated MLP: split into gate and up, apply SiLU to gate, multiply, project
gate_up = mlp[0](x)
gate, up = gate_up.chunk(2, dim=-1)
hidden = F.silu(gate) * up
return mlp[2](hidden)
else:
return mlp(x)
class SingleStreamBlock(nn.Module):
"""Transformer block with merged image and text stream.
Used after the double stream blocks have processed both modalities.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None,
silu_mlp: bool = False,
gated_mlp: bool = False,
ops_bias: bool = True,
global_modulation: bool = False,
):
super().__init__()
if operations is None:
operations = get_ops()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.silu_mlp = silu_mlp
self.gated_mlp = gated_mlp
self.use_global_modulation = global_modulation
# For gated MLP, mlp_ratio gives intermediate size
# linear1 outputs gate+up (2x intermediate), linear2 takes intermediate
if gated_mlp:
self.mlp_intermediate = int(hidden_size * mlp_ratio)
self.mlp_gate_up_dim = self.mlp_intermediate * 2
linear1_out = hidden_size * 3 + self.mlp_gate_up_dim
linear2_in = hidden_size + self.mlp_intermediate
else:
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
linear1_out = hidden_size * 3 + self.mlp_hidden_dim
linear2_in = hidden_size + self.mlp_hidden_dim
# Joint QKV and MLP projection
self.linear1 = operations.Linear(
hidden_size, linear1_out, bias=ops_bias, dtype=dtype, device=device
)
self.linear2 = operations.Linear(
linear2_in, hidden_size, bias=ops_bias, dtype=dtype, device=device
)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# Only create per-block modulation if not using global modulation
if not global_modulation:
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
else:
self.modulation = None
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
attn_mask=None,
modulation=None, # ModulationOut from global modulation
) -> torch.Tensor:
# Get modulation
if self.use_global_modulation and modulation is not None:
mod = modulation
elif self.modulation is not None:
mod, _ = self.modulation(vec)
else:
raise ValueError("No modulation available - either provide global or use per-block modulation")
x_normed = self.pre_norm(x)
x_mod = (1 + mod.scale) * x_normed + mod.shift
del x_normed # Free memory early
# Joint projection - split QKV from MLP part
qkv_mlp = self.linear1(x_mod)
del x_mod
if self.gated_mlp:
qkv, mlp_gate_up = qkv_mlp.split([self.hidden_size * 3, self.mlp_gate_up_dim], dim=-1)
del qkv_mlp
# Gated MLP: split into gate and up, apply SiLU to gate, multiply
gate, up = mlp_gate_up.chunk(2, dim=-1)
del mlp_gate_up
mlp = F.silu(gate) * up
del gate, up
else:
qkv, mlp = qkv_mlp.split([self.hidden_size * 3, self.mlp_hidden_dim], dim=-1)
del qkv_mlp
# Standard activation
if self.silu_mlp:
mlp = F.silu(mlp)
else:
mlp = F.gelu(mlp, approximate="tanh")
# Attention - use view+permute for efficiency instead of rearrange
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)
attn = attention(q, k, v, pe=pe, mask=attn_mask)
del q, k, v
# Combine and project
output = self.linear2(torch.cat((attn, mlp), dim=-1))
del attn, mlp
result = x + mod.gate * output
# Handle fp16 numerical issues (matches ComfyUI exactly)
if result.dtype == torch.float16:
result = torch.nan_to_num(result, nan=0.0, posinf=65504, neginf=-65504)
return result
class LastLayer(nn.Module):
"""Final layer for unpatchifying and producing output."""
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None, ops_bias: bool = True):
super().__init__()
if operations is None:
operations = get_ops()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(
hidden_size, patch_size * patch_size * out_channels, bias=ops_bias, dtype=dtype, device=device
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=ops_bias, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
|