| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from typing import Any, Dict, Optional, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.utils.checkpoint |
| |
|
| | from ...configuration_utils import ConfigMixin, register_to_config |
| | from ...models.attention import FeedForward |
| | from ...models.attention_processor import ( |
| | Attention, |
| | AttentionProcessor, |
| | StableAudioAttnProcessor2_0, |
| | ) |
| | from ...models.modeling_utils import ModelMixin |
| | from ...models.transformers.transformer_2d import Transformer2DModelOutput |
| | from ...utils import is_torch_version, logging |
| | from ...utils.torch_utils import maybe_allow_in_graph |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class StableAudioGaussianFourierProjection(nn.Module): |
| | """Gaussian Fourier embeddings for noise levels.""" |
| |
|
| | |
| | def __init__( |
| | self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False |
| | ): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) |
| | self.log = log |
| | self.flip_sin_to_cos = flip_sin_to_cos |
| |
|
| | if set_W_to_weight: |
| | |
| | del self.weight |
| | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) |
| | self.weight = self.W |
| | del self.W |
| |
|
| | def forward(self, x): |
| | if self.log: |
| | x = torch.log(x) |
| |
|
| | x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :] |
| |
|
| | if self.flip_sin_to_cos: |
| | out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) |
| | else: |
| | out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) |
| | return out |
| |
|
| |
|
| | @maybe_allow_in_graph |
| | class StableAudioDiTBlock(nn.Module): |
| | r""" |
| | Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip |
| | connection and QKNorm |
| | |
| | Parameters: |
| | dim (`int`): The number of channels in the input and output. |
| | num_attention_heads (`int`): The number of heads to use for the query states. |
| | num_key_value_attention_heads (`int`): The number of heads to use for the key and value states. |
| | attention_head_dim (`int`): The number of channels in each head. |
| | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
| | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. |
| | upcast_attention (`bool`, *optional*): |
| | Whether to upcast the attention computation to float32. This is useful for mixed precision training. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim: int, |
| | num_attention_heads: int, |
| | num_key_value_attention_heads: int, |
| | attention_head_dim: int, |
| | dropout=0.0, |
| | cross_attention_dim: Optional[int] = None, |
| | upcast_attention: bool = False, |
| | norm_eps: float = 1e-5, |
| | ff_inner_dim: Optional[int] = None, |
| | ): |
| | super().__init__() |
| | |
| | |
| | self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps) |
| | self.attn1 = Attention( |
| | query_dim=dim, |
| | heads=num_attention_heads, |
| | dim_head=attention_head_dim, |
| | dropout=dropout, |
| | bias=False, |
| | upcast_attention=upcast_attention, |
| | out_bias=False, |
| | processor=StableAudioAttnProcessor2_0(), |
| | ) |
| |
|
| | |
| | self.norm2 = nn.LayerNorm(dim, norm_eps, True) |
| |
|
| | self.attn2 = Attention( |
| | query_dim=dim, |
| | cross_attention_dim=cross_attention_dim, |
| | heads=num_attention_heads, |
| | dim_head=attention_head_dim, |
| | kv_heads=num_key_value_attention_heads, |
| | dropout=dropout, |
| | bias=False, |
| | upcast_attention=upcast_attention, |
| | out_bias=False, |
| | processor=StableAudioAttnProcessor2_0(), |
| | ) |
| |
|
| | |
| | self.norm3 = nn.LayerNorm(dim, norm_eps, True) |
| | self.ff = FeedForward( |
| | dim, |
| | dropout=dropout, |
| | activation_fn="swiglu", |
| | final_dropout=False, |
| | inner_dim=ff_inner_dim, |
| | bias=True, |
| | ) |
| |
|
| | |
| | self._chunk_size = None |
| | self._chunk_dim = 0 |
| |
|
| | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): |
| | |
| | self._chunk_size = chunk_size |
| | self._chunk_dim = dim |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | encoder_hidden_states: Optional[torch.Tensor] = None, |
| | encoder_attention_mask: Optional[torch.Tensor] = None, |
| | rotary_embedding: Optional[torch.FloatTensor] = None, |
| | ) -> torch.Tensor: |
| | |
| | |
| | norm_hidden_states = self.norm1(hidden_states) |
| |
|
| | attn_output = self.attn1( |
| | norm_hidden_states, |
| | attention_mask=attention_mask, |
| | rotary_emb=rotary_embedding, |
| | ) |
| |
|
| | hidden_states = attn_output + hidden_states |
| |
|
| | |
| | norm_hidden_states = self.norm2(hidden_states) |
| |
|
| | attn_output = self.attn2( |
| | norm_hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | attention_mask=encoder_attention_mask, |
| | ) |
| | hidden_states = attn_output + hidden_states |
| |
|
| | |
| | norm_hidden_states = self.norm3(hidden_states) |
| | ff_output = self.ff(norm_hidden_states) |
| |
|
| | hidden_states = ff_output + hidden_states |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class StableAudioDiTModel(ModelMixin, ConfigMixin): |
| | """ |
| | The Diffusion Transformer model introduced in Stable Audio. |
| | |
| | Reference: https://github.com/Stability-AI/stable-audio-tools |
| | |
| | Parameters: |
| | sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample. |
| | in_channels (`int`, *optional*, defaults to 64): The number of channels in the input. |
| | num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use. |
| | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. |
| | num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states. |
| | num_key_value_attention_heads (`int`, *optional*, defaults to 12): |
| | The number of heads to use for the key and value states. |
| | out_channels (`int`, defaults to 64): Number of output channels. |
| | cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection. |
| | time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection. |
| | global_states_input_dim ( `int`, *optional*, defaults to 1536): |
| | Input dimension of the global hidden states projection. |
| | cross_attention_input_dim ( `int`, *optional*, defaults to 768): |
| | Input dimension of the cross-attention projection |
| | """ |
| |
|
| | _supports_gradient_checkpointing = True |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | sample_size: int = 1024, |
| | in_channels: int = 64, |
| | num_layers: int = 24, |
| | attention_head_dim: int = 64, |
| | num_attention_heads: int = 24, |
| | num_key_value_attention_heads: int = 12, |
| | out_channels: int = 64, |
| | cross_attention_dim: int = 768, |
| | time_proj_dim: int = 256, |
| | global_states_input_dim: int = 1536, |
| | cross_attention_input_dim: int = 768, |
| | ): |
| | super().__init__() |
| | self.sample_size = sample_size |
| | self.out_channels = out_channels |
| | self.inner_dim = num_attention_heads * attention_head_dim |
| |
|
| | self.time_proj = StableAudioGaussianFourierProjection( |
| | embedding_size=time_proj_dim // 2, |
| | flip_sin_to_cos=True, |
| | log=False, |
| | set_W_to_weight=False, |
| | ) |
| |
|
| | self.timestep_proj = nn.Sequential( |
| | nn.Linear(time_proj_dim, self.inner_dim, bias=True), |
| | nn.SiLU(), |
| | nn.Linear(self.inner_dim, self.inner_dim, bias=True), |
| | ) |
| |
|
| | self.global_proj = nn.Sequential( |
| | nn.Linear(global_states_input_dim, self.inner_dim, bias=False), |
| | nn.SiLU(), |
| | nn.Linear(self.inner_dim, self.inner_dim, bias=False), |
| | ) |
| |
|
| | self.cross_attention_proj = nn.Sequential( |
| | nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), |
| | nn.SiLU(), |
| | nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), |
| | ) |
| |
|
| | self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False) |
| | self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False) |
| |
|
| | self.transformer_blocks = nn.ModuleList( |
| | [ |
| | StableAudioDiTBlock( |
| | dim=self.inner_dim, |
| | num_attention_heads=num_attention_heads, |
| | num_key_value_attention_heads=num_key_value_attention_heads, |
| | attention_head_dim=attention_head_dim, |
| | cross_attention_dim=cross_attention_dim, |
| | ) |
| | for i in range(num_layers) |
| | ] |
| | ) |
| |
|
| | self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False) |
| | self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False) |
| |
|
| | self.gradient_checkpointing = False |
| |
|
| | @property |
| | |
| | def attn_processors(self) -> Dict[str, AttentionProcessor]: |
| | r""" |
| | Returns: |
| | `dict` of attention processors: A dictionary containing all attention processors used in the model with |
| | indexed by its weight name. |
| | """ |
| | |
| | processors = {} |
| |
|
| | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
| | if hasattr(module, "get_processor"): |
| | processors[f"{name}.processor"] = module.get_processor() |
| |
|
| | for sub_name, child in module.named_children(): |
| | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
| |
|
| | return processors |
| |
|
| | for name, module in self.named_children(): |
| | fn_recursive_add_processors(name, module, processors) |
| |
|
| | return processors |
| |
|
| | |
| | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
| | r""" |
| | Sets the attention processor to use to compute attention. |
| | |
| | Parameters: |
| | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
| | The instantiated processor class or a dictionary of processor classes that will be set as the processor |
| | for **all** `Attention` layers. |
| | |
| | If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
| | processor. This is strongly recommended when setting trainable attention processors. |
| | |
| | """ |
| | count = len(self.attn_processors.keys()) |
| |
|
| | if isinstance(processor, dict) and len(processor) != count: |
| | raise ValueError( |
| | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
| | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
| | ) |
| |
|
| | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
| | if hasattr(module, "set_processor"): |
| | if not isinstance(processor, dict): |
| | module.set_processor(processor) |
| | else: |
| | module.set_processor(processor.pop(f"{name}.processor")) |
| |
|
| | for sub_name, child in module.named_children(): |
| | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
| |
|
| | for name, module in self.named_children(): |
| | fn_recursive_attn_processor(name, module, processor) |
| |
|
| | |
| | def set_default_attn_processor(self): |
| | """ |
| | Disables custom attention processors and sets the default attention implementation. |
| | """ |
| | self.set_attn_processor(StableAudioAttnProcessor2_0()) |
| |
|
| | def _set_gradient_checkpointing(self, module, value=False): |
| | if hasattr(module, "gradient_checkpointing"): |
| | module.gradient_checkpointing = value |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | timestep: torch.LongTensor = None, |
| | encoder_hidden_states: torch.FloatTensor = None, |
| | global_hidden_states: torch.FloatTensor = None, |
| | rotary_embedding: torch.FloatTensor = None, |
| | return_dict: bool = True, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | encoder_attention_mask: Optional[torch.LongTensor] = None, |
| | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: |
| | """ |
| | The [`StableAudioDiTModel`] forward method. |
| | |
| | Args: |
| | hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`): |
| | Input `hidden_states`. |
| | timestep ( `torch.LongTensor`): |
| | Used to indicate denoising step. |
| | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`): |
| | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. |
| | global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`): |
| | Global embeddings that will be prepended to the hidden states. |
| | rotary_embedding (`torch.Tensor`): |
| | The rotary embeddings to apply on query and key tensors during attention calculation. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain |
| | tuple. |
| | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): |
| | Mask to avoid performing attention on padding token indices, formed by concatenating the attention |
| | masks |
| | for the two text encoders together. Mask values selected in `[0, 1]`: |
| | |
| | - 1 for tokens that are **not masked**, |
| | - 0 for tokens that are **masked**. |
| | encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): |
| | Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating |
| | the attention masks |
| | for the two text encoders together. Mask values selected in `[0, 1]`: |
| | |
| | - 1 for tokens that are **not masked**, |
| | - 0 for tokens that are **masked**. |
| | Returns: |
| | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a |
| | `tuple` where the first element is the sample tensor. |
| | """ |
| | cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) |
| | global_hidden_states = self.global_proj(global_hidden_states) |
| | time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype))) |
| |
|
| | global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) |
| |
|
| | hidden_states = self.preprocess_conv(hidden_states) + hidden_states |
| | |
| | hidden_states = hidden_states.transpose(1, 2) |
| |
|
| | hidden_states = self.proj_in(hidden_states) |
| |
|
| | |
| | hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2) |
| | if attention_mask is not None: |
| | prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool) |
| | attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) |
| |
|
| | for block in self.transformer_blocks: |
| | if torch.is_grad_enabled() and self.gradient_checkpointing: |
| |
|
| | def create_custom_forward(module, return_dict=None): |
| | def custom_forward(*inputs): |
| | if return_dict is not None: |
| | return module(*inputs, return_dict=return_dict) |
| | else: |
| | return module(*inputs) |
| |
|
| | return custom_forward |
| |
|
| | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
| | hidden_states = torch.utils.checkpoint.checkpoint( |
| | create_custom_forward(block), |
| | hidden_states, |
| | attention_mask, |
| | cross_attention_hidden_states, |
| | encoder_attention_mask, |
| | rotary_embedding, |
| | **ckpt_kwargs, |
| | ) |
| |
|
| | else: |
| | hidden_states = block( |
| | hidden_states=hidden_states, |
| | attention_mask=attention_mask, |
| | encoder_hidden_states=cross_attention_hidden_states, |
| | encoder_attention_mask=encoder_attention_mask, |
| | rotary_embedding=rotary_embedding, |
| | ) |
| |
|
| | hidden_states = self.proj_out(hidden_states) |
| |
|
| | |
| | |
| | hidden_states = hidden_states.transpose(1, 2)[:, :, 1:] |
| | hidden_states = self.postprocess_conv(hidden_states) + hidden_states |
| |
|
| | if not return_dict: |
| | return (hidden_states,) |
| |
|
| | return Transformer2DModelOutput(sample=hidden_states) |
| |
|