# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, FeedForward, JointTransformerBlock from ..attention_processor import ( Attention, FusedJointAttnProcessor2_0, JointAttnProcessor2_0, ) from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero logger = logging.get_logger(__name__) # pylint: disable=invalid-name @maybe_allow_in_graph class SD3SingleTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, ): super().__init__() self.norm1 = AdaLayerNormZero(dim) self.attn = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=True, processor=JointAttnProcessor2_0(), eps=1e-6, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor): # 1. Attention norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None) attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output # 2. Feed Forward norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output return hidden_states class SD3Transformer2DModel( ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin ): """ The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). Parameters: sample_size (`int`, defaults to `128`): The width/height of the latents. This is fixed during training since it is used to learn a number of position embeddings. patch_size (`int`, defaults to `2`): Patch size to turn the input data into small patches. in_channels (`int`, defaults to `16`): The number of latent channels in the input. num_layers (`int`, defaults to `18`): The number of layers of transformer blocks to use. attention_head_dim (`int`, defaults to `64`): The number of channels in each head. num_attention_heads (`int`, defaults to `18`): The number of heads to use for multi-head attention. joint_attention_dim (`int`, defaults to `4096`): The embedding dimension to use for joint text-image attention. caption_projection_dim (`int`, defaults to `1152`): The embedding dimension of caption embeddings. pooled_projection_dim (`int`, defaults to `2048`): The embedding dimension of pooled text projections. out_channels (`int`, defaults to `16`): The number of latent channels in the output. pos_embed_max_size (`int`, defaults to `96`): The maximum latent height/width of positional embeddings. dual_attention_layers (`tuple[int, ...]`, defaults to `()`): The number of dual-stream transformer blocks to use. qk_norm (`str`, *optional*, defaults to `None`): The normalization to use for query and key in the attention layer. If `None`, no normalization is used. """ _supports_gradient_checkpointing = True _no_split_modules = ["JointTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config def __init__( self, sample_size: int = 128, patch_size: int = 2, in_channels: int = 16, num_layers: int = 18, attention_head_dim: int = 64, num_attention_heads: int = 18, joint_attention_dim: int = 4096, caption_projection_dim: int = 1152, pooled_projection_dim: int = 2048, out_channels: int = 16, pos_embed_max_size: int = 96, dual_attention_layers: tuple[ int, ... ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 qk_norm: str | None = None, ): super().__init__() self.out_channels = out_channels if out_channels is not None else in_channels self.inner_dim = num_attention_heads * attention_head_dim self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=self.inner_dim, pos_embed_max_size=pos_embed_max_size, # hard-code for now. ) self.time_text_embed = CombinedTimestepTextProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim ) self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) self.transformer_blocks = nn.ModuleList( [ JointTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, context_pre_only=i == num_layers - 1, qk_norm=qk_norm, use_dual_attention=True if i in dual_attention_layers else False, ) for i in range(num_layers) ] ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: int | None = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking def disable_forward_chunking(self): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, None, 0) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. > [!WARNING] > This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedJointAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. > [!WARNING] > This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, block_controlnet_hidden_states: list = None, joint_attention_kwargs: dict[str, Any] | None = None, return_dict: bool = True, skip_layers: list[int] | None = None, ) -> torch.Tensor | Transformer2DModelOutput: """ The [`SD3Transformer2DModel`] forward method. Args: hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep (`torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. skip_layers (`list` of `int`, *optional*): A list of layer indices to skip during the forward pass. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ height, width = hidden_states.shape[-2:] hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep) joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb) for index_block, block in enumerate(self.transformer_blocks): # Skip specified layers is_skip = True if skip_layers is not None and index_block in skip_layers else False if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, temb, joint_attention_kwargs, ) elif not is_skip: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual if block_controlnet_hidden_states is not None and block.context_pre_only is False: interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states) hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)] hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) # unpatchify patch_size = self.config.patch_size height = height // patch_size width = width // patch_size hidden_states = hidden_states.reshape( shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) ) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output)