Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| import math | |
| import random | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch.nn.utils.parametrizations import weight_norm | |
| from typing import Optional, Tuple | |
| from scipy.signal import get_window | |
| class AdaIN1d(nn.Module): | |
| def __init__(self, style_dim, num_features): | |
| super().__init__() | |
| self.norm = nn.InstanceNorm1d(num_features, affine=False) | |
| self.fc = nn.Linear(style_dim, num_features*2) | |
| def forward(self, x, s): | |
| h = self.fc(s) | |
| h = h.view(h.size(0), h.size(1), 1) | |
| gamma, beta = torch.chunk(h, chunks=2, dim=1) | |
| return (1 + gamma) * self.norm(x) + beta | |
| class ConvNeXtBlock(nn.Module): | |
| """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. | |
| Args: | |
| dim (int): Number of input channels. | |
| intermediate_dim (int): Dimensionality of the intermediate layer. | |
| layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. | |
| Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| intermediate_dim: int, | |
| layer_scale_init_value: float, | |
| style_dim: int, | |
| ): | |
| super().__init__() | |
| self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv | |
| self.norm = AdaIN1d(style_dim, dim) | |
| self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers | |
| self.act = nn.GELU() | |
| self.pwconv2 = nn.Linear(intermediate_dim, dim) | |
| self.gamma = ( | |
| nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) | |
| if layer_scale_init_value > 0 | |
| else None | |
| ) | |
| def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| x = self.dwconv(x) | |
| x = self.norm(x, s) | |
| x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) | |
| x = self.pwconv1(x) | |
| x = self.act(x) | |
| x = self.pwconv2(x) | |
| if self.gamma is not None: | |
| x = self.gamma * x | |
| x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) | |
| x = residual + x | |
| return x | |
| def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: | |
| """ | |
| Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. | |
| Args: | |
| x (Tensor): Input tensor. | |
| clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. | |
| Returns: | |
| Tensor: Element-wise logarithm of the input tensor with clipping applied. | |
| """ | |
| return torch.log(torch.clip(x, min=clip_val)) | |
| def symlog(x: torch.Tensor) -> torch.Tensor: | |
| return torch.sign(x) * torch.log1p(x.abs()) | |
| def symexp(x: torch.Tensor) -> torch.Tensor: | |
| return torch.sign(x) * (torch.exp(x.abs()) - 1) | |
| class Backbone(nn.Module): | |
| """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" | |
| def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: | |
| """ | |
| Args: | |
| x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, | |
| C denotes output features, and L is the sequence length. | |
| Returns: | |
| Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, | |
| and H denotes the model dimension. | |
| """ | |
| raise NotImplementedError("Subclasses must implement the forward method.") | |
| class Generator(Backbone): | |
| """ | |
| Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization | |
| Args: | |
| input_channels (int): Number of input features channels. | |
| dim (int): Hidden dimension of the model. | |
| intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. | |
| num_layers (int): Number of ConvNeXtBlock layers. | |
| layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. | |
| """ | |
| def __init__( | |
| self, | |
| input_channels: int, | |
| dim: int, | |
| style_dim: int, | |
| intermediate_dim: int, | |
| num_layers: int, | |
| gen_istft_n_fft: int, | |
| gen_istft_hop_size: int, | |
| layer_scale_init_value: Optional[float] = None, | |
| ): | |
| super().__init__() | |
| self.input_channels = input_channels | |
| layer_scale_init_value = layer_scale_init_value or 1 / num_layers | |
| self.convnext = nn.ModuleList() | |
| for i in range(num_layers): | |
| self.convnext.append( | |
| ConvNeXtBlock( | |
| dim=dim, | |
| intermediate_dim=intermediate_dim, | |
| layer_scale_init_value=layer_scale_init_value, | |
| style_dim=style_dim, | |
| ) | |
| ) | |
| self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) | |
| self.apply(self._init_weights) | |
| self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) | |
| self.stft = ISTFTHead(dim=dim, n_fft=gen_istft_n_fft, hop_length=gen_istft_hop_size, padding="same") | |
| def _init_weights(self, m): | |
| if isinstance(m, (nn.Conv1d, nn.Linear)): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x, s) -> torch.Tensor: | |
| for i, conv_block in enumerate(self.convnext): | |
| x = conv_block(x, s) | |
| x = self.final_layer_norm(x.transpose(1, 2)) | |
| x = self.stft(x) | |
| return x | |
| class ISTFT(nn.Module): | |
| """ | |
| Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with | |
| windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. | |
| See issue: https://github.com/pytorch/pytorch/issues/62323 | |
| Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. | |
| The NOLA constraint is met as we trim padded samples anyway. | |
| Args: | |
| n_fft (int): Size of Fourier transform. | |
| hop_length (int): The distance between neighboring sliding window frames. | |
| win_length (int): The size of window frame and STFT filter. | |
| padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". | |
| """ | |
| def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): | |
| super().__init__() | |
| if padding not in ["center", "same"]: | |
| raise ValueError("Padding must be 'center' or 'same'.") | |
| self.padding = padding | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| window = torch.hann_window(win_length) | |
| self.register_buffer("window", window) | |
| def forward(self, spec: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. | |
| Args: | |
| spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, | |
| N is the number of frequency bins, and T is the number of time frames. | |
| Returns: | |
| Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. | |
| """ | |
| if self.padding == "center": | |
| # Fallback to pytorch native implementation | |
| return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) | |
| elif self.padding == "same": | |
| pad = (self.win_length - self.hop_length) // 2 | |
| else: | |
| raise ValueError("Padding must be 'center' or 'same'.") | |
| assert spec.dim() == 3, "Expected a 3D tensor as input" | |
| B, N, T = spec.shape | |
| # Inverse FFT | |
| ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") | |
| ifft = ifft * self.window[None, :, None] | |
| # Overlap and Add | |
| output_size = (T - 1) * self.hop_length + self.win_length | |
| y = torch.nn.functional.fold( | |
| ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), | |
| )[:, 0, 0, pad:-pad] | |
| # Window envelope | |
| window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) | |
| window_envelope = torch.nn.functional.fold( | |
| window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), | |
| ).squeeze()[pad:-pad] | |
| # Normalize | |
| assert (window_envelope > 1e-11).all() | |
| y = y / window_envelope | |
| return y | |
| class FourierHead(nn.Module): | |
| """Base class for inverse fourier modules.""" | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, | |
| L is the sequence length, and H denotes the model dimension. | |
| Returns: | |
| Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. | |
| """ | |
| raise NotImplementedError("Subclasses must implement the forward method.") | |
| class ISTFTHead(FourierHead): | |
| """ | |
| ISTFT Head module for predicting STFT complex coefficients. | |
| Args: | |
| dim (int): Hidden dimension of the model. | |
| n_fft (int): Size of Fourier transform. | |
| hop_length (int): The distance between neighboring sliding window frames, which should align with | |
| the resolution of the input features. | |
| padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". | |
| """ | |
| def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): | |
| super().__init__() | |
| self.filter_length = n_fft | |
| self.win_length = n_fft | |
| self.hop_length = hop_length | |
| self.window = torch.from_numpy(get_window("hann", self.win_length, fftbins=True).astype(np.float32)) | |
| out_dim = n_fft + 2 | |
| self.out = torch.nn.Linear(dim, out_dim) | |
| self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass of the ISTFTHead module. | |
| Args: | |
| x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, | |
| L is the sequence length, and H denotes the model dimension. | |
| Returns: | |
| Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. | |
| """ | |
| x = self.out(x).transpose(1, 2) | |
| mag, p = x.chunk(2, dim=1) | |
| mag = torch.exp(mag) | |
| mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes | |
| # wrapping happens here. These two lines produce real and imaginary value | |
| x = torch.cos(p) | |
| y = torch.sin(p) | |
| # recalculating phase here does not produce anything new | |
| # only costs time | |
| # phase = torch.atan2(y, x) | |
| # S = mag * torch.exp(phase * 1j) | |
| # better directly produce the complex value | |
| S = mag * (x + 1j * y) | |
| audio = self.istft(S) | |
| return audio | |
| def transform(self, input_data): | |
| forward_transform = torch.stft( | |
| input_data, | |
| self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device), | |
| return_complex=True) | |
| return torch.abs(forward_transform), torch.angle(forward_transform) | |
| class AdainResBlk1d(nn.Module): | |
| def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), | |
| upsample='none', dropout_p=0.0): | |
| super().__init__() | |
| self.actv = actv | |
| self.upsample_type = upsample | |
| self.upsample = UpSample1d(upsample) | |
| self.learned_sc = dim_in != dim_out | |
| self._build_weights(dim_in, dim_out, style_dim) | |
| self.dropout = nn.Dropout(dropout_p) | |
| if upsample == 'none': | |
| self.pool = nn.Identity() | |
| else: | |
| self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1)) | |
| def _build_weights(self, dim_in, dim_out, style_dim): | |
| self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) | |
| self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)) | |
| self.norm1 = AdaIN1d(style_dim, dim_in) | |
| self.norm2 = AdaIN1d(style_dim, dim_out) | |
| if self.learned_sc: | |
| self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False)) | |
| def _shortcut(self, x): | |
| x = self.upsample(x) | |
| if self.learned_sc: | |
| x = self.conv1x1(x) | |
| return x | |
| def _residual(self, x, s): | |
| x = self.norm1(x, s) | |
| x = self.actv(x) | |
| x = self.pool(x) | |
| x = self.conv1(self.dropout(x)) | |
| x = self.norm2(x, s) | |
| x = self.actv(x) | |
| x = self.conv2(self.dropout(x)) | |
| return x | |
| def forward(self, x, s): | |
| out = self._residual(x, s) | |
| out = (out + self._shortcut(x)) / math.sqrt(2) | |
| return out | |
| class UpSample1d(nn.Module): | |
| def __init__(self, layer_type): | |
| super().__init__() | |
| self.layer_type = layer_type | |
| def forward(self, x): | |
| if self.layer_type == 'none': | |
| return x | |
| else: | |
| return F.interpolate(x, scale_factor=2, mode='nearest') | |
| class Decoder(nn.Module): | |
| def __init__(self, dim_in=512, style_dim=64, dim_out=80, | |
| intermediate_dim=1536, | |
| num_layers=8, | |
| gen_istft_n_fft=1024, gen_istft_hop_size=256): | |
| super().__init__() | |
| self.decode = nn.ModuleList() | |
| self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim) | |
| self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) | |
| self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) | |
| self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) | |
| self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True)) | |
| self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) | |
| self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) | |
| self.asr_res = nn.Sequential( | |
| weight_norm(nn.Conv1d(512, 64, kernel_size=1)), | |
| ) | |
| self.generator = Generator(input_channels=dim_out, dim=dim_in, style_dim=style_dim, | |
| intermediate_dim=intermediate_dim, num_layers=num_layers, | |
| gen_istft_n_fft=gen_istft_n_fft, gen_istft_hop_size=gen_istft_hop_size) | |
| def forward(self, asr, F0_curve, N, s): | |
| if self.training: | |
| downlist = [0, 3, 7] | |
| F0_down = downlist[random.randint(0, 2)] | |
| downlist = [0, 3, 7, 15] | |
| N_down = downlist[random.randint(0, 3)] | |
| if F0_down: | |
| F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down | |
| if N_down: | |
| N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down | |
| F0 = self.F0_conv(F0_curve.unsqueeze(1)) | |
| N = self.N_conv(N.unsqueeze(1)) | |
| x = torch.cat([asr, F0, N], axis=1) | |
| x = self.encode(x, s) | |
| asr_res = self.asr_res(asr) | |
| res = True | |
| for block in self.decode: | |
| if res: | |
| x = torch.cat([x, asr_res, F0, N], axis=1) | |
| x = block(x, s) | |
| if block.upsample_type != "none": | |
| res = False | |
| x = self.generator(x, s) | |
| x = x.unsqueeze(1) | |
| return x |