|
|
from typing import Optional |
|
|
|
|
|
import math |
|
|
import random |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
from torch.nn.utils.parametrizations import weight_norm |
|
|
|
|
|
from typing import Optional, Tuple |
|
|
from scipy.signal import get_window |
|
|
|
|
|
class AdaIN1d(nn.Module): |
|
|
def __init__(self, style_dim, num_features): |
|
|
super().__init__() |
|
|
self.norm = nn.InstanceNorm1d(num_features, affine=False) |
|
|
self.fc = nn.Linear(style_dim, num_features*2) |
|
|
|
|
|
def forward(self, x, s): |
|
|
h = self.fc(s) |
|
|
h = h.view(h.size(0), h.size(1), 1) |
|
|
gamma, beta = torch.chunk(h, chunks=2, dim=1) |
|
|
return (1 + gamma) * self.norm(x) + beta |
|
|
|
|
|
class ConvNeXtBlock(nn.Module): |
|
|
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. |
|
|
|
|
|
Args: |
|
|
dim (int): Number of input channels. |
|
|
intermediate_dim (int): Dimensionality of the intermediate layer. |
|
|
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. |
|
|
Defaults to None. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
intermediate_dim: int, |
|
|
layer_scale_init_value: float, |
|
|
style_dim: int, |
|
|
): |
|
|
super().__init__() |
|
|
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) |
|
|
self.norm = AdaIN1d(style_dim, dim) |
|
|
self.pwconv1 = nn.Linear(dim, intermediate_dim) |
|
|
self.act = nn.GELU() |
|
|
self.pwconv2 = nn.Linear(intermediate_dim, dim) |
|
|
self.gamma = ( |
|
|
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) |
|
|
if layer_scale_init_value > 0 |
|
|
else None |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor: |
|
|
residual = x |
|
|
x = self.dwconv(x) |
|
|
x = self.norm(x, s) |
|
|
x = x.transpose(1, 2) |
|
|
x = self.pwconv1(x) |
|
|
x = self.act(x) |
|
|
x = self.pwconv2(x) |
|
|
if self.gamma is not None: |
|
|
x = self.gamma * x |
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
x = residual + x |
|
|
return x |
|
|
|
|
|
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: |
|
|
""" |
|
|
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. |
|
|
|
|
|
Args: |
|
|
x (Tensor): Input tensor. |
|
|
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. |
|
|
|
|
|
Returns: |
|
|
Tensor: Element-wise logarithm of the input tensor with clipping applied. |
|
|
""" |
|
|
return torch.log(torch.clip(x, min=clip_val)) |
|
|
|
|
|
|
|
|
def symlog(x: torch.Tensor) -> torch.Tensor: |
|
|
return torch.sign(x) * torch.log1p(x.abs()) |
|
|
|
|
|
|
|
|
def symexp(x: torch.Tensor) -> torch.Tensor: |
|
|
return torch.sign(x) * (torch.exp(x.abs()) - 1) |
|
|
|
|
|
class Backbone(nn.Module): |
|
|
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" |
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, |
|
|
C denotes output features, and L is the sequence length. |
|
|
|
|
|
Returns: |
|
|
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, |
|
|
and H denotes the model dimension. |
|
|
""" |
|
|
raise NotImplementedError("Subclasses must implement the forward method.") |
|
|
|
|
|
|
|
|
class Generator(Backbone): |
|
|
""" |
|
|
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization |
|
|
|
|
|
Args: |
|
|
input_channels (int): Number of input features channels. |
|
|
dim (int): Hidden dimension of the model. |
|
|
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. |
|
|
num_layers (int): Number of ConvNeXtBlock layers. |
|
|
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_channels: int, |
|
|
dim: int, |
|
|
style_dim: int, |
|
|
intermediate_dim: int, |
|
|
num_layers: int, |
|
|
gen_istft_n_fft: int, |
|
|
gen_istft_hop_size: int, |
|
|
layer_scale_init_value: Optional[float] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.input_channels = input_channels |
|
|
layer_scale_init_value = layer_scale_init_value or 1 / num_layers |
|
|
|
|
|
self.convnext = nn.ModuleList() |
|
|
|
|
|
for i in range(num_layers): |
|
|
self.convnext.append( |
|
|
ConvNeXtBlock( |
|
|
dim=dim, |
|
|
intermediate_dim=intermediate_dim, |
|
|
layer_scale_init_value=layer_scale_init_value, |
|
|
style_dim=style_dim, |
|
|
) |
|
|
) |
|
|
|
|
|
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) |
|
|
self.apply(self._init_weights) |
|
|
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) |
|
|
self.stft = ISTFTHead(dim=dim, n_fft=gen_istft_n_fft, hop_length=gen_istft_hop_size, padding="same") |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, (nn.Conv1d, nn.Linear)): |
|
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def forward(self, x, s) -> torch.Tensor: |
|
|
for i, conv_block in enumerate(self.convnext): |
|
|
x = conv_block(x, s) |
|
|
x = self.final_layer_norm(x.transpose(1, 2)) |
|
|
x = self.stft(x) |
|
|
return x |
|
|
|
|
|
class ISTFT(nn.Module): |
|
|
""" |
|
|
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with |
|
|
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. |
|
|
See issue: https://github.com/pytorch/pytorch/issues/62323 |
|
|
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. |
|
|
The NOLA constraint is met as we trim padded samples anyway. |
|
|
|
|
|
Args: |
|
|
n_fft (int): Size of Fourier transform. |
|
|
hop_length (int): The distance between neighboring sliding window frames. |
|
|
win_length (int): The size of window frame and STFT filter. |
|
|
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". |
|
|
""" |
|
|
|
|
|
def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): |
|
|
super().__init__() |
|
|
if padding not in ["center", "same"]: |
|
|
raise ValueError("Padding must be 'center' or 'same'.") |
|
|
self.padding = padding |
|
|
self.n_fft = n_fft |
|
|
self.hop_length = hop_length |
|
|
self.win_length = win_length |
|
|
window = torch.hann_window(win_length) |
|
|
self.register_buffer("window", window) |
|
|
|
|
|
def forward(self, spec: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. |
|
|
|
|
|
Args: |
|
|
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, |
|
|
N is the number of frequency bins, and T is the number of time frames. |
|
|
|
|
|
Returns: |
|
|
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. |
|
|
""" |
|
|
if self.padding == "center": |
|
|
|
|
|
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) |
|
|
elif self.padding == "same": |
|
|
pad = (self.win_length - self.hop_length) // 2 |
|
|
else: |
|
|
raise ValueError("Padding must be 'center' or 'same'.") |
|
|
|
|
|
assert spec.dim() == 3, "Expected a 3D tensor as input" |
|
|
B, N, T = spec.shape |
|
|
|
|
|
|
|
|
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") |
|
|
ifft = ifft * self.window[None, :, None] |
|
|
|
|
|
|
|
|
output_size = (T - 1) * self.hop_length + self.win_length |
|
|
y = torch.nn.functional.fold( |
|
|
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), |
|
|
)[:, 0, 0, pad:-pad] |
|
|
|
|
|
|
|
|
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) |
|
|
window_envelope = torch.nn.functional.fold( |
|
|
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), |
|
|
).squeeze()[pad:-pad] |
|
|
|
|
|
|
|
|
assert (window_envelope > 1e-11).all() |
|
|
y = y / window_envelope |
|
|
|
|
|
return y |
|
|
|
|
|
class FourierHead(nn.Module): |
|
|
"""Base class for inverse fourier modules.""" |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, |
|
|
L is the sequence length, and H denotes the model dimension. |
|
|
|
|
|
Returns: |
|
|
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. |
|
|
""" |
|
|
raise NotImplementedError("Subclasses must implement the forward method.") |
|
|
|
|
|
class ISTFTHead(FourierHead): |
|
|
""" |
|
|
ISTFT Head module for predicting STFT complex coefficients. |
|
|
|
|
|
Args: |
|
|
dim (int): Hidden dimension of the model. |
|
|
n_fft (int): Size of Fourier transform. |
|
|
hop_length (int): The distance between neighboring sliding window frames, which should align with |
|
|
the resolution of the input features. |
|
|
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". |
|
|
""" |
|
|
|
|
|
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): |
|
|
super().__init__() |
|
|
self.filter_length = n_fft |
|
|
self.win_length = n_fft |
|
|
self.hop_length = hop_length |
|
|
self.window = torch.from_numpy(get_window("hann", self.win_length, fftbins=True).astype(np.float32)) |
|
|
|
|
|
out_dim = n_fft + 2 |
|
|
self.out = torch.nn.Linear(dim, out_dim) |
|
|
self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass of the ISTFTHead module. |
|
|
|
|
|
Args: |
|
|
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, |
|
|
L is the sequence length, and H denotes the model dimension. |
|
|
|
|
|
Returns: |
|
|
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. |
|
|
""" |
|
|
x = self.out(x).transpose(1, 2) |
|
|
mag, p = x.chunk(2, dim=1) |
|
|
mag = torch.exp(mag) |
|
|
mag = torch.clip(mag, max=1e2) |
|
|
|
|
|
x = torch.cos(p) |
|
|
y = torch.sin(p) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
S = mag * (x + 1j * y) |
|
|
audio = self.istft(S) |
|
|
return audio |
|
|
|
|
|
def transform(self, input_data): |
|
|
forward_transform = torch.stft( |
|
|
input_data, |
|
|
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device), |
|
|
return_complex=True) |
|
|
|
|
|
return torch.abs(forward_transform), torch.angle(forward_transform) |
|
|
|
|
|
|
|
|
class AdainResBlk1d(nn.Module): |
|
|
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), |
|
|
upsample='none', dropout_p=0.0): |
|
|
super().__init__() |
|
|
self.actv = actv |
|
|
self.upsample_type = upsample |
|
|
self.upsample = UpSample1d(upsample) |
|
|
self.learned_sc = dim_in != dim_out |
|
|
self._build_weights(dim_in, dim_out, style_dim) |
|
|
self.dropout = nn.Dropout(dropout_p) |
|
|
|
|
|
if upsample == 'none': |
|
|
self.pool = nn.Identity() |
|
|
else: |
|
|
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1)) |
|
|
|
|
|
|
|
|
def _build_weights(self, dim_in, dim_out, style_dim): |
|
|
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) |
|
|
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)) |
|
|
self.norm1 = AdaIN1d(style_dim, dim_in) |
|
|
self.norm2 = AdaIN1d(style_dim, dim_out) |
|
|
if self.learned_sc: |
|
|
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False)) |
|
|
|
|
|
def _shortcut(self, x): |
|
|
x = self.upsample(x) |
|
|
if self.learned_sc: |
|
|
x = self.conv1x1(x) |
|
|
return x |
|
|
|
|
|
def _residual(self, x, s): |
|
|
x = self.norm1(x, s) |
|
|
x = self.actv(x) |
|
|
x = self.pool(x) |
|
|
x = self.conv1(self.dropout(x)) |
|
|
x = self.norm2(x, s) |
|
|
x = self.actv(x) |
|
|
x = self.conv2(self.dropout(x)) |
|
|
return x |
|
|
|
|
|
def forward(self, x, s): |
|
|
out = self._residual(x, s) |
|
|
out = (out + self._shortcut(x)) / math.sqrt(2) |
|
|
return out |
|
|
|
|
|
class UpSample1d(nn.Module): |
|
|
def __init__(self, layer_type): |
|
|
super().__init__() |
|
|
self.layer_type = layer_type |
|
|
|
|
|
def forward(self, x): |
|
|
if self.layer_type == 'none': |
|
|
return x |
|
|
else: |
|
|
return F.interpolate(x, scale_factor=2, mode='nearest') |
|
|
|
|
|
class Decoder(nn.Module): |
|
|
def __init__(self, dim_in=512, style_dim=64, dim_out=80, |
|
|
intermediate_dim=1536, |
|
|
num_layers=8, |
|
|
gen_istft_n_fft=1024, gen_istft_hop_size=256): |
|
|
super().__init__() |
|
|
|
|
|
self.decode = nn.ModuleList() |
|
|
|
|
|
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim) |
|
|
|
|
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) |
|
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) |
|
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) |
|
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True)) |
|
|
|
|
|
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) |
|
|
|
|
|
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) |
|
|
|
|
|
self.asr_res = nn.Sequential( |
|
|
weight_norm(nn.Conv1d(512, 64, kernel_size=1)), |
|
|
) |
|
|
|
|
|
self.generator = Generator(input_channels=dim_out, dim=dim_in, style_dim=style_dim, |
|
|
intermediate_dim=intermediate_dim, num_layers=num_layers, |
|
|
gen_istft_n_fft=gen_istft_n_fft, gen_istft_hop_size=gen_istft_hop_size) |
|
|
|
|
|
def forward(self, asr, F0_curve, N, s): |
|
|
if self.training: |
|
|
downlist = [0, 3, 7] |
|
|
F0_down = downlist[random.randint(0, 2)] |
|
|
downlist = [0, 3, 7, 15] |
|
|
N_down = downlist[random.randint(0, 3)] |
|
|
if F0_down: |
|
|
F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down |
|
|
if N_down: |
|
|
N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down |
|
|
|
|
|
|
|
|
F0 = self.F0_conv(F0_curve.unsqueeze(1)) |
|
|
N = self.N_conv(N.unsqueeze(1)) |
|
|
|
|
|
x = torch.cat([asr, F0, N], axis=1) |
|
|
x = self.encode(x, s) |
|
|
|
|
|
asr_res = self.asr_res(asr) |
|
|
|
|
|
res = True |
|
|
for block in self.decode: |
|
|
if res: |
|
|
x = torch.cat([x, asr_res, F0, N], axis=1) |
|
|
x = block(x, s) |
|
|
if block.upsample_type != "none": |
|
|
res = False |
|
|
|
|
|
x = self.generator(x, s) |
|
|
x = x.unsqueeze(1) |
|
|
return x |