| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Any, Dict, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin |
| from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers |
| from diffusers.models.attention_processor import ( |
| Attention, |
| AttentionProcessor, |
| SanaLinearAttnProcessor2_0, |
| ) |
| from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput |
| from diffusers.models.modeling_utils import ModelMixin |
| from diffusers.models.normalization import AdaLayerNormSingle, RMSNorm |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class GLUMBConv(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| expand_ratio: float = 4, |
| norm_type: Optional[str] = None, |
| residual_connection: bool = True, |
| ) -> None: |
| super().__init__() |
|
|
| hidden_channels = int(expand_ratio * in_channels) |
| self.norm_type = norm_type |
| self.residual_connection = residual_connection |
|
|
| self.nonlinearity = nn.SiLU() |
| self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0) |
| self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2) |
| self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False) |
|
|
| self.norm = None |
| if norm_type == "rms_norm": |
| self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| if self.residual_connection: |
| residual = hidden_states |
|
|
| hidden_states = self.conv_inverted(hidden_states) |
| hidden_states = self.nonlinearity(hidden_states) |
|
|
| hidden_states = self.conv_depth(hidden_states) |
| hidden_states, gate = torch.chunk(hidden_states, 2, dim=1) |
| hidden_states = hidden_states * self.nonlinearity(gate) |
|
|
| hidden_states = self.conv_point(hidden_states) |
|
|
| if self.norm_type == "rms_norm": |
| |
| hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1) |
|
|
| if self.residual_connection: |
| hidden_states = hidden_states + residual |
|
|
| return hidden_states |
|
|
|
|
| class SanaModulatedNorm(nn.Module): |
| def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6): |
| super().__init__() |
| self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps) |
|
|
| def forward( |
| self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor |
| ) -> torch.Tensor: |
| hidden_states = self.norm(hidden_states) |
| shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1) |
| hidden_states = hidden_states * (1 + scale) + shift |
| return hidden_states |
|
|
|
|
| class SanaCombinedTimestepGuidanceEmbeddings(nn.Module): |
| def __init__(self, embedding_dim): |
| super().__init__() |
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) |
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
|
|
| self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) |
| self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
|
|
| self.silu = nn.SiLU() |
| self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) |
|
|
| def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None): |
| timesteps_proj = self.time_proj(timestep) |
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) |
|
|
| guidance_proj = self.guidance_condition_proj(guidance) |
| guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype)) |
| conditioning = timesteps_emb + guidance_emb |
|
|
| return self.linear(self.silu(conditioning)), conditioning |
|
|
|
|
| class SanaAttnProcessor2_0: |
| r""" |
| Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
| """ |
|
|
| def __init__(self): |
| if not hasattr(F, "scaled_dot_product_attention"): |
| raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
| def __call__( |
| self, |
| attn: Attention, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| batch_size, sequence_length, _ = ( |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| ) |
|
|
| if attention_mask is not None: |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| |
| |
| attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
| query = attn.to_q(hidden_states) |
|
|
| if encoder_hidden_states is None: |
| encoder_hidden_states = hidden_states |
|
|
| key = attn.to_k(encoder_hidden_states) |
| value = attn.to_v(encoder_hidden_states) |
|
|
| if attn.norm_q is not None: |
| query = attn.norm_q(query) |
| if attn.norm_k is not None: |
| key = attn.norm_k(key) |
|
|
| inner_dim = key.shape[-1] |
| head_dim = inner_dim // attn.heads |
|
|
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
| |
| |
| hidden_states = F.scaled_dot_product_attention( |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
| ) |
|
|
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
| hidden_states = hidden_states.to(query.dtype) |
|
|
| |
| hidden_states = attn.to_out[0](hidden_states) |
| |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| hidden_states = hidden_states / attn.rescale_output_factor |
|
|
| return hidden_states |
|
|
|
|
| class SanaTransformerBlock(nn.Module): |
| r""" |
| Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629). |
| """ |
|
|
| def __init__( |
| self, |
| dim: int = 2240, |
| num_attention_heads: int = 70, |
| attention_head_dim: int = 32, |
| dropout: float = 0.0, |
| num_cross_attention_heads: Optional[int] = 20, |
| cross_attention_head_dim: Optional[int] = 112, |
| cross_attention_dim: Optional[int] = 2240, |
| attention_bias: bool = True, |
| norm_elementwise_affine: bool = False, |
| norm_eps: float = 1e-6, |
| attention_out_bias: bool = True, |
| mlp_ratio: float = 2.5, |
| qk_norm: Optional[str] = None, |
| ) -> None: |
| super().__init__() |
|
|
| |
| self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps) |
| self.attn1 = Attention( |
| query_dim=dim, |
| heads=num_attention_heads, |
| dim_head=attention_head_dim, |
| kv_heads=num_attention_heads if qk_norm is not None else None, |
| qk_norm=qk_norm, |
| dropout=dropout, |
| bias=attention_bias, |
| cross_attention_dim=None, |
| processor=SanaLinearAttnProcessor2_0(), |
| ) |
|
|
| |
| if cross_attention_dim is not None: |
| self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) |
| self.attn2 = Attention( |
| query_dim=dim, |
| qk_norm=qk_norm, |
| kv_heads=num_cross_attention_heads if qk_norm is not None else None, |
| cross_attention_dim=cross_attention_dim, |
| heads=num_cross_attention_heads, |
| dim_head=cross_attention_head_dim, |
| dropout=dropout, |
| bias=True, |
| out_bias=attention_out_bias, |
| processor=SanaAttnProcessor2_0(), |
| ) |
|
|
| |
| self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False) |
|
|
| self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| timestep: Optional[torch.LongTensor] = None, |
| height: int = None, |
| width: int = None, |
| ) -> torch.Tensor: |
| batch_size = hidden_states.shape[0] |
|
|
| |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) |
| ).chunk(6, dim=1) |
|
|
| |
| norm_hidden_states = self.norm1(hidden_states) |
| norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa |
| norm_hidden_states = norm_hidden_states.to(hidden_states.dtype) |
|
|
| attn_output = self.attn1(norm_hidden_states) |
| hidden_states = hidden_states + gate_msa * attn_output |
|
|
| |
| if self.attn2 is not None: |
| attn_output = self.attn2( |
| hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| ) |
| hidden_states = attn_output + hidden_states |
|
|
| |
| norm_hidden_states = self.norm2(hidden_states) |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp |
|
|
| norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute(0, 3, 1, 2) |
| ff_output = self.ff(norm_hidden_states) |
| ff_output = ff_output.flatten(2, 3).permute(0, 2, 1) |
| hidden_states = hidden_states + gate_mlp * ff_output |
|
|
| return hidden_states |
|
|
|
|
| class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): |
| r""" |
| A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models. |
| |
| Args: |
| in_channels (`int`, defaults to `32`): |
| The number of channels in the input. |
| out_channels (`int`, *optional*, defaults to `32`): |
| The number of channels in the output. |
| num_attention_heads (`int`, defaults to `70`): |
| The number of heads to use for multi-head attention. |
| attention_head_dim (`int`, defaults to `32`): |
| The number of channels in each head. |
| num_layers (`int`, defaults to `20`): |
| The number of layers of Transformer blocks to use. |
| num_cross_attention_heads (`int`, *optional*, defaults to `20`): |
| The number of heads to use for cross-attention. |
| cross_attention_head_dim (`int`, *optional*, defaults to `112`): |
| The number of channels in each head for cross-attention. |
| cross_attention_dim (`int`, *optional*, defaults to `2240`): |
| The number of channels in the cross-attention output. |
| caption_channels (`int`, defaults to `2304`): |
| The number of channels in the caption embeddings. |
| mlp_ratio (`float`, defaults to `2.5`): |
| The expansion ratio to use in the GLUMBConv layer. |
| dropout (`float`, defaults to `0.0`): |
| The dropout probability. |
| attention_bias (`bool`, defaults to `False`): |
| Whether to use bias in the attention layer. |
| sample_size (`int`, defaults to `32`): |
| The base size of the input latent. |
| patch_size (`int`, defaults to `1`): |
| The size of the patches to use in the patch embedding layer. |
| norm_elementwise_affine (`bool`, defaults to `False`): |
| Whether to use elementwise affinity in the normalization layer. |
| norm_eps (`float`, defaults to `1e-6`): |
| The epsilon value for the normalization layer. |
| qk_norm (`str`, *optional*, defaults to `None`): |
| The normalization to use for the query and key. |
| timestep_scale (`float`, defaults to `1.0`): |
| The scale to use for the timesteps. |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| _no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"] |
| _skip_layerwise_casting_patterns = ["patch_embed", "norm"] |
|
|
| @register_to_config |
| def __init__( |
| self, |
| in_channels: int = 32, |
| out_channels: Optional[int] = 32, |
| num_attention_heads: int = 70, |
| attention_head_dim: int = 32, |
| num_layers: int = 20, |
| num_cross_attention_heads: Optional[int] = 20, |
| cross_attention_head_dim: Optional[int] = 112, |
| cross_attention_dim: Optional[int] = 2240, |
| caption_channels: int = 2304, |
| mlp_ratio: float = 2.5, |
| dropout: float = 0.0, |
| attention_bias: bool = False, |
| sample_size: int = 32, |
| patch_size: int = 1, |
| norm_elementwise_affine: bool = False, |
| norm_eps: float = 1e-6, |
| interpolation_scale: Optional[int] = None, |
| guidance_embeds: bool = False, |
| guidance_embeds_scale: float = 0.1, |
| qk_norm: Optional[str] = None, |
| timestep_scale: float = 1.0, |
| ) -> None: |
| super().__init__() |
|
|
| out_channels = out_channels or in_channels |
| inner_dim = num_attention_heads * attention_head_dim |
|
|
| |
| self.patch_embed = PatchEmbed( |
| height=sample_size, |
| width=sample_size, |
| patch_size=patch_size, |
| in_channels=in_channels, |
| embed_dim=inner_dim, |
| interpolation_scale=interpolation_scale, |
| pos_embed_type="sincos" if interpolation_scale is not None else None, |
| ) |
|
|
| |
| if guidance_embeds: |
| self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim) |
| else: |
| self.time_embed = AdaLayerNormSingle(inner_dim) |
|
|
| self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) |
| self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| SanaTransformerBlock( |
| inner_dim, |
| num_attention_heads, |
| attention_head_dim, |
| dropout=dropout, |
| num_cross_attention_heads=num_cross_attention_heads, |
| cross_attention_head_dim=cross_attention_head_dim, |
| cross_attention_dim=cross_attention_dim, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| mlp_ratio=mlp_ratio, |
| qk_norm=qk_norm, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| |
| self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) |
| self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6) |
| self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| @property |
| |
| def attn_processors(self) -> Dict[str, AttentionProcessor]: |
| r""" |
| Returns: |
| `dict` of attention processors: A dictionary containing all attention processors used in the model with |
| indexed by its weight name. |
| """ |
| |
| processors = {} |
|
|
| def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
| if hasattr(module, "get_processor"): |
| processors[f"{name}.processor"] = module.get_processor() |
|
|
| for sub_name, child in module.named_children(): |
| fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
|
|
| return processors |
|
|
| for name, module in self.named_children(): |
| fn_recursive_add_processors(name, module, processors) |
|
|
| return processors |
|
|
| |
| def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
| r""" |
| Sets the attention processor to use to compute attention. |
| |
| Parameters: |
| processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
| The instantiated processor class or a dictionary of processor classes that will be set as the processor |
| for **all** `Attention` layers. |
| |
| If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
| processor. This is strongly recommended when setting trainable attention processors. |
| |
| """ |
| count = len(self.attn_processors.keys()) |
|
|
| if isinstance(processor, dict) and len(processor) != count: |
| raise ValueError( |
| f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
| f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
| ) |
|
|
| def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
| if hasattr(module, "set_processor"): |
| if not isinstance(processor, dict): |
| module.set_processor(processor) |
| else: |
| module.set_processor(processor.pop(f"{name}.processor")) |
|
|
| for sub_name, child in module.named_children(): |
| fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
|
|
| for name, module in self.named_children(): |
| fn_recursive_attn_processor(name, module, processor) |
|
|
| def register_block_hooks(self, block_indices=None): |
| """ |
| 为指定的transformer block注册钩子以获取输出 |
| |
| Args: |
| block_indices (list, optional): 要监视的block索引列表,None表示所有block |
| |
| Returns: |
| dict: block_outputs字典,键为block索引,值为对应的输出 |
| """ |
| block_outputs = {} |
| hooks = [] |
| |
| indices = block_indices if block_indices is not None else range(len(self.transformer_blocks)) |
| |
| for idx in indices: |
| |
| if idx < 0 or idx >= len(self.transformer_blocks): |
| continue |
| |
| def get_hook(i): |
| def hook(module, input, output): |
| block_outputs[i] = output |
| return hook |
| |
| h = self.transformer_blocks[idx].register_forward_hook(get_hook(idx)) |
| hooks.append(h) |
| |
| return block_outputs, hooks |
| |
| def remove_hooks(self, hooks): |
| """移除所有注册的钩子""" |
| for h in hooks: |
| h.remove() |
|
|
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| timestep: torch.Tensor, |
| guidance: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| attention_kwargs: Optional[Dict[str, Any]] = None, |
| return_dict: bool = True, |
| ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: |
| if attention_kwargs is not None: |
| attention_kwargs = attention_kwargs.copy() |
| lora_scale = attention_kwargs.pop("scale", 1.0) |
| else: |
| lora_scale = 1.0 |
|
|
| if USE_PEFT_BACKEND: |
| |
| scale_lora_layers(self, lora_scale) |
| else: |
| if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: |
| logger.warning( |
| "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if attention_mask is not None and attention_mask.ndim == 2: |
| |
| |
| |
| |
| attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 |
| attention_mask = attention_mask.unsqueeze(1) |
|
|
| |
| if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: |
| encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 |
| encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
| |
| batch_size, num_channels, height, width = hidden_states.shape |
| p = self.config.patch_size |
| post_patch_height, post_patch_width = height // p, width // p |
|
|
| hidden_states = self.patch_embed(hidden_states) |
|
|
| if guidance is not None: |
| timestep, embedded_timestep = self.time_embed( |
| timestep, guidance=guidance, hidden_dtype=hidden_states.dtype |
| ) |
| else: |
| timestep, embedded_timestep = self.time_embed( |
| timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype |
| ) |
|
|
| encoder_hidden_states = self.caption_projection(encoder_hidden_states) |
| encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) |
|
|
| encoder_hidden_states = self.caption_norm(encoder_hidden_states) |
|
|
| |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| for block in self.transformer_blocks: |
| hidden_states = self._gradient_checkpointing_func( |
| block, |
| hidden_states, |
| attention_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| timestep, |
| post_patch_height, |
| post_patch_width, |
| ) |
|
|
| else: |
| for block in self.transformer_blocks: |
| hidden_states = block( |
| hidden_states, |
| attention_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| timestep, |
| post_patch_height, |
| post_patch_width, |
| ) |
|
|
| |
| hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table) |
|
|
| hidden_states = self.proj_out(hidden_states) |
|
|
| |
| hidden_states = hidden_states.reshape( |
| batch_size, post_patch_height, post_patch_width, self.config.patch_size, self.config.patch_size, -1 |
| ) |
| hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4) |
| output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p) |
|
|
| if USE_PEFT_BACKEND: |
| |
| unscale_lora_layers(self, lora_scale) |
|
|
| if not return_dict: |
| return (output,) |
|
|
| return Transformer2DModelOutput(sample=output) |