| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from dataclasses import dataclass |
| | from math import pi |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.utils.checkpoint |
| |
|
| | from ...configuration_utils import ConfigMixin, register_to_config |
| | from ...models.modeling_utils import ModelMixin |
| | from ...utils import BaseOutput, logging |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class StableAudioPositionalEmbedding(nn.Module): |
| | """Used for continuous time""" |
| |
|
| | def __init__(self, dim: int): |
| | super().__init__() |
| | assert (dim % 2) == 0 |
| | half_dim = dim // 2 |
| | self.weights = nn.Parameter(torch.randn(half_dim)) |
| |
|
| | def forward(self, times: torch.Tensor) -> torch.Tensor: |
| | times = times[..., None] |
| | freqs = times * self.weights[None] * 2 * pi |
| | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) |
| | fouriered = torch.cat((times, fouriered), dim=-1) |
| | return fouriered |
| |
|
| |
|
| | @dataclass |
| | class StableAudioProjectionModelOutput(BaseOutput): |
| | """ |
| | Args: |
| | Class for StableAudio projection layer's outputs. |
| | text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| | Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder. |
| | seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): |
| | Sequence of hidden-states obtained by linearly projecting the audio start hidden states. |
| | seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): |
| | Sequence of hidden-states obtained by linearly projecting the audio end hidden states. |
| | """ |
| |
|
| | text_hidden_states: Optional[torch.Tensor] = None |
| | seconds_start_hidden_states: Optional[torch.Tensor] = None |
| | seconds_end_hidden_states: Optional[torch.Tensor] = None |
| |
|
| |
|
| | class StableAudioNumberConditioner(nn.Module): |
| | """ |
| | A simple linear projection model to map numbers to a latent space. |
| | |
| | Args: |
| | number_embedding_dim (`int`): |
| | Dimensionality of the number embeddings. |
| | min_value (`int`): |
| | The minimum value of the seconds number conditioning modules. |
| | max_value (`int`): |
| | The maximum value of the seconds number conditioning modules |
| | internal_dim (`int`): |
| | Dimensionality of the intermediate number hidden states. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | number_embedding_dim, |
| | min_value, |
| | max_value, |
| | internal_dim: Optional[int] = 256, |
| | ): |
| | super().__init__() |
| | self.time_positional_embedding = nn.Sequential( |
| | StableAudioPositionalEmbedding(internal_dim), |
| | nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), |
| | ) |
| |
|
| | self.number_embedding_dim = number_embedding_dim |
| | self.min_value = min_value |
| | self.max_value = max_value |
| |
|
| | def forward( |
| | self, |
| | floats: torch.Tensor, |
| | ): |
| | floats = floats.clamp(self.min_value, self.max_value) |
| |
|
| | normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) |
| |
|
| | |
| | embedder_dtype = next(self.time_positional_embedding.parameters()).dtype |
| | normalized_floats = normalized_floats.to(embedder_dtype) |
| |
|
| | embedding = self.time_positional_embedding(normalized_floats) |
| | float_embeds = embedding.view(-1, 1, self.number_embedding_dim) |
| |
|
| | return float_embeds |
| |
|
| |
|
| | class StableAudioProjectionModel(ModelMixin, ConfigMixin): |
| | """ |
| | A simple linear projection model to map the conditioning values to a shared latent space. |
| | |
| | Args: |
| | text_encoder_dim (`int`): |
| | Dimensionality of the text embeddings from the text encoder (T5). |
| | conditioning_dim (`int`): |
| | Dimensionality of the output conditioning tensors. |
| | min_value (`int`): |
| | The minimum value of the seconds number conditioning modules. |
| | max_value (`int`): |
| | The maximum value of the seconds number conditioning modules |
| | """ |
| |
|
| | @register_to_config |
| | def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value): |
| | super().__init__() |
| | self.text_projection = ( |
| | nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim) |
| | ) |
| | self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) |
| | self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) |
| |
|
| | def forward( |
| | self, |
| | text_hidden_states: Optional[torch.Tensor] = None, |
| | start_seconds: Optional[torch.Tensor] = None, |
| | end_seconds: Optional[torch.Tensor] = None, |
| | ): |
| | text_hidden_states = ( |
| | text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states) |
| | ) |
| | seconds_start_hidden_states = ( |
| | start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds) |
| | ) |
| | seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds) |
| |
|
| | return StableAudioProjectionModelOutput( |
| | text_hidden_states=text_hidden_states, |
| | seconds_start_hidden_states=seconds_start_hidden_states, |
| | seconds_end_hidden_states=seconds_end_hidden_states, |
| | ) |
| |
|