salmankhanpm's picture
Add files using upload-large-folder tool
69e1a8d verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
class ResBlock(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int = 3,
stride: int = 1,
dilations: tuple[int, ...] = (1, 3, 5),
leaky_relu_negative_slope: float = 0.1,
padding_mode: str = "same",
):
super().__init__()
self.dilations = dilations
self.negative_slope = leaky_relu_negative_slope
self.convs1 = nn.ModuleList(
[
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode)
for dilation in dilations
]
)
self.convs2 = nn.ModuleList(
[
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode)
for _ in range(len(dilations))
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv1, conv2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
xt = conv1(xt)
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
xt = conv2(xt)
x = x + xt
return x
class LTX2Vocoder(ModelMixin, ConfigMixin):
r"""
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
"""
@register_to_config
def __init__(
self,
in_channels: int = 128,
hidden_channels: int = 1024,
out_channels: int = 2,
upsample_kernel_sizes: list[int] = [16, 15, 8, 4, 4],
upsample_factors: list[int] = [6, 5, 2, 2, 2],
resnet_kernel_sizes: list[int] = [3, 7, 11],
resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
leaky_relu_negative_slope: float = 0.1,
output_sampling_rate: int = 24000,
):
super().__init__()
self.num_upsample_layers = len(upsample_kernel_sizes)
self.resnets_per_upsample = len(resnet_kernel_sizes)
self.out_channels = out_channels
self.total_upsample_factor = math.prod(upsample_factors)
self.negative_slope = leaky_relu_negative_slope
if self.num_upsample_layers != len(upsample_factors):
raise ValueError(
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
)
if self.resnets_per_upsample != len(resnet_dilations):
raise ValueError(
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
)
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
self.upsamplers = nn.ModuleList()
self.resnets = nn.ModuleList()
input_channels = hidden_channels
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
output_channels = input_channels // 2
self.upsamplers.append(
nn.ConvTranspose1d(
input_channels, # hidden_channels // (2 ** i)
output_channels, # hidden_channels // (2 ** (i + 1))
kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2,
)
)
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
self.resnets.append(
ResBlock(
output_channels,
kernel_size,
dilations=dilations,
leaky_relu_negative_slope=leaky_relu_negative_slope,
)
)
input_channels = output_channels
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
r"""
Forward pass of the vocoder.
Args:
hidden_states (`torch.Tensor`):
Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`
is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is
`True`.
time_last (`bool`, *optional*, defaults to `False`):
Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.
Returns:
`torch.Tensor`:
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
"""
# Ensure that the time/frame dimension is last
if not time_last:
hidden_states = hidden_states.transpose(2, 3)
# Combine channels and frequency (mel bins) dimensions
hidden_states = hidden_states.flatten(1, 2)
hidden_states = self.conv_in(hidden_states)
for i in range(self.num_upsample_layers):
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
hidden_states = self.upsamplers[i](hidden_states)
# Run all resnets in parallel on hidden_states
start = i * self.resnets_per_upsample
end = (i + 1) * self.resnets_per_upsample
resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0)
hidden_states = torch.mean(resnet_outputs, dim=0)
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
hidden_states = self.conv_out(hidden_states)
hidden_states = torch.tanh(hidden_states)
return hidden_states