| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...models.modeling_utils import ModelMixin |
|
|
|
|
| class ResBlock(nn.Module): |
| def __init__( |
| self, |
| channels: int, |
| kernel_size: int = 3, |
| stride: int = 1, |
| dilations: tuple[int, ...] = (1, 3, 5), |
| leaky_relu_negative_slope: float = 0.1, |
| padding_mode: str = "same", |
| ): |
| super().__init__() |
| self.dilations = dilations |
| self.negative_slope = leaky_relu_negative_slope |
|
|
| self.convs1 = nn.ModuleList( |
| [ |
| nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode) |
| for dilation in dilations |
| ] |
| ) |
|
|
| self.convs2 = nn.ModuleList( |
| [ |
| nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode) |
| for _ in range(len(dilations)) |
| ] |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for conv1, conv2 in zip(self.convs1, self.convs2): |
| xt = F.leaky_relu(x, negative_slope=self.negative_slope) |
| xt = conv1(xt) |
| xt = F.leaky_relu(xt, negative_slope=self.negative_slope) |
| xt = conv2(xt) |
| x = x + xt |
| return x |
|
|
|
|
| class LTX2Vocoder(ModelMixin, ConfigMixin): |
| r""" |
| LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms. |
| """ |
|
|
| @register_to_config |
| def __init__( |
| self, |
| in_channels: int = 128, |
| hidden_channels: int = 1024, |
| out_channels: int = 2, |
| upsample_kernel_sizes: list[int] = [16, 15, 8, 4, 4], |
| upsample_factors: list[int] = [6, 5, 2, 2, 2], |
| resnet_kernel_sizes: list[int] = [3, 7, 11], |
| resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], |
| leaky_relu_negative_slope: float = 0.1, |
| output_sampling_rate: int = 24000, |
| ): |
| super().__init__() |
| self.num_upsample_layers = len(upsample_kernel_sizes) |
| self.resnets_per_upsample = len(resnet_kernel_sizes) |
| self.out_channels = out_channels |
| self.total_upsample_factor = math.prod(upsample_factors) |
| self.negative_slope = leaky_relu_negative_slope |
|
|
| if self.num_upsample_layers != len(upsample_factors): |
| raise ValueError( |
| f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length" |
| f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively." |
| ) |
|
|
| if self.resnets_per_upsample != len(resnet_dilations): |
| raise ValueError( |
| f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length" |
| f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." |
| ) |
|
|
| self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) |
|
|
| self.upsamplers = nn.ModuleList() |
| self.resnets = nn.ModuleList() |
| input_channels = hidden_channels |
| for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): |
| output_channels = input_channels // 2 |
| self.upsamplers.append( |
| nn.ConvTranspose1d( |
| input_channels, |
| output_channels, |
| kernel_size, |
| stride=stride, |
| padding=(kernel_size - stride) // 2, |
| ) |
| ) |
|
|
| for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): |
| self.resnets.append( |
| ResBlock( |
| output_channels, |
| kernel_size, |
| dilations=dilations, |
| leaky_relu_negative_slope=leaky_relu_negative_slope, |
| ) |
| ) |
| input_channels = output_channels |
|
|
| self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) |
|
|
| def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: |
| r""" |
| Forward pass of the vocoder. |
| |
| Args: |
| hidden_states (`torch.Tensor`): |
| Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last` |
| is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is |
| `True`. |
| time_last (`bool`, *optional*, defaults to `False`): |
| Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension. |
| |
| Returns: |
| `torch.Tensor`: |
| Audio waveform tensor of shape (batch_size, out_channels, audio_length) |
| """ |
|
|
| |
| if not time_last: |
| hidden_states = hidden_states.transpose(2, 3) |
| |
| hidden_states = hidden_states.flatten(1, 2) |
|
|
| hidden_states = self.conv_in(hidden_states) |
|
|
| for i in range(self.num_upsample_layers): |
| hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) |
| hidden_states = self.upsamplers[i](hidden_states) |
|
|
| |
| start = i * self.resnets_per_upsample |
| end = (i + 1) * self.resnets_per_upsample |
| resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0) |
|
|
| hidden_states = torch.mean(resnet_outputs, dim=0) |
|
|
| |
| |
| hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01) |
| hidden_states = self.conv_out(hidden_states) |
| hidden_states = torch.tanh(hidden_states) |
|
|
| return hidden_states |
|
|