from typing import Optional import math import random import numpy as np import torch from torch import nn import torch.nn.functional as F from torch.nn.utils.parametrizations import weight_norm from typing import Optional, Tuple from scipy.signal import get_window class AdaIN1d(nn.Module): def __init__(self, style_dim, num_features): super().__init__() self.norm = nn.InstanceNorm1d(num_features, affine=False) self.fc = nn.Linear(style_dim, num_features*2) def forward(self, x, s): h = self.fc(s) h = h.view(h.size(0), h.size(1), 1) gamma, beta = torch.chunk(h, chunks=2, dim=1) return (1 + gamma) * self.norm(x) + beta class ConvNeXtBlock(nn.Module): """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. Args: dim (int): Number of input channels. intermediate_dim (int): Dimensionality of the intermediate layer. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. Defaults to None. """ def __init__( self, dim: int, intermediate_dim: int, layer_scale_init_value: float, style_dim: int, ): super().__init__() self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv self.norm = AdaIN1d(style_dim, dim) self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) self.gamma = ( nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None ) def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor: residual = x x = self.dwconv(x) x = self.norm(x, s) x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) x = residual + x return x def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: """ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. Args: x (Tensor): Input tensor. clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. Returns: Tensor: Element-wise logarithm of the input tensor with clipping applied. """ return torch.log(torch.clip(x, min=clip_val)) def symlog(x: torch.Tensor) -> torch.Tensor: return torch.sign(x) * torch.log1p(x.abs()) def symexp(x: torch.Tensor) -> torch.Tensor: return torch.sign(x) * (torch.exp(x.abs()) - 1) class Backbone(nn.Module): """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, C denotes output features, and L is the sequence length. Returns: Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. """ raise NotImplementedError("Subclasses must implement the forward method.") class Generator(Backbone): """ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization Args: input_channels (int): Number of input features channels. dim (int): Hidden dimension of the model. intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. num_layers (int): Number of ConvNeXtBlock layers. layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. """ def __init__( self, input_channels: int, dim: int, style_dim: int, intermediate_dim: int, num_layers: int, gen_istft_n_fft: int, gen_istft_hop_size: int, layer_scale_init_value: Optional[float] = None, ): super().__init__() self.input_channels = input_channels layer_scale_init_value = layer_scale_init_value or 1 / num_layers self.convnext = nn.ModuleList() for i in range(num_layers): self.convnext.append( ConvNeXtBlock( dim=dim, intermediate_dim=intermediate_dim, layer_scale_init_value=layer_scale_init_value, style_dim=style_dim, ) ) self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) self.apply(self._init_weights) self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) self.stft = ISTFTHead(dim=dim, n_fft=gen_istft_n_fft, hop_length=gen_istft_hop_size, padding="same") def _init_weights(self, m): if isinstance(m, (nn.Conv1d, nn.Linear)): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) def forward(self, x, s) -> torch.Tensor: for i, conv_block in enumerate(self.convnext): x = conv_block(x, s) x = self.final_layer_norm(x.transpose(1, 2)) x = self.stft(x) return x class ISTFT(nn.Module): """ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. See issue: https://github.com/pytorch/pytorch/issues/62323 Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. The NOLA constraint is met as we trim padded samples anyway. Args: n_fft (int): Size of Fourier transform. hop_length (int): The distance between neighboring sliding window frames. win_length (int): The size of window frame and STFT filter. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". """ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): super().__init__() if padding not in ["center", "same"]: raise ValueError("Padding must be 'center' or 'same'.") self.padding = padding self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length window = torch.hann_window(win_length) self.register_buffer("window", window) def forward(self, spec: torch.Tensor) -> torch.Tensor: """ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. Args: spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, N is the number of frequency bins, and T is the number of time frames. Returns: Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. """ if self.padding == "center": # Fallback to pytorch native implementation return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) elif self.padding == "same": pad = (self.win_length - self.hop_length) // 2 else: raise ValueError("Padding must be 'center' or 'same'.") assert spec.dim() == 3, "Expected a 3D tensor as input" B, N, T = spec.shape # Inverse FFT ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") ifft = ifft * self.window[None, :, None] # Overlap and Add output_size = (T - 1) * self.hop_length + self.win_length y = torch.nn.functional.fold( ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), )[:, 0, 0, pad:-pad] # Window envelope window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) window_envelope = torch.nn.functional.fold( window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), ).squeeze()[pad:-pad] # Normalize assert (window_envelope > 1e-11).all() y = y / window_envelope return y class FourierHead(nn.Module): """Base class for inverse fourier modules.""" def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ raise NotImplementedError("Subclasses must implement the forward method.") class ISTFTHead(FourierHead): """ ISTFT Head module for predicting STFT complex coefficients. Args: dim (int): Hidden dimension of the model. n_fft (int): Size of Fourier transform. hop_length (int): The distance between neighboring sliding window frames, which should align with the resolution of the input features. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". """ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): super().__init__() self.filter_length = n_fft self.win_length = n_fft self.hop_length = hop_length self.window = torch.from_numpy(get_window("hann", self.win_length, fftbins=True).astype(np.float32)) out_dim = n_fft + 2 self.out = torch.nn.Linear(dim, out_dim) self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the ISTFTHead module. Args: x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ x = self.out(x).transpose(1, 2) mag, p = x.chunk(2, dim=1) mag = torch.exp(mag) mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes # wrapping happens here. These two lines produce real and imaginary value x = torch.cos(p) y = torch.sin(p) # recalculating phase here does not produce anything new # only costs time # phase = torch.atan2(y, x) # S = mag * torch.exp(phase * 1j) # better directly produce the complex value S = mag * (x + 1j * y) audio = self.istft(S) return audio def transform(self, input_data): forward_transform = torch.stft( input_data, self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device), return_complex=True) return torch.abs(forward_transform), torch.angle(forward_transform) class AdainResBlk1d(nn.Module): def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0): super().__init__() self.actv = actv self.upsample_type = upsample self.upsample = UpSample1d(upsample) self.learned_sc = dim_in != dim_out self._build_weights(dim_in, dim_out, style_dim) self.dropout = nn.Dropout(dropout_p) if upsample == 'none': self.pool = nn.Identity() else: self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1)) def _build_weights(self, dim_in, dim_out, style_dim): self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)) self.norm1 = AdaIN1d(style_dim, dim_in) self.norm2 = AdaIN1d(style_dim, dim_out) if self.learned_sc: self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False)) def _shortcut(self, x): x = self.upsample(x) if self.learned_sc: x = self.conv1x1(x) return x def _residual(self, x, s): x = self.norm1(x, s) x = self.actv(x) x = self.pool(x) x = self.conv1(self.dropout(x)) x = self.norm2(x, s) x = self.actv(x) x = self.conv2(self.dropout(x)) return x def forward(self, x, s): out = self._residual(x, s) out = (out + self._shortcut(x)) / math.sqrt(2) return out class UpSample1d(nn.Module): def __init__(self, layer_type): super().__init__() self.layer_type = layer_type def forward(self, x): if self.layer_type == 'none': return x else: return F.interpolate(x, scale_factor=2, mode='nearest') class Decoder(nn.Module): def __init__(self, dim_in=512, style_dim=64, dim_out=80, intermediate_dim=1536, num_layers=8, gen_istft_n_fft=1024, gen_istft_hop_size=256): super().__init__() self.decode = nn.ModuleList() self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim) self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True)) self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) self.asr_res = nn.Sequential( weight_norm(nn.Conv1d(512, 64, kernel_size=1)), ) self.generator = Generator(input_channels=dim_out, dim=dim_in, style_dim=style_dim, intermediate_dim=intermediate_dim, num_layers=num_layers, gen_istft_n_fft=gen_istft_n_fft, gen_istft_hop_size=gen_istft_hop_size) def forward(self, asr, F0_curve, N, s): if self.training: downlist = [0, 3, 7] F0_down = downlist[random.randint(0, 2)] downlist = [0, 3, 7, 15] N_down = downlist[random.randint(0, 3)] if F0_down: F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down if N_down: N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down F0 = self.F0_conv(F0_curve.unsqueeze(1)) N = self.N_conv(N.unsqueeze(1)) x = torch.cat([asr, F0, N], axis=1) x = self.encode(x, s) asr_res = self.asr_res(asr) res = True for block in self.decode: if res: x = torch.cat([x, asr_res, F0, N], axis=1) x = block(x, s) if block.upsample_type != "none": res = False x = self.generator(x, s) x = x.unsqueeze(1) return x