File size: 30,974 Bytes
69e1a8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
from torch import nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
from ..embeddings import get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm


logger = logging.get_logger(__name__)


def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor:
    r"""
    Generates 2D patch coordinate indices for a batch of images.

    Args:
        batch_size (`int`):
            Number of images in the batch.
        height (`int`):
            Height of the input images (in pixels).
        width (`int`):
            Width of the input images (in pixels).
        patch_size (`int`):
            Size of the square patches that the image is divided into.
        device (`torch.device`):
            The device on which to create the tensor.

    Returns:
        `torch.Tensor`:
            Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the
            image grid.
    """

    img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device)
    img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None]
    img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :]
    return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1)


def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    r"""
    Applies rotary positional embeddings (RoPE) to a query tensor.

    Args:
        xq (`torch.Tensor`):
            Input tensor of shape `(..., dim)` representing the queries.
        freqs_cis (`torch.Tensor`):
            Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs.

    Returns:
        `torch.Tensor`:
            Tensor of the same shape as `xq` with rotary embeddings applied.
    """
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    # Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading
    freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype)
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    return xq_out.reshape(*xq.shape).type_as(xq)


class PRXAttnProcessor2_0:
    r"""
    Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention
    backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
    """

    _attention_backend = None
    _parallel_config = None

    def __init__(self):
        if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
            raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn: "PRXAttention",
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        image_rotary_emb: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor:
        """
        Apply PRX attention using PRXAttention module.

        Args:
            attn: PRXAttention module containing projection layers
            hidden_states: Image tokens [B, L_img, D]
            encoder_hidden_states: Text tokens [B, L_txt, D]
            attention_mask: Boolean mask for text tokens [B, L_txt]
            image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2]
        """

        if encoder_hidden_states is None:
            raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")

        # Project image tokens to Q, K, V
        img_qkv = attn.img_qkv_proj(hidden_states)
        B, L_img, _ = img_qkv.shape
        img_qkv = img_qkv.reshape(B, L_img, 3, attn.heads, attn.head_dim)
        img_qkv = img_qkv.permute(2, 0, 3, 1, 4)  # [3, B, H, L_img, D]
        img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2]

        # Apply QK normalization to image tokens
        img_q = attn.norm_q(img_q)
        img_k = attn.norm_k(img_k)

        # Project text tokens to K, V
        txt_kv = attn.txt_kv_proj(encoder_hidden_states)
        B, L_txt, _ = txt_kv.shape
        txt_kv = txt_kv.reshape(B, L_txt, 2, attn.heads, attn.head_dim)
        txt_kv = txt_kv.permute(2, 0, 3, 1, 4)  # [2, B, H, L_txt, D]
        txt_k, txt_v = txt_kv[0], txt_kv[1]

        # Apply K normalization to text tokens
        txt_k = attn.norm_added_k(txt_k)

        # Apply RoPE to image queries and keys
        if image_rotary_emb is not None:
            img_q = apply_rope(img_q, image_rotary_emb)
            img_k = apply_rope(img_k, image_rotary_emb)

        # Concatenate text and image keys/values
        k = torch.cat((txt_k, img_k), dim=2)  # [B, H, L_txt + L_img, D]
        v = torch.cat((txt_v, img_v), dim=2)  # [B, H, L_txt + L_img, D]

        # Build attention mask if provided
        attn_mask_tensor = None
        if attention_mask is not None:
            bs, _, l_img, _ = img_q.shape
            l_txt = txt_k.shape[2]

            if attention_mask.dim() != 2:
                raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
            if attention_mask.shape[-1] != l_txt:
                raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}")

            device = img_q.device
            ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device)
            attention_mask = attention_mask.to(device=device, dtype=torch.bool)
            joint_mask = torch.cat([attention_mask, ones_img], dim=-1)
            attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1)

        # Apply attention using dispatch_attention_fn for backend support
        # Reshape to match dispatch_attention_fn expectations: [B, L, H, D]
        query = img_q.transpose(1, 2)  # [B, L_img, H, D]
        key = k.transpose(1, 2)  # [B, L_txt + L_img, H, D]
        value = v.transpose(1, 2)  # [B, L_txt + L_img, H, D]

        attn_output = dispatch_attention_fn(
            query,
            key,
            value,
            attn_mask=attn_mask_tensor,
            backend=self._attention_backend,
            parallel_config=self._parallel_config,
        )

        # Reshape from [B, L_img, H, D] to [B, L_img, H*D]
        batch_size, seq_len, num_heads, head_dim = attn_output.shape
        attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim)

        # Apply output projection
        attn_output = attn.to_out[0](attn_output)
        if len(attn.to_out) > 1:
            attn_output = attn.to_out[1](attn_output)  # dropout if present

        return attn_output


class PRXAttention(nn.Module, AttentionModuleMixin):
    r"""
    PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
    PRX's architecture.
    """

    _default_processor_cls = PRXAttnProcessor2_0
    _available_processors = [PRXAttnProcessor2_0]

    def __init__(
        self,
        query_dim: int,
        heads: int = 8,
        dim_head: int = 64,
        bias: bool = False,
        out_bias: bool = False,
        eps: float = 1e-6,
        processor=None,
    ):
        super().__init__()

        self.heads = heads
        self.head_dim = dim_head
        self.inner_dim = dim_head * heads
        self.query_dim = query_dim

        self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias)

        self.norm_q = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
        self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)

        self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias)
        self.norm_added_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)

        self.to_out = nn.ModuleList([])
        self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias))
        self.to_out.append(nn.Dropout(0.0))

        if processor is None:
            processor = self._default_processor_cls()
        self.set_processor(processor)

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        image_rotary_emb: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor:
        return self.processor(
            self,
            hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            image_rotary_emb=image_rotary_emb,
            **kwargs,
        )


# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class PRXEmbedND(nn.Module):
    r"""
    N-dimensional rotary positional embedding.

    This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding
    dimension. The embeddings are combined and returned as a single tensor

    Args:
        dim (int):
        Base embedding dimension (must be even).
        theta (int):
        Scaling factor that controls the frequency spectrum of the rotary embeddings.
        axes_dim (list[int]):
        list of embedding dimensions for each axis (each must be even).
    """

    def __init__(self, dim: int, theta: int, axes_dim: list[int]):
        super().__init__()
        self.dim = dim
        self.theta = theta
        self.axes_dim = axes_dim

    def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
        assert dim % 2 == 0

        is_mps = pos.device.type == "mps"
        is_npu = pos.device.type == "npu"
        dtype = torch.float32 if (is_mps or is_npu) else torch.float64

        scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
        omega = 1.0 / (theta**scale)
        out = pos.unsqueeze(-1) * omega.unsqueeze(0)
        out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
        # Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
        # out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2)
        out = out.reshape(*out.shape[:-1], 2, 2)
        return out.float()

    def forward(self, ids: torch.Tensor) -> torch.Tensor:
        n_axes = ids.shape[-1]
        emb = torch.cat(
            [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
            dim=-3,
        )
        return emb.unsqueeze(1)


class MLPEmbedder(nn.Module):
    r"""
    A simple 2-layer MLP used for embedding inputs.

    Args:
        in_dim (`int`):
            Dimensionality of the input features.
        hidden_dim (`int`):
            Dimensionality of the hidden and output embedding space.

    Returns:
        `torch.Tensor`:
            Tensor of shape `(..., hidden_dim)` containing the embedded representations.
    """

    def __init__(self, in_dim: int, hidden_dim: int):
        super().__init__()
        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
        self.silu = nn.SiLU()
        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.out_layer(self.silu(self.in_layer(x)))


class Modulation(nn.Module):
    r"""
    Modulation network that generates scale, shift, and gating parameters.

    Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
    two tuples `(shift, scale, gate)`.

    Args:
        dim (`int`):
            Dimensionality of the input vector. The output will have `6 * dim` features internally.

    Returns:
        ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
            Two tuples `(shift, scale, gate)`.
    """

    def __init__(self, dim: int):
        super().__init__()
        self.lin = nn.Linear(dim, 6 * dim, bias=True)
        nn.init.constant_(self.lin.weight, 0)
        nn.init.constant_(self.lin.bias, 0)

    def forward(
        self, vec: torch.Tensor
    ) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
        out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
        return tuple(out[:3]), tuple(out[3:])


class PRXBlock(nn.Module):
    r"""
    Multimodal transformer block with text–image cross-attention, modulation, and MLP.

    Args:
        hidden_size (`int`):
            Dimension of the hidden representations.
        num_heads (`int`):
            Number of attention heads.
        mlp_ratio (`float`, *optional*, defaults to 4.0):
            Expansion ratio for the hidden dimension inside the MLP.
        qk_scale (`float`, *optional*):
            Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``.

    Attributes:
        img_pre_norm (`nn.LayerNorm`):
            Pre-normalization applied to image tokens before attention.
        attention (`PRXAttention`):
            Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
            image and text tokens.
        post_attention_layernorm (`nn.LayerNorm`):
            Normalization applied after attention.
        gate_proj / up_proj / down_proj (`nn.Linear`):
            Feedforward layers forming the gated MLP.
        mlp_act (`nn.GELU`):
            Nonlinear activation used in the MLP.
        modulation (`Modulation`):
            Produces scale/shift/gating parameters for modulated layers.

        Methods:
            The forward method performs cross-attention and the MLP with modulation.
    """

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qk_scale: float | None = None,
    ):
        super().__init__()

        self.hidden_dim = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scale = qk_scale or self.head_dim**-0.5

        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.hidden_size = hidden_size

        # Pre-attention normalization for image tokens
        self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

        # PRXAttention module with built-in projections and norms
        self.attention = PRXAttention(
            query_dim=hidden_size,
            heads=num_heads,
            dim_head=self.head_dim,
            bias=False,
            out_bias=False,
            eps=1e-6,
            processor=PRXAttnProcessor2_0(),
        )

        # mlp
        self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
        self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
        self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False)
        self.mlp_act = nn.GELU(approximate="tanh")

        self.modulation = Modulation(hidden_size)

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        temb: torch.Tensor,
        image_rotary_emb: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        **kwargs: dict[str, Any],
    ) -> torch.Tensor:
        r"""
        Runs modulation-gated cross-attention and MLP, with residual connections.

        Args:
            hidden_states (`torch.Tensor`):
                Image tokens of shape `(B, L_img, hidden_size)`.
            encoder_hidden_states (`torch.Tensor`):
                Text tokens of shape `(B, L_txt, hidden_size)`.
            temb (`torch.Tensor`):
                Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or
                broadcastable).
            image_rotary_emb (`torch.Tensor`):
                Rotary positional embeddings applied inside attention.
            attention_mask (`torch.Tensor`, *optional*):
                Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
            **kwargs:
                Additional keyword arguments for API compatibility.

        Returns:
            `torch.Tensor`:
                Updated image tokens of shape `(B, L_img, hidden_size)`.
        """

        mod_attn, mod_mlp = self.modulation(temb)
        attn_shift, attn_scale, attn_gate = mod_attn
        mlp_shift, mlp_scale, mlp_gate = mod_mlp

        hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift

        attn_out = self.attention(
            hidden_states=hidden_states_mod,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            image_rotary_emb=image_rotary_emb,
        )

        hidden_states = hidden_states + attn_gate * attn_out

        x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift
        hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)))
        return hidden_states


class FinalLayer(nn.Module):
    r"""
    Final projection layer with adaptive LayerNorm modulation.

    This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level
    outputs.

    Args:
        hidden_size (`int`):
            Dimensionality of the input tokens.
        patch_size (`int`):
            Size of the square image patches.
        out_channels (`int`):
            Number of output channels per pixel (e.g. RGB = 3).

    Forward Inputs:
        x (`torch.Tensor`):
            Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches.
        vec (`torch.Tensor`):
            Conditioning vector of shape `(B, hidden_size)` used to generate shift and scale parameters for adaptive
            LayerNorm.

    Returns:
        `torch.Tensor`:
            Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`.
    """

    def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))

    def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
        shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
        x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
        x = self.linear(x)
        return x


def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
    r"""
    Flattens an image tensor into a sequence of non-overlapping patches.

    Args:
        img (`torch.Tensor`):
            Input image tensor of shape `(B, C, H, W)`.
        patch_size (`int`):
            Size of each square patch. Must evenly divide both `H` and `W`.

    Returns:
        `torch.Tensor`:
            Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
            // patch_size)` is the number of patches.
    """
    b, c, h, w = img.shape
    p = patch_size

    # Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions
    img = img.reshape(b, c, h // p, p, w // p, p)

    # Permute to (B, H//p, W//p, C, p, p) using einsum
    # n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width
    img = torch.einsum("nchpwq->nhwcpq", img)

    # Flatten to (B, L, C * p * p)
    img = img.reshape(b, -1, c * p * p)
    return img


def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
    r"""
    Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).

    Args:
        seq (`torch.Tensor`):
            Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W //
            patch_size)`.
        patch_size (`int`):
            Size of each square patch.
        shape (`tuple` or `torch.Tensor`):
            The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as
            height and width.

    Returns:
        `torch.Tensor`:
            Reconstructed image tensor of shape `(B, C, H, W)`.
    """
    if isinstance(shape, tuple):
        h, w = shape[-2:]
    elif isinstance(shape, torch.Tensor):
        h, w = (int(shape[0]), int(shape[1]))
    else:
        raise NotImplementedError(f"shape type {type(shape)} not supported")

    b, l, d = seq.shape
    p = patch_size
    c = d // (p * p)

    # Reshape back to grid structure: (B, H//p, W//p, C, p, p)
    seq = seq.reshape(b, h // p, w // p, c, p, p)

    # Permute back to image layout: (B, C, H//p, p, W//p, p)
    # n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width
    seq = torch.einsum("nhwcpq->nchpwq", seq)

    # Final reshape to (B, C, H, W)
    seq = seq.reshape(b, c, h, w)
    return seq


class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
    r"""
    Transformer-based 2D model for text to image generation.

    Args:
        in_channels (`int`, *optional*, defaults to 16):
            Number of input channels in the latent image.
        patch_size (`int`, *optional*, defaults to 2):
            Size of the square patches used to flatten the input image.
        context_in_dim (`int`, *optional*, defaults to 2304):
            Dimensionality of the text conditioning input.
        hidden_size (`int`, *optional*, defaults to 1792):
            Dimension of the hidden representation.
        mlp_ratio (`float`, *optional*, defaults to 3.5):
            Expansion ratio for the hidden dimension inside MLP blocks.
        num_heads (`int`, *optional*, defaults to 28):
            Number of attention heads.
        depth (`int`, *optional*, defaults to 16):
            Number of transformer blocks.
        axes_dim (`list[int]`, *optional*):
            list of dimensions for each positional embedding axis. Defaults to `[32, 32]`.
        theta (`int`, *optional*, defaults to 10000):
            Frequency scaling factor for rotary embeddings.
        time_factor (`float`, *optional*, defaults to 1000.0):
            Scaling factor applied in timestep embeddings.
        time_max_period (`int`, *optional*, defaults to 10000):
            Maximum frequency period for timestep embeddings.

    Attributes:
        pe_embedder (`EmbedND`):
            Multi-axis rotary embedding generator for positional encodings.
        img_in (`nn.Linear`):
            Projection layer for image patch tokens.
        time_in (`MLPEmbedder`):
            Embedding layer for timestep embeddings.
        txt_in (`nn.Linear`):
            Projection layer for text conditioning.
        blocks (`nn.ModuleList`):
            Stack of transformer blocks (`PRXBlock`).
        final_layer (`LastLayer`):
            Projection layer mapping hidden tokens back to patch outputs.

    Methods:
        attn_processors:
            Returns a dictionary of all attention processors in the model.
        set_attn_processor(processor):
            Replaces attention processors across all attention layers.
        process_inputs(image_latent, txt):
            Converts inputs into patch tokens, encodes text, and produces positional encodings.
        compute_timestep_embedding(timestep, dtype):
            Creates a timestep embedding of dimension 256, scaled and projected.
        forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask,
        **block_kwargs):
            Runs the sequence of transformer blocks over image and text tokens.
        forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None,
        attention_kwargs=None, return_dict=True):
            Full forward pass from latent input to reconstructed output image.

    Returns:
        `Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing:
            - `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`.
    """

    config_name = "config.json"
    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        in_channels: int = 16,
        patch_size: int = 2,
        context_in_dim: int = 2304,
        hidden_size: int = 1792,
        mlp_ratio: float = 3.5,
        num_heads: int = 28,
        depth: int = 16,
        axes_dim: list = None,
        theta: int = 10000,
        time_factor: float = 1000.0,
        time_max_period: int = 10000,
    ):
        super().__init__()

        if axes_dim is None:
            axes_dim = [32, 32]

        # Store parameters directly
        self.in_channels = in_channels
        self.patch_size = patch_size
        self.out_channels = self.in_channels * self.patch_size**2

        self.time_factor = time_factor
        self.time_max_period = time_max_period

        if hidden_size % num_heads != 0:
            raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")

        pe_dim = hidden_size // num_heads

        if sum(axes_dim) != pe_dim:
            raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")

        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
        self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
        self.txt_in = nn.Linear(context_in_dim, self.hidden_size)

        self.blocks = nn.ModuleList(
            [
                PRXBlock(
                    self.hidden_size,
                    self.num_heads,
                    mlp_ratio=mlp_ratio,
                )
                for i in range(depth)
            ]
        )

        self.final_layer = FinalLayer(self.hidden_size, 1, self.out_channels)

        self.gradient_checkpointing = False

    def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
        return self.time_in(
            get_timestep_embedding(
                timesteps=timestep,
                embedding_dim=256,
                max_period=self.time_max_period,
                scale=self.time_factor,
                flip_sin_to_cos=True,  # Match original cos, sin order
                downscale_freq_shift=0.0,
            ).to(dtype)
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        timestep: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        attention_kwargs: dict[str, Any] | None = None,
        return_dict: bool = True,
    ) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput:
        r"""
        Forward pass of the PRXTransformer2DModel.

        The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
        transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.

        Args:
            hidden_states (`torch.Tensor`):
                Input latent image tensor of shape `(B, C, H, W)`.
            timestep (`torch.Tensor`):
                Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
            encoder_hidden_states (`torch.Tensor`):
                Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
            attention_mask (`torch.Tensor`, *optional*):
                Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
            attention_kwargs (`dict`, *optional*):
                Additional arguments passed to attention layers.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a `Transformer2DModelOutput` or a tuple.

        Returns:
            `Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple:

                - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
        """
        # Process text conditioning
        txt = self.txt_in(encoder_hidden_states)

        # Convert image to sequence and embed
        img = img2seq(hidden_states, self.patch_size)
        img = self.img_in(img)

        # Generate positional embeddings
        bs, _, h, w = hidden_states.shape
        img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device)
        pe = self.pe_embedder(img_ids)

        # Compute time embedding
        vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)

        # Apply transformer blocks
        for block in self.blocks:
            if torch.is_grad_enabled() and self.gradient_checkpointing:
                img = self._gradient_checkpointing_func(
                    block.__call__,
                    img,
                    txt,
                    vec,
                    pe,
                    attention_mask,
                )
            else:
                img = block(
                    hidden_states=img,
                    encoder_hidden_states=txt,
                    temb=vec,
                    image_rotary_emb=pe,
                    attention_mask=attention_mask,
                )

        # Final layer and convert back to image
        img = self.final_layer(img, vec)
        output = seq2img(img, self.patch_size, hidden_states.shape)

        if not return_dict:
            return (output,)
        return Transformer2DModelOutput(sample=output)