| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
| from typing import Any |
|
|
| import torch |
| from torch import nn |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...loaders import PeftAdapterMixin |
| from ...utils import BaseOutput, apply_lora_scale, logging |
| from ..attention import AttentionMixin |
| from ..embeddings import PatchEmbed, PixArtAlphaTextProjection |
| from ..modeling_outputs import Transformer2DModelOutput |
| from ..modeling_utils import ModelMixin |
| from ..normalization import AdaLayerNormSingle, RMSNorm |
| from ..transformers.sana_transformer import SanaTransformerBlock |
| from .controlnet import zero_module |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @dataclass |
| class SanaControlNetOutput(BaseOutput): |
| controlnet_block_samples: tuple[torch.Tensor] |
|
|
|
|
| class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin): |
| _supports_gradient_checkpointing = True |
| _no_split_modules = ["SanaTransformerBlock", "PatchEmbed"] |
| _skip_layerwise_casting_patterns = ["patch_embed", "norm"] |
|
|
| @register_to_config |
| def __init__( |
| self, |
| in_channels: int = 32, |
| out_channels: int | None = 32, |
| num_attention_heads: int = 70, |
| attention_head_dim: int = 32, |
| num_layers: int = 7, |
| num_cross_attention_heads: int | None = 20, |
| cross_attention_head_dim: int | None = 112, |
| cross_attention_dim: int | None = 2240, |
| caption_channels: int = 2304, |
| mlp_ratio: float = 2.5, |
| dropout: float = 0.0, |
| attention_bias: bool = False, |
| sample_size: int = 32, |
| patch_size: int = 1, |
| norm_elementwise_affine: bool = False, |
| norm_eps: float = 1e-6, |
| interpolation_scale: int | None = None, |
| ) -> None: |
| super().__init__() |
|
|
| out_channels = out_channels or in_channels |
| inner_dim = num_attention_heads * attention_head_dim |
|
|
| |
| self.patch_embed = PatchEmbed( |
| height=sample_size, |
| width=sample_size, |
| patch_size=patch_size, |
| in_channels=in_channels, |
| embed_dim=inner_dim, |
| interpolation_scale=interpolation_scale, |
| pos_embed_type="sincos" if interpolation_scale is not None else None, |
| ) |
|
|
| |
| self.time_embed = AdaLayerNormSingle(inner_dim) |
|
|
| self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) |
| self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| SanaTransformerBlock( |
| inner_dim, |
| num_attention_heads, |
| attention_head_dim, |
| dropout=dropout, |
| num_cross_attention_heads=num_cross_attention_heads, |
| cross_attention_head_dim=cross_attention_head_dim, |
| cross_attention_dim=cross_attention_dim, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| mlp_ratio=mlp_ratio, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| |
| self.controlnet_blocks = nn.ModuleList([]) |
|
|
| self.input_block = zero_module(nn.Linear(inner_dim, inner_dim)) |
| for _ in range(len(self.transformer_blocks)): |
| controlnet_block = nn.Linear(inner_dim, inner_dim) |
| controlnet_block = zero_module(controlnet_block) |
| self.controlnet_blocks.append(controlnet_block) |
|
|
| self.gradient_checkpointing = False |
|
|
| @apply_lora_scale("attention_kwargs") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| timestep: torch.LongTensor, |
| controlnet_cond: torch.Tensor, |
| conditioning_scale: float = 1.0, |
| encoder_attention_mask: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| attention_kwargs: dict[str, Any] | None = None, |
| return_dict: bool = True, |
| ) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if attention_mask is not None and attention_mask.ndim == 2: |
| |
| |
| |
| |
| attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 |
| attention_mask = attention_mask.unsqueeze(1) |
|
|
| |
| if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: |
| encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 |
| encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
| |
| batch_size, num_channels, height, width = hidden_states.shape |
| p = self.config.patch_size |
| post_patch_height, post_patch_width = height // p, width // p |
|
|
| hidden_states = self.patch_embed(hidden_states) |
| hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond.to(hidden_states.dtype))) |
|
|
| timestep, embedded_timestep = self.time_embed( |
| timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype |
| ) |
|
|
| encoder_hidden_states = self.caption_projection(encoder_hidden_states) |
| encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) |
|
|
| encoder_hidden_states = self.caption_norm(encoder_hidden_states) |
|
|
| |
| block_res_samples = () |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| for block in self.transformer_blocks: |
| hidden_states = self._gradient_checkpointing_func( |
| block, |
| hidden_states, |
| attention_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| timestep, |
| post_patch_height, |
| post_patch_width, |
| ) |
| block_res_samples = block_res_samples + (hidden_states,) |
| else: |
| for block in self.transformer_blocks: |
| hidden_states = block( |
| hidden_states, |
| attention_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| timestep, |
| post_patch_height, |
| post_patch_width, |
| ) |
| block_res_samples = block_res_samples + (hidden_states,) |
|
|
| |
| controlnet_block_res_samples = () |
| for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): |
| block_res_sample = controlnet_block(block_res_sample) |
| controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) |
|
|
| controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] |
|
|
| if not return_dict: |
| return (controlnet_block_res_samples,) |
|
|
| return SanaControlNetOutput(controlnet_block_samples=controlnet_block_res_samples) |
|
|