noblebarkrr's picture
Updated to Dzeta
4f175c5
import numpy as np
import torch
import torchaudio
from torch import nn
from torch.nn import functional as F
from packaging import version
is_pytorch2_1 = version.parse(torch.__version__) >= version.parse("2.1.0")
is_pytorchaudio2_0 = version.parse(torchaudio.__version__) >= version.parse("2.0.1")
if is_pytorch2_1:
from torch.nn.utils.parametrizations import weight_norm
else:
from torch.nn.utils import weight_norm
from torch.nn.utils import remove_weight_norm
from torch.utils.checkpoint import checkpoint
from ..commons import init_weights, get_padding
class ResBlock(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int = 7,
dilation: tuple[int] = (1, 3, 5),
leaky_relu_slope: float = 0.2,
):
super().__init__()
self.leaky_relu_slope = leaky_relu_slope
self.convs1 = nn.ModuleList(
[
weight_norm(
nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=d,
padding=get_padding(kernel_size, d),
)
)
for d in dilation
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
)
for d in dilation
]
)
self.convs2.apply(init_weights)
def forward(self, x: torch.Tensor):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, self.leaky_relu_slope)
xt = c1(xt)
xt = F.leaky_relu(xt, self.leaky_relu_slope)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for c1, c2 in zip(self.convs1, self.convs2):
remove_weight_norm(c1)
remove_weight_norm(c2)
class AdaIN(nn.Module):
def __init__(
self,
*,
channels: int,
leaky_relu_slope: float = 0.2,
):
super().__init__()
self.weight = nn.Parameter(torch.ones(channels) * 1e-4)
self.activation = nn.LeakyReLU(leaky_relu_slope)
def forward(self, x: torch.Tensor):
gaussian = torch.randn_like(x) * self.weight[None, :, None]
return self.activation(x + gaussian)
class ParallelResBlock(nn.Module):
def __init__(
self,
*,
in_channels: int,
out_channels: int,
kernel_sizes: tuple[int] = (3, 7, 11),
dilation: tuple[int] = (1, 3, 5),
leaky_relu_slope: float = 0.2,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.input_conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=7,
stride=1,
padding=3,
)
self.input_conv.apply(init_weights)
self.blocks = nn.ModuleList(
[
nn.Sequential(
AdaIN(channels=out_channels),
ResBlock(
out_channels,
kernel_size=kernel_size,
dilation=dilation,
leaky_relu_slope=leaky_relu_slope,
),
AdaIN(channels=out_channels),
)
for kernel_size in kernel_sizes
]
)
def forward(self, x: torch.Tensor):
x = self.input_conv(x)
return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
def remove_weight_norm(self):
remove_weight_norm(self.input_conv)
for block in self.blocks:
block[1].remove_weight_norm()
class SineGenerator(nn.Module):
def __init__(
self,
samp_rate,
harmonic_num=0,
sine_amp=0.1,
noise_std=0.003,
voiced_threshold=0,
):
super(SineGenerator, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.dim = self.harmonic_num + 1
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
self.merge = nn.Sequential(
nn.Linear(self.dim, 1, bias=False),
nn.Tanh(),
)
def _f02uv(self, f0):
uv = torch.ones_like(f0)
uv = uv * (f0 > self.voiced_threshold)
return uv
def _f02sine(self, f0_values):
rad_values = (f0_values / self.sampling_rate) % 1
rand_ini = torch.rand(
f0_values.shape[0], f0_values.shape[2], device=f0_values.device
)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
tmp_over_one = torch.cumsum(rad_values, 1) % 1
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
cumsum_shift = torch.zeros_like(rad_values)
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
return sines
def forward(self, f0):
with torch.no_grad():
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
f0_buf[:, :, 0] = f0[:, :, 0]
for idx in np.arange(self.harmonic_num):
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
sine_waves = self._f02sine(f0_buf) * self.sine_amp
uv = self._f02uv(f0)
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
sine_waves = sine_waves * uv + noise
return self.merge(sine_waves)
class RefineGANGenerator(nn.Module):
def __init__(
self,
*,
sample_rate: int = 44100,
downsample_rates: tuple[int] = (2, 2, 8, 8),
upsample_rates: tuple[int] = (8, 8, 2, 2),
leaky_relu_slope: float = 0.2,
num_mels: int = 128,
start_channels: int = 16,
gin_channels: int = 256,
checkpointing: bool = False,
upsample_initial_channel=512,
):
super().__init__()
self.upsample_rates = upsample_rates
self.leaky_relu_slope = leaky_relu_slope
self.checkpointing = checkpointing
self.upp = np.prod(upsample_rates)
self.m_source = SineGenerator(sample_rate)
self.pre_conv = weight_norm(
nn.Conv1d(
1,
16,
7,
1,
padding=3,
)
)
channels = start_channels
size = self.upp
self.downsample_blocks = nn.ModuleList([])
self.df0 = []
for i, u in enumerate(upsample_rates):
new_size = int(size / upsample_rates[-i - 1])
self.df0.append([size, new_size])
size = new_size
new_channels = channels * 2
self.downsample_blocks.append(
weight_norm(nn.Conv1d(channels, new_channels, 7, 1, padding=3))
)
channels = new_channels
channels = upsample_initial_channel
self.mel_conv = weight_norm(
nn.Conv1d(
num_mels,
channels // 2,
7,
1,
padding=3,
)
)
self.mel_conv.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(256, channels // 2, 1)
self.upsample_blocks = nn.ModuleList([])
self.upsample_conv_blocks = nn.ModuleList([])
for rate in upsample_rates:
new_channels = channels // 2
self.upsample_blocks.append(nn.Upsample(scale_factor=rate, mode="linear"))
self.upsample_conv_blocks.append(
ParallelResBlock(
in_channels=channels + channels // 4,
out_channels=new_channels,
kernel_sizes=(3, 7, 11),
dilation=(1, 3, 5),
leaky_relu_slope=leaky_relu_slope,
)
)
channels = new_channels
self.conv_post = weight_norm(
nn.Conv1d(channels, 1, 7, 1, padding=3, bias=False)
)
self.conv_post.apply(init_weights)
def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None):
f0_size = mel.shape[-1]
f0 = F.interpolate(f0.unsqueeze(1), size=f0_size * self.upp, mode="linear")
har_source = self.m_source(f0.transpose(1, 2)).transpose(1, 2)
x = self.pre_conv(har_source)
downs = []
for block, (old_size, new_size) in zip(self.downsample_blocks, self.df0):
x = F.leaky_relu(x, self.leaky_relu_slope)
downs.append(x)
if is_pytorchaudio2_0:
x = torchaudio.functional.resample(
x.contiguous(),
orig_freq=int(f0_size * old_size),
new_freq=int(f0_size * new_size),
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="sinc_interp_kaiser",
beta=14.769656459379492,
)
else:
x = torchaudio.functional.resample(
x.contiguous(),
orig_freq=int(f0_size * old_size),
new_freq=int(f0_size * new_size),
resampling_method="kaiser_window",
beta=9.0,
)
x = block(x)
mel = self.mel_conv(mel)
if g is not None:
mel = mel + self.cond(g)
x = torch.cat([mel, x], dim=1)
for ups, res, down in zip(
self.upsample_blocks,
self.upsample_conv_blocks,
reversed(downs),
):
x = F.leaky_relu(x, self.leaky_relu_slope)
if self.training and self.checkpointing:
x = checkpoint(ups, x, use_reentrant=False)
x = torch.cat([x, down], dim=1)
x = checkpoint(res, x, use_reentrant=False)
else:
x = ups(x)
x = torch.cat([x, down], dim=1)
x = res(x)
x = F.leaky_relu(x, self.leaky_relu_slope)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
remove_weight_norm(self.pre_conv)
remove_weight_norm(self.mel_conv)
remove_weight_norm(self.conv_post)
for block in self.downsample_blocks:
block.remove_weight_norm()
for block in self.upsample_conv_blocks:
block.remove_weight_norm()