# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..embeddings import get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import RMSNorm logger = logging.get_logger(__name__) def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor: r""" Generates 2D patch coordinate indices for a batch of images. Args: batch_size (`int`): Number of images in the batch. height (`int`): Height of the input images (in pixels). width (`int`): Width of the input images (in pixels). patch_size (`int`): Size of the square patches that the image is divided into. device (`torch.device`): The device on which to create the tensor. Returns: `torch.Tensor`: Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the image grid. """ img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device) img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None] img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :] return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1) def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: r""" Applies rotary positional embeddings (RoPE) to a query tensor. Args: xq (`torch.Tensor`): Input tensor of shape `(..., dim)` representing the queries. freqs_cis (`torch.Tensor`): Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs. Returns: `torch.Tensor`: Tensor of the same shape as `xq` with rotary embeddings applied. """ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) # Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq) class PRXAttnProcessor2_0: r""" Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn. """ _attention_backend = None _parallel_config = None def __init__(self): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") def __call__( self, attn: "PRXAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, image_rotary_emb: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: """ Apply PRX attention using PRXAttention module. Args: attn: PRXAttention module containing projection layers hidden_states: Image tokens [B, L_img, D] encoder_hidden_states: Text tokens [B, L_txt, D] attention_mask: Boolean mask for text tokens [B, L_txt] image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2] """ if encoder_hidden_states is None: raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.") # Project image tokens to Q, K, V img_qkv = attn.img_qkv_proj(hidden_states) B, L_img, _ = img_qkv.shape img_qkv = img_qkv.reshape(B, L_img, 3, attn.heads, attn.head_dim) img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # [3, B, H, L_img, D] img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2] # Apply QK normalization to image tokens img_q = attn.norm_q(img_q) img_k = attn.norm_k(img_k) # Project text tokens to K, V txt_kv = attn.txt_kv_proj(encoder_hidden_states) B, L_txt, _ = txt_kv.shape txt_kv = txt_kv.reshape(B, L_txt, 2, attn.heads, attn.head_dim) txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # [2, B, H, L_txt, D] txt_k, txt_v = txt_kv[0], txt_kv[1] # Apply K normalization to text tokens txt_k = attn.norm_added_k(txt_k) # Apply RoPE to image queries and keys if image_rotary_emb is not None: img_q = apply_rope(img_q, image_rotary_emb) img_k = apply_rope(img_k, image_rotary_emb) # Concatenate text and image keys/values k = torch.cat((txt_k, img_k), dim=2) # [B, H, L_txt + L_img, D] v = torch.cat((txt_v, img_v), dim=2) # [B, H, L_txt + L_img, D] # Build attention mask if provided attn_mask_tensor = None if attention_mask is not None: bs, _, l_img, _ = img_q.shape l_txt = txt_k.shape[2] if attention_mask.dim() != 2: raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") if attention_mask.shape[-1] != l_txt: raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}") device = img_q.device ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) attention_mask = attention_mask.to(device=device, dtype=torch.bool) joint_mask = torch.cat([attention_mask, ones_img], dim=-1) attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1) # Apply attention using dispatch_attention_fn for backend support # Reshape to match dispatch_attention_fn expectations: [B, L, H, D] query = img_q.transpose(1, 2) # [B, L_img, H, D] key = k.transpose(1, 2) # [B, L_txt + L_img, H, D] value = v.transpose(1, 2) # [B, L_txt + L_img, H, D] attn_output = dispatch_attention_fn( query, key, value, attn_mask=attn_mask_tensor, backend=self._attention_backend, parallel_config=self._parallel_config, ) # Reshape from [B, L_img, H, D] to [B, L_img, H*D] batch_size, seq_len, num_heads, head_dim = attn_output.shape attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim) # Apply output projection attn_output = attn.to_out[0](attn_output) if len(attn.to_out) > 1: attn_output = attn.to_out[1](attn_output) # dropout if present return attn_output class PRXAttention(nn.Module, AttentionModuleMixin): r""" PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for PRX's architecture. """ _default_processor_cls = PRXAttnProcessor2_0 _available_processors = [PRXAttnProcessor2_0] def __init__( self, query_dim: int, heads: int = 8, dim_head: int = 64, bias: bool = False, out_bias: bool = False, eps: float = 1e-6, processor=None, ): super().__init__() self.heads = heads self.head_dim = dim_head self.inner_dim = dim_head * heads self.query_dim = query_dim self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias) self.norm_q = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True) self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True) self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias) self.norm_added_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias)) self.to_out.append(nn.Dropout(0.0)) if processor is None: processor = self._default_processor_cls() self.set_processor(processor) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, image_rotary_emb: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, **kwargs, ) # inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py class PRXEmbedND(nn.Module): r""" N-dimensional rotary positional embedding. This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding dimension. The embeddings are combined and returned as a single tensor Args: dim (int): Base embedding dimension (must be even). theta (int): Scaling factor that controls the frequency spectrum of the rotary embeddings. axes_dim (list[int]): list of embedding dimensions for each axis (each must be even). """ def __init__(self, dim: int, theta: int, axes_dim: list[int]): super().__init__() self.dim = dim self.theta = theta self.axes_dim = axes_dim def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 is_mps = pos.device.type == "mps" is_npu = pos.device.type == "npu" dtype = torch.float32 if (is_mps or is_npu) else torch.float64 scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) out = pos.unsqueeze(-1) * omega.unsqueeze(0) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) # Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2) # out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2) out = out.reshape(*out.shape[:-1], 2, 2) return out.float() def forward(self, ids: torch.Tensor) -> torch.Tensor: n_axes = ids.shape[-1] emb = torch.cat( [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3, ) return emb.unsqueeze(1) class MLPEmbedder(nn.Module): r""" A simple 2-layer MLP used for embedding inputs. Args: in_dim (`int`): Dimensionality of the input features. hidden_dim (`int`): Dimensionality of the hidden and output embedding space. Returns: `torch.Tensor`: Tensor of shape `(..., hidden_dim)` containing the embedded representations. """ def __init__(self, in_dim: int, hidden_dim: int): super().__init__() self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) self.silu = nn.SiLU() self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.out_layer(self.silu(self.in_layer(x))) class Modulation(nn.Module): r""" Modulation network that generates scale, shift, and gating parameters. Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into two tuples `(shift, scale, gate)`. Args: dim (`int`): Dimensionality of the input vector. The output will have `6 * dim` features internally. Returns: ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)): Two tuples `(shift, scale, gate)`. """ def __init__(self, dim: int): super().__init__() self.lin = nn.Linear(dim, 6 * dim, bias=True) nn.init.constant_(self.lin.weight, 0) nn.init.constant_(self.lin.bias, 0) def forward( self, vec: torch.Tensor ) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) return tuple(out[:3]), tuple(out[3:]) class PRXBlock(nn.Module): r""" Multimodal transformer block with text–image cross-attention, modulation, and MLP. Args: hidden_size (`int`): Dimension of the hidden representations. num_heads (`int`): Number of attention heads. mlp_ratio (`float`, *optional*, defaults to 4.0): Expansion ratio for the hidden dimension inside the MLP. qk_scale (`float`, *optional*): Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``. Attributes: img_pre_norm (`nn.LayerNorm`): Pre-normalization applied to image tokens before attention. attention (`PRXAttention`): Multi-head attention module with built-in QKV projections and normalizations for cross-attention between image and text tokens. post_attention_layernorm (`nn.LayerNorm`): Normalization applied after attention. gate_proj / up_proj / down_proj (`nn.Linear`): Feedforward layers forming the gated MLP. mlp_act (`nn.GELU`): Nonlinear activation used in the MLP. modulation (`Modulation`): Produces scale/shift/gating parameters for modulated layers. Methods: The forward method performs cross-attention and the MLP with modulation. """ def __init__( self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, qk_scale: float | None = None, ): super().__init__() self.hidden_dim = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.scale = qk_scale or self.head_dim**-0.5 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) self.hidden_size = hidden_size # Pre-attention normalization for image tokens self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) # PRXAttention module with built-in projections and norms self.attention = PRXAttention( query_dim=hidden_size, heads=num_heads, dim_head=self.head_dim, bias=False, out_bias=False, eps=1e-6, processor=PRXAttnProcessor2_0(), ) # mlp self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False) self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: torch.Tensor, attention_mask: torch.Tensor | None = None, **kwargs: dict[str, Any], ) -> torch.Tensor: r""" Runs modulation-gated cross-attention and MLP, with residual connections. Args: hidden_states (`torch.Tensor`): Image tokens of shape `(B, L_img, hidden_size)`. encoder_hidden_states (`torch.Tensor`): Text tokens of shape `(B, L_txt, hidden_size)`. temb (`torch.Tensor`): Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or broadcastable). image_rotary_emb (`torch.Tensor`): Rotary positional embeddings applied inside attention. attention_mask (`torch.Tensor`, *optional*): Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding. **kwargs: Additional keyword arguments for API compatibility. Returns: `torch.Tensor`: Updated image tokens of shape `(B, L_img, hidden_size)`. """ mod_attn, mod_mlp = self.modulation(temb) attn_shift, attn_scale, attn_gate = mod_attn mlp_shift, mlp_scale, mlp_gate = mod_mlp hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift attn_out = self.attention( hidden_states=hidden_states_mod, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + attn_gate * attn_out x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x))) return hidden_states class FinalLayer(nn.Module): r""" Final projection layer with adaptive LayerNorm modulation. This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level outputs. Args: hidden_size (`int`): Dimensionality of the input tokens. patch_size (`int`): Size of the square image patches. out_channels (`int`): Number of output channels per pixel (e.g. RGB = 3). Forward Inputs: x (`torch.Tensor`): Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches. vec (`torch.Tensor`): Conditioning vector of shape `(B, hidden_size)` used to generate shift and scale parameters for adaptive LayerNorm. Returns: `torch.Tensor`: Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`. """ def __init__(self, hidden_size: int, patch_size: int, out_channels: int): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] x = self.linear(x) return x def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor: r""" Flattens an image tensor into a sequence of non-overlapping patches. Args: img (`torch.Tensor`): Input image tensor of shape `(B, C, H, W)`. patch_size (`int`): Size of each square patch. Must evenly divide both `H` and `W`. Returns: `torch.Tensor`: Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W // patch_size)` is the number of patches. """ b, c, h, w = img.shape p = patch_size # Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions img = img.reshape(b, c, h // p, p, w // p, p) # Permute to (B, H//p, W//p, C, p, p) using einsum # n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width img = torch.einsum("nchpwq->nhwcpq", img) # Flatten to (B, L, C * p * p) img = img.reshape(b, -1, c * p * p) return img def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor: r""" Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`). Args: seq (`torch.Tensor`): Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W // patch_size)`. patch_size (`int`): Size of each square patch. shape (`tuple` or `torch.Tensor`): The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as height and width. Returns: `torch.Tensor`: Reconstructed image tensor of shape `(B, C, H, W)`. """ if isinstance(shape, tuple): h, w = shape[-2:] elif isinstance(shape, torch.Tensor): h, w = (int(shape[0]), int(shape[1])) else: raise NotImplementedError(f"shape type {type(shape)} not supported") b, l, d = seq.shape p = patch_size c = d // (p * p) # Reshape back to grid structure: (B, H//p, W//p, C, p, p) seq = seq.reshape(b, h // p, w // p, c, p, p) # Permute back to image layout: (B, C, H//p, p, W//p, p) # n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width seq = torch.einsum("nhwcpq->nchpwq", seq) # Final reshape to (B, C, H, W) seq = seq.reshape(b, c, h, w) return seq class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): r""" Transformer-based 2D model for text to image generation. Args: in_channels (`int`, *optional*, defaults to 16): Number of input channels in the latent image. patch_size (`int`, *optional*, defaults to 2): Size of the square patches used to flatten the input image. context_in_dim (`int`, *optional*, defaults to 2304): Dimensionality of the text conditioning input. hidden_size (`int`, *optional*, defaults to 1792): Dimension of the hidden representation. mlp_ratio (`float`, *optional*, defaults to 3.5): Expansion ratio for the hidden dimension inside MLP blocks. num_heads (`int`, *optional*, defaults to 28): Number of attention heads. depth (`int`, *optional*, defaults to 16): Number of transformer blocks. axes_dim (`list[int]`, *optional*): list of dimensions for each positional embedding axis. Defaults to `[32, 32]`. theta (`int`, *optional*, defaults to 10000): Frequency scaling factor for rotary embeddings. time_factor (`float`, *optional*, defaults to 1000.0): Scaling factor applied in timestep embeddings. time_max_period (`int`, *optional*, defaults to 10000): Maximum frequency period for timestep embeddings. Attributes: pe_embedder (`EmbedND`): Multi-axis rotary embedding generator for positional encodings. img_in (`nn.Linear`): Projection layer for image patch tokens. time_in (`MLPEmbedder`): Embedding layer for timestep embeddings. txt_in (`nn.Linear`): Projection layer for text conditioning. blocks (`nn.ModuleList`): Stack of transformer blocks (`PRXBlock`). final_layer (`LastLayer`): Projection layer mapping hidden tokens back to patch outputs. Methods: attn_processors: Returns a dictionary of all attention processors in the model. set_attn_processor(processor): Replaces attention processors across all attention layers. process_inputs(image_latent, txt): Converts inputs into patch tokens, encodes text, and produces positional encodings. compute_timestep_embedding(timestep, dtype): Creates a timestep embedding of dimension 256, scaled and projected. forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask, **block_kwargs): Runs the sequence of transformer blocks over image and text tokens. forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None, attention_kwargs=None, return_dict=True): Full forward pass from latent input to reconstructed output image. Returns: `Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing: - `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`. """ config_name = "config.json" _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 16, patch_size: int = 2, context_in_dim: int = 2304, hidden_size: int = 1792, mlp_ratio: float = 3.5, num_heads: int = 28, depth: int = 16, axes_dim: list = None, theta: int = 10000, time_factor: float = 1000.0, time_max_period: int = 10000, ): super().__init__() if axes_dim is None: axes_dim = [32, 32] # Store parameters directly self.in_channels = in_channels self.patch_size = patch_size self.out_channels = self.in_channels * self.patch_size**2 self.time_factor = time_factor self.time_max_period = time_max_period if hidden_size % num_heads != 0: raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}") pe_dim = hidden_size // num_heads if sum(axes_dim) != pe_dim: raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}") self.hidden_size = hidden_size self.num_heads = num_heads self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.txt_in = nn.Linear(context_in_dim, self.hidden_size) self.blocks = nn.ModuleList( [ PRXBlock( self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, ) for i in range(depth) ] ) self.final_layer = FinalLayer(self.hidden_size, 1, self.out_channels) self.gradient_checkpointing = False def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return self.time_in( get_timestep_embedding( timesteps=timestep, embedding_dim=256, max_period=self.time_max_period, scale=self.time_factor, flip_sin_to_cos=True, # Match original cos, sin order downscale_freq_shift=0.0, ).to(dtype) ) def forward( self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, attention_kwargs: dict[str, Any] | None = None, return_dict: bool = True, ) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput: r""" Forward pass of the PRXTransformer2DModel. The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of transformer blocks modulated by the timestep. The output is reconstructed into the latent image space. Args: hidden_states (`torch.Tensor`): Input latent image tensor of shape `(B, C, H, W)`. timestep (`torch.Tensor`): Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning. encoder_hidden_states (`torch.Tensor`): Text conditioning tensor of shape `(B, L_txt, context_in_dim)`. attention_mask (`torch.Tensor`, *optional*): Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence. attention_kwargs (`dict`, *optional*): Additional arguments passed to attention layers. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a `Transformer2DModelOutput` or a tuple. Returns: `Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple: - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`. """ # Process text conditioning txt = self.txt_in(encoder_hidden_states) # Convert image to sequence and embed img = img2seq(hidden_states, self.patch_size) img = self.img_in(img) # Generate positional embeddings bs, _, h, w = hidden_states.shape img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device) pe = self.pe_embedder(img_ids) # Compute time embedding vec = self._compute_timestep_embedding(timestep, dtype=img.dtype) # Apply transformer blocks for block in self.blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: img = self._gradient_checkpointing_func( block.__call__, img, txt, vec, pe, attention_mask, ) else: img = block( hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe, attention_mask=attention_mask, ) # Final layer and convert back to image img = self.final_layer(img, vec) output = seq2img(img, self.patch_size, hidden_states.shape) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output)