build-tools / diffusers /models /controlnets /controlnet_sana.py
salmankhanpm's picture
Add files using upload-large-folder tool
69e1a8d verified
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import BaseOutput, apply_lora_scale, logging
from ..attention import AttentionMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
from ..transformers.sana_transformer import SanaTransformerBlock
from .controlnet import zero_module
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class SanaControlNetOutput(BaseOutput):
controlnet_block_samples: tuple[torch.Tensor]
class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
@register_to_config
def __init__(
self,
in_channels: int = 32,
out_channels: int | None = 32,
num_attention_heads: int = 70,
attention_head_dim: int = 32,
num_layers: int = 7,
num_cross_attention_heads: int | None = 20,
cross_attention_head_dim: int | None = 112,
cross_attention_dim: int | None = 2240,
caption_channels: int = 2304,
mlp_ratio: float = 2.5,
dropout: float = 0.0,
attention_bias: bool = False,
sample_size: int = 32,
patch_size: int = 1,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: int | None = None,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
inner_dim = num_attention_heads * attention_head_dim
# 1. Patch Embedding
self.patch_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
pos_embed_type="sincos" if interpolation_scale is not None else None,
)
# 2. Additional condition embeddings
self.time_embed = AdaLayerNormSingle(inner_dim)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
# 3. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
SanaTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
num_cross_attention_heads=num_cross_attention_heads,
cross_attention_head_dim=cross_attention_head_dim,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
)
for _ in range(num_layers)
]
)
# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
self.input_block = zero_module(nn.Linear(inner_dim, inner_dim))
for _ in range(len(self.transformer_blocks)):
controlnet_block = nn.Linear(inner_dim, inner_dim)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
encoder_attention_mask: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput:
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size, num_channels, height, width = hidden_states.shape
p = self.config.patch_size
post_patch_height, post_patch_width = height // p, width // p
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond.to(hidden_states.dtype)))
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
# 2. Transformer blocks
block_res_samples = ()
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
post_patch_height,
post_patch_width,
)
block_res_samples = block_res_samples + (hidden_states,)
else:
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
post_patch_height,
post_patch_width,
)
block_res_samples = block_res_samples + (hidden_states,)
# 3. ControlNet blocks
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
if not return_dict:
return (controlnet_block_res_samples,)
return SanaControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)