File size: 29,739 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
"""Flux2 transformer layers for LightDiffusion-Next.

Core building blocks for the Flux2 architecture:
- Attention mechanisms
- Modulation layers  
- Transformer blocks (double and single stream)
- Embedding layers

Adapted from ComfyUI's Flux implementation for LightDiffusion-Next.
"""

import math
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from src.cond import cast as ops_module
from src.Device import Device


# Get operations module
def get_ops():
    """Get the operations module for weight initialization."""
    return ops_module.disable_weight_init


class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.
    
    Uses native PyTorch rms_norm when available for numerical consistency with ComfyUI.
    """
    
    def __init__(self, dim: int, eps: float = 1e-6, dtype=None, device=None):
        super().__init__()
        self.eps = eps
        # Use 'scale' to match Flux2 checkpoint naming convention
        self.scale = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
        # Check if native rms_norm is available
        self._use_native = hasattr(torch.nn.functional, 'rms_norm')
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Ensure scale is on the same device as input
        scale = self.scale.to(x.device, x.dtype)
        
        if self._use_native and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
            # Use native PyTorch rms_norm for better precision (matches ComfyUI)
            return torch.nn.functional.rms_norm(x, scale.shape, weight=scale, eps=self.eps)
        else:
            # Fallback implementation
            rms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
            return x * rms * scale


class EmbedND(nn.Module):
    """N-dimensional positional embedding using RoPE."""

    def __init__(self, dim: int, theta: int, axes_dim: list[int]):
        super().__init__()
        self.dim = dim
        self.theta = theta
        self.axes_dim = axes_dim

    def forward(self, ids: torch.Tensor) -> torch.Tensor:
        """Compute rotary positional embeddings.
        
        Args:
            ids: Position IDs tensor of shape [batch, seq_len, num_axes]
            
        Returns:
            Rotary embeddings of shape [batch, seq_len, dim]
        """
        n_axes = ids.shape[-1]
        emb = torch.cat(
            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
            dim=-3,
        )
        return emb.unsqueeze(1)


def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
    """Compute rotary position embeddings.
    
    Matches ComfyUI's implementation exactly for numerical precision.
    
    Args:
        pos: Position indices
        dim: Embedding dimension
        theta: Base frequency
        
    Returns:
        Rotary embeddings as float32 concatenation of cos and sin
    """
    assert dim % 2 == 0
    device = pos.device
    
    # ComfyUI uses float64 for scale calculation for maximum precision
    scale = torch.linspace(0, (dim - 2) / dim, dim // 2, dtype=torch.float64, device=device)
    omega = 1.0 / (theta ** scale)
    
    # Einsum for position-frequency interaction - cast pos to float32 like ComfyUI
    out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
    
    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
    
    # ComfyUI always returns float32 for RoPE embeddings
    return out.to(dtype=torch.float32, device=pos.device)


class MLPEmbedder(nn.Module):
    """MLP for timestep and guidance embeddings."""

    def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None, ops_bias: bool = True):
        super().__init__()
        if operations is None:
            operations = get_ops()
        self.in_layer = operations.Linear(in_dim, hidden_dim, bias=ops_bias, dtype=dtype, device=device)
        self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=ops_bias, dtype=dtype, device=device)
        self.silu = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.out_layer(self.silu(self.in_layer(x)))


class GatedMLP(nn.Module):
    """Gated MLP (SwiGLU) for Klein models.
    
    Structure: hidden -> 2*intermediate -> SiLU gate -> intermediate -> hidden
    The first linear produces gate and value activations,
    SiLU is applied to gate, then gate * value, then final projection.
    """

    def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None, ops_bias: bool = True):
        super().__init__()
        if operations is None:
            operations = get_ops()
        # First layer outputs 2x intermediate for gating
        self.gate_up_proj = operations.Linear(hidden_size, intermediate_size * 2, bias=ops_bias, dtype=dtype, device=device)
        self.down_proj = operations.Linear(intermediate_size, hidden_size, bias=ops_bias, dtype=dtype, device=device)
        self.act = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up = self.gate_up_proj(x)
        gate, up = gate_up.chunk(2, dim=-1)
        return self.down_proj(self.act(gate) * up)


class QKNorm(nn.Module):
    """Query-Key normalization layer."""

    def __init__(self, dim: int, dtype=None, device=None, operations=None):
        super().__init__()
        # Use native RMSNorm instead of operations.RMSNorm
        self.query_norm = RMSNorm(dim, dtype=dtype, device=device)
        self.key_norm = RMSNorm(dim, dtype=dtype, device=device)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        q = self.query_norm(q)
        k = self.key_norm(k)
        # Cast to v's dtype and device to match ComfyUI (crucial for numerical consistency)
        return q.to(v), k.to(v)


class SelfAttention(nn.Module):
    """Self-attention with rotary position embedding (RoPE)."""

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        dtype=None,
        device=None,
        operations=None,
        ops_bias: bool = True,
    ):
        super().__init__()
        if operations is None:
            operations = get_ops()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        
        self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
        self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
        self.proj = operations.Linear(dim, dim, bias=ops_bias, dtype=dtype, device=device)

    def forward(self, x: torch.Tensor, pe: torch.Tensor) -> torch.Tensor:
        qkv = self.qkv(x)
        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        q, k = self.norm(q, k, v)
        x = attention(q, k, v, pe=pe)
        x = self.proj(x)
        return x


def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pe: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
    """Apply attention with rotary position embeddings.
    
    Args:
        q: Query tensor [batch, heads, seq, dim]
        k: Key tensor [batch, heads, seq, dim]  
        v: Value tensor [batch, heads, seq, dim]
        pe: Positional embeddings
        mask: Optional attention mask for padding tokens
        
    Returns:
        Attention output [batch, seq, heads*dim]
    """
    # Validate positional embedding sequence length to prevent RoPE shape errors
    if pe is not None:
        try:
            pe_seq = pe.shape[2] if pe.ndim >= 3 else None
            if pe_seq not in (1, q.shape[2]):
                raise ValueError(
                    f"RoPE sequence length mismatch: pe.seq={pe_seq} != q.seq={q.shape[2]}. "
                    "Transformer options (img_h/img_w) may not match the input token grid; check calc_cond_batch merging of transformer_options."
                )
        except Exception:
            # Re-raise as a clear ValueError for easier debugging
            raise

    q, k = apply_rope(q, k, pe)
    
    # Efficient attention implementation
    heads = q.shape[1]
    x = optimized_attention(q, k, v, heads, mask=mask)
    return x


def apply_rope1(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    """Apply rotary position embedding to a single tensor.
    
    Correctly applies the 2x2 rotation matrix:
    y1 = x1 * cos - x2 * sin
    y2 = x1 * sin + x2 * cos
    
    Args:
        x: Input tensor [batch, heads, seq, dim]
        freqs_cis: Frequency tensor [batch, 1, seq, dim//2, 2, 2]
        
    Returns:
        Rotated tensor [batch, heads, seq, dim]
    """
    # Reshape x to match RoPE components [batch, heads, seq, dim//2, 2]
    x_reshaped = x.reshape(*x.shape[:-1], -1, 2)

    # Handle differing sequence lengths between x and freqs_cis
    # freqs_cis shape: [batch, 1, seq_pe, dim//2, 2, 2]
    seq_x = x.shape[2]
    seq_pe = freqs_cis.shape[2]
    if seq_pe != seq_x:
        if seq_pe < seq_x:
            # Upsample by repeating along sequence dimension then slice to exact length
            repeat = (seq_x + seq_pe - 1) // seq_pe
            freqs_cis = freqs_cis.repeat_interleave(repeat, dim=2)[..., :seq_x, :, :, :]
        else:
            # Slice to match x sequence length
            freqs_cis = freqs_cis[..., :seq_x, :, :, :]

    # Sanity-check: feature dimension (half of head dim) must match freqs_cis
    feat_half = x.shape[-1] // 2
    if freqs_cis.shape[-3] != feat_half:
        raise ValueError(
            f"RoPE feature-dim mismatch: freqs_cis.dim={freqs_cis.shape[-3]} != x.dim/2={feat_half}. "
            f"x.shape={x.shape}, freqs_cis.shape={freqs_cis.shape}"
        )

    # Extract rotation matrix components
    # freqs_cis is [..., dim//2, row, col]
    # row 0: [cos, -sin]
    # row 1: [sin, cos]
    cos = freqs_cis[..., 0, 0]
    msin = freqs_cis[..., 0, 1] # -sin
    sin = freqs_cis[..., 1, 0]

    x1 = x_reshaped[..., 0]
    x2 = x_reshaped[..., 1]

    # Apply rotation
    out1 = x1 * cos + x2 * msin
    out2 = x1 * sin + x2 * cos
    
    # Combine and reshape back to original
    return torch.stack([out1, out2], dim=-1).reshape(*x.shape).type_as(x)


def apply_rope(q: torch.Tensor, k: torch.Tensor, pe: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Apply rotary position embeddings to queries and keys.
    
    Args:
        q: Query tensor [batch, heads, seq, dim]
        k: Key tensor [batch, heads, seq, dim]
        pe: Positional embeddings [..., dim//2, 2, 2]
        
    Returns:
        Rotated (q, k) tensors
    """
    return apply_rope1(q, pe), apply_rope1(k, pe)


def optimized_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor = None) -> torch.Tensor:
    """Optimized attention using Flash/SDPA with fallback to xformers.
    
    Performance priority: cuDNN > Flash > SDPA > xformers > naive
    Uses SDPA backend priority from Device module for optimal dispatch.
    """
    b, _, seq_q, dim = q.shape
    _, _, seq_kv, _ = k.shape
    
    # Method 1: Use native scaled_dot_product_attention with backend priority
    # This is the fastest path on modern PyTorch with GPU support
    if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
        try:
            # Get SDPA backend priority context manager from Device
            sdpa_context = Device.get_sdpa_context()
            
            # Process attention mask for SDPA if provided
            attn_mask = None
            if mask is not None:
                # Add dimensions as needed: [B, L] -> [B, 1, 1, L] for broadcasting
                if mask.ndim == 2:
                    attn_mask = mask.unsqueeze(1).unsqueeze(1)  # [B, 1, 1, L]
                elif mask.ndim == 3:
                    attn_mask = mask.unsqueeze(1)  # [B, 1, L, L]
                else:
                    attn_mask = mask
                # Convert mask to additive form (0 for attend, -inf for mask)
                # Input mask is 1 for valid, 0 for invalid (padding)
                attn_mask = attn_mask.to(dtype=q.dtype)
                attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min
            
            # SDPA expects [batch, heads, seq, dim] - q/k/v are already in this format
            with sdpa_context:
                out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
            
            # Reshape: [batch, heads, seq, dim] -> [batch, seq, heads*dim]
            # Use transpose + view for efficiency (avoid copy)
            out = out.transpose(1, 2).reshape(b, seq_q, -1)
            return out
        except Exception:
            pass  # Fall through to xformers
    
    # Method 2: Use xformers memory-efficient attention
    if Device.xformers_enabled():
        try:
            import xformers.ops as xops
            # xformers expects [batch, seq, heads, dim]
            q_xf = q.transpose(1, 2).contiguous()
            k_xf = k.transpose(1, 2).contiguous()
            v_xf = v.transpose(1, 2).contiguous()
            # Note: xformers has different mask format, conversion would be needed
            out = xops.memory_efficient_attention(q_xf, k_xf, v_xf)
            del q_xf, k_xf, v_xf  # Free memory early
            # Reshape: [batch, seq, heads, dim] -> [batch, seq, heads*dim]
            out = out.reshape(b, seq_q, -1)
            return out
        except Exception:
            pass  # Fall through to naive
    
    # Method 3: Naive implementation (slowest, memory intensive)
    out = F.scaled_dot_product_attention(q, k, v)
    out = out.transpose(1, 2).reshape(b, seq_q, -1)
    return out


@dataclass
class ModulationOut:
    """Output of modulation layer."""
    shift: torch.Tensor
    scale: torch.Tensor
    gate: torch.Tensor


class Modulation(nn.Module):
    """Adaptive layer normalization modulation.
    
    Applies shift, scale, and gate from conditioning vector.
    """

    def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None, ops_bias: bool = True):
        super().__init__()
        if operations is None:
            operations = get_ops()
        self.is_double = double
        self.multiplier = 6 if double else 3
        self.lin = operations.Linear(dim, self.multiplier * dim, bias=ops_bias, dtype=dtype, device=device)

    def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut | None]:
        out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
        
        mod1 = ModulationOut(shift=out[0], scale=out[1], gate=out[2])
        mod2 = ModulationOut(shift=out[3], scale=out[4], gate=out[5]) if self.is_double else None
        return mod1, mod2


class GlobalModulation(nn.Module):
    """Global modulation for Flux2 (Klein) double stream blocks."""

    def __init__(self, dim: int, dtype=None, device=None, operations=None, ops_bias: bool = True):
        super().__init__()
        if operations is None:
            operations = get_ops()
        # 12 outputs: 6 for img stream, 6 for txt stream
        self.lin = operations.Linear(dim, 12 * dim, bias=ops_bias, dtype=dtype, device=device)

    def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut, ModulationOut, ModulationOut]:
        out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(12, dim=-1)
        
        mod1_img = ModulationOut(shift=out[0], scale=out[1], gate=out[2])
        mod2_img = ModulationOut(shift=out[3], scale=out[4], gate=out[5])
        mod1_txt = ModulationOut(shift=out[6], scale=out[7], gate=out[8])
        mod2_txt = ModulationOut(shift=out[9], scale=out[10], gate=out[11])
        
        return mod1_img, mod2_img, mod1_txt, mod2_txt


class DoubleStreamBlock(nn.Module):
    """Transformer block with separate image and text streams.
    
    Uses joint attention but separate MLPs for image and text.
    """

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        mlp_ratio: float,
        qkv_bias: bool = False,
        global_modulation: bool = False,
        dtype=None,
        device=None,
        operations=None,
        flax_compatible: bool = False,
        silu_mlp: bool = False,
        gated_mlp: bool = False,
        ops_bias: bool = True,  # Whether to use bias in linear layers
    ):
        super().__init__()
        if operations is None:
            operations = get_ops()
            
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.flax_compatible = flax_compatible
        self.silu_mlp = silu_mlp
        self.gated_mlp = gated_mlp
        
        # For gated MLP (Klein), mlp_ratio is the true ratio
        # First layer outputs 2x for gating: hidden -> 2*intermediate
        # Second layer: intermediate -> hidden
        if gated_mlp:
            mlp_intermediate = int(hidden_size * mlp_ratio)
            mlp_hidden_dim = mlp_intermediate * 2  # Double for gate+up projection
        else:
            mlp_hidden_dim = int(hidden_size * mlp_ratio)
            mlp_intermediate = mlp_hidden_dim

        if global_modulation:
            # When using global modulation at model level, don't create per-block modulation
            self.double_stream_modulation = None
            self.img_mod = None
            self.txt_mod = None
            self.use_global_modulation = True
        else:
            self.double_stream_modulation = None
            self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
            self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
            self.use_global_modulation = False

        # Image stream
        self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
        self.img_attn = SelfAttention(hidden_size, num_heads, qkv_bias, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
        self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
        
        if gated_mlp:
            # Gated MLP with naming compatible with checkpoint: .0, .1 (identity), .2
            self.img_mlp = nn.Sequential(
                operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device),
                nn.Identity(),  # Placeholder for index 1
                operations.Linear(mlp_intermediate, hidden_size, bias=ops_bias, dtype=dtype, device=device),
            )
        else:
            self.img_mlp = nn.Sequential(
                operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device),
                nn.SiLU() if silu_mlp else nn.GELU(approximate="tanh"),
                operations.Linear(mlp_hidden_dim, hidden_size, bias=ops_bias, dtype=dtype, device=device),
            )

        # Text stream
        self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
        self.txt_attn = SelfAttention(hidden_size, num_heads, qkv_bias, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
        self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
        
        if gated_mlp:
            self.txt_mlp = nn.Sequential(
                operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device),
                nn.Identity(),
                operations.Linear(mlp_intermediate, hidden_size, bias=ops_bias, dtype=dtype, device=device),
            )
        else:
            self.txt_mlp = nn.Sequential(
                operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device),
                nn.SiLU() if silu_mlp else nn.GELU(approximate="tanh"),
                operations.Linear(mlp_hidden_dim, hidden_size, bias=ops_bias, dtype=dtype, device=device),
            )

    def forward(
        self,
        img: torch.Tensor,
        txt: torch.Tensor,
        vec: torch.Tensor,
        pe: torch.Tensor,
        attn_mask=None,
        img_mod: tuple = None,  # (img_mod1, img_mod2) from global modulation
        txt_mod: tuple = None,  # (txt_mod1, txt_mod2) from global modulation
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Get modulation parameters
        if self.use_global_modulation and img_mod is not None and txt_mod is not None:
            # Use global modulation passed from model level
            img_mod1, img_mod2 = img_mod
            txt_mod1, txt_mod2 = txt_mod
        elif self.img_mod is not None and self.txt_mod is not None:
            # Use per-block modulation (Flux1 style)
            img_mod1, img_mod2 = self.img_mod(vec)
            txt_mod1, txt_mod2 = self.txt_mod(vec)
        else:
            raise ValueError("No modulation available - either provide global or use per-block modulation")

        # Prepare normed inputs
        img_normed = self.img_norm1(img)
        img_modulated = (1 + img_mod1.scale) * img_normed + img_mod1.shift
        del img_normed  # Free memory early
        
        txt_normed = self.txt_norm1(txt)
        txt_modulated = (1 + txt_mod1.scale) * txt_normed + txt_mod1.shift
        del txt_normed  # Free memory early

        # Run joint attention - use view+permute for efficiency instead of rearrange
        img_qkv = self.img_attn.qkv(img_modulated)
        del img_modulated
        q_img, k_img, v_img = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        del img_qkv
        
        txt_qkv = self.txt_attn.qkv(txt_modulated)
        del txt_modulated
        q_txt, k_txt, v_txt = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        del txt_qkv

        q_img, k_img = self.img_attn.norm(q_img, k_img, v_img)
        q_txt, k_txt = self.txt_attn.norm(q_txt, k_txt, v_txt)

        # Concatenate for joint attention
        q = torch.cat((q_txt, q_img), dim=2)
        del q_txt, q_img
        k = torch.cat((k_txt, k_img), dim=2)
        del k_txt, k_img
        v = torch.cat((v_txt, v_img), dim=2)
        del v_txt, v_img

        attn_out = attention(q, k, v, pe=pe, mask=attn_mask)
        del q, k, v
        txt_attn, img_attn = attn_out[:, : txt.shape[1]], attn_out[:, txt.shape[1] :]
        del attn_out

        # Apply residual connections with gating
        img = img + img_mod1.gate * self.img_attn.proj(img_attn)
        del img_attn
        txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
        del txt_attn

        # MLP with modulation
        img_mlp_in = (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
        img = img + img_mod2.gate * self._forward_mlp(self.img_mlp, img_mlp_in)
        del img_mlp_in

        txt_mlp_in = (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
        txt = txt + txt_mod2.gate * self._forward_mlp(self.txt_mlp, txt_mlp_in)
        del txt_mlp_in

        # Handle fp16 numerical issues (matches ComfyUI exactly)
        if txt.dtype == torch.float16:
            txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)

        return img, txt

    def _forward_mlp(self, mlp: nn.Sequential, x: torch.Tensor) -> torch.Tensor:
        """Forward through MLP, handling both standard and gated variants."""
        if self.gated_mlp:
            # Gated MLP: split into gate and up, apply SiLU to gate, multiply, project
            gate_up = mlp[0](x)
            gate, up = gate_up.chunk(2, dim=-1)
            hidden = F.silu(gate) * up
            return mlp[2](hidden)
        else:
            return mlp(x)


class SingleStreamBlock(nn.Module):
    """Transformer block with merged image and text stream.
    
    Used after the double stream blocks have processed both modalities.
    """

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qk_scale: float = None,
        dtype=None,
        device=None,
        operations=None,
        silu_mlp: bool = False,
        gated_mlp: bool = False,
        ops_bias: bool = True,
        global_modulation: bool = False,
    ):
        super().__init__()
        if operations is None:
            operations = get_ops()
            
        self.hidden_dim = hidden_size
        self.num_heads = num_heads
        head_dim = hidden_size // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.silu_mlp = silu_mlp
        self.gated_mlp = gated_mlp
        self.use_global_modulation = global_modulation

        # For gated MLP, mlp_ratio gives intermediate size
        # linear1 outputs gate+up (2x intermediate), linear2 takes intermediate
        if gated_mlp:
            self.mlp_intermediate = int(hidden_size * mlp_ratio)
            self.mlp_gate_up_dim = self.mlp_intermediate * 2
            linear1_out = hidden_size * 3 + self.mlp_gate_up_dim
            linear2_in = hidden_size + self.mlp_intermediate
        else:
            self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
            linear1_out = hidden_size * 3 + self.mlp_hidden_dim
            linear2_in = hidden_size + self.mlp_hidden_dim
        
        # Joint QKV and MLP projection
        self.linear1 = operations.Linear(
            hidden_size, linear1_out, bias=ops_bias, dtype=dtype, device=device
        )
        self.linear2 = operations.Linear(
            linear2_in, hidden_size, bias=ops_bias, dtype=dtype, device=device
        )

        self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
        self.hidden_size = hidden_size
        self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)

        # Only create per-block modulation if not using global modulation
        if not global_modulation:
            self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
        else:
            self.modulation = None

    def forward(
        self,
        x: torch.Tensor,
        vec: torch.Tensor,
        pe: torch.Tensor,
        attn_mask=None,
        modulation=None,  # ModulationOut from global modulation
    ) -> torch.Tensor:
        # Get modulation
        if self.use_global_modulation and modulation is not None:
            mod = modulation
        elif self.modulation is not None:
            mod, _ = self.modulation(vec)
        else:
            raise ValueError("No modulation available - either provide global or use per-block modulation")
        
        x_normed = self.pre_norm(x)
        x_mod = (1 + mod.scale) * x_normed + mod.shift
        del x_normed  # Free memory early
        
        # Joint projection - split QKV from MLP part
        qkv_mlp = self.linear1(x_mod)
        del x_mod
        
        if self.gated_mlp:
            qkv, mlp_gate_up = qkv_mlp.split([self.hidden_size * 3, self.mlp_gate_up_dim], dim=-1)
            del qkv_mlp
            # Gated MLP: split into gate and up, apply SiLU to gate, multiply
            gate, up = mlp_gate_up.chunk(2, dim=-1)
            del mlp_gate_up
            mlp = F.silu(gate) * up
            del gate, up
        else:
            qkv, mlp = qkv_mlp.split([self.hidden_size * 3, self.mlp_hidden_dim], dim=-1)
            del qkv_mlp
            # Standard activation
            if self.silu_mlp:
                mlp = F.silu(mlp)
            else:
                mlp = F.gelu(mlp, approximate="tanh")
        
        # Attention - use view+permute for efficiency instead of rearrange
        q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        del qkv
        q, k = self.norm(q, k, v)
        
        attn = attention(q, k, v, pe=pe, mask=attn_mask)
        del q, k, v
        
        # Combine and project
        output = self.linear2(torch.cat((attn, mlp), dim=-1))
        del attn, mlp
        
        result = x + mod.gate * output
        
        # Handle fp16 numerical issues (matches ComfyUI exactly)
        if result.dtype == torch.float16:
            result = torch.nan_to_num(result, nan=0.0, posinf=65504, neginf=-65504)
        
        return result


class LastLayer(nn.Module):
    """Final layer for unpatchifying and producing output."""

    def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None, ops_bias: bool = True):
        super().__init__()
        if operations is None:
            operations = get_ops()
            
        self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
        self.linear = operations.Linear(
            hidden_size, patch_size * patch_size * out_channels, bias=ops_bias, dtype=dtype, device=device
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            operations.Linear(hidden_size, 2 * hidden_size, bias=ops_bias, dtype=dtype, device=device),
        )

    def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
        shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
        x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
        x = self.linear(x)
        return x