import numpy as np import torch import torchaudio from torch import nn from torch.nn import functional as F from packaging import version is_pytorch2_1 = version.parse(torch.__version__) >= version.parse("2.1.0") is_pytorchaudio2_0 = version.parse(torchaudio.__version__) >= version.parse("2.0.1") if is_pytorch2_1: from torch.nn.utils.parametrizations import weight_norm else: from torch.nn.utils import weight_norm from torch.nn.utils import remove_weight_norm from torch.utils.checkpoint import checkpoint from ..commons import init_weights, get_padding class ResBlock(nn.Module): def __init__( self, channels: int, kernel_size: int = 7, dilation: tuple[int] = (1, 3, 5), leaky_relu_slope: float = 0.2, ): super().__init__() self.leaky_relu_slope = leaky_relu_slope self.convs1 = nn.ModuleList( [ weight_norm( nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=d, padding=get_padding(kernel_size, d), ) ) for d in dilation ] ) self.convs1.apply(init_weights) self.convs2 = nn.ModuleList( [ weight_norm( nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=1, padding=get_padding(kernel_size, 1), ) ) for d in dilation ] ) self.convs2.apply(init_weights) def forward(self, x: torch.Tensor): for c1, c2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, self.leaky_relu_slope) xt = c1(xt) xt = F.leaky_relu(xt, self.leaky_relu_slope) xt = c2(xt) x = xt + x return x def remove_weight_norm(self): for c1, c2 in zip(self.convs1, self.convs2): remove_weight_norm(c1) remove_weight_norm(c2) class AdaIN(nn.Module): def __init__( self, *, channels: int, leaky_relu_slope: float = 0.2, ): super().__init__() self.weight = nn.Parameter(torch.ones(channels) * 1e-4) self.activation = nn.LeakyReLU(leaky_relu_slope) def forward(self, x: torch.Tensor): gaussian = torch.randn_like(x) * self.weight[None, :, None] return self.activation(x + gaussian) class ParallelResBlock(nn.Module): def __init__( self, *, in_channels: int, out_channels: int, kernel_sizes: tuple[int] = (3, 7, 11), dilation: tuple[int] = (1, 3, 5), leaky_relu_slope: float = 0.2, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.input_conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=1, padding=3, ) self.input_conv.apply(init_weights) self.blocks = nn.ModuleList( [ nn.Sequential( AdaIN(channels=out_channels), ResBlock( out_channels, kernel_size=kernel_size, dilation=dilation, leaky_relu_slope=leaky_relu_slope, ), AdaIN(channels=out_channels), ) for kernel_size in kernel_sizes ] ) def forward(self, x: torch.Tensor): x = self.input_conv(x) return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0) def remove_weight_norm(self): remove_weight_norm(self.input_conv) for block in self.blocks: block[1].remove_weight_norm() class SineGenerator(nn.Module): def __init__( self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, ): super(SineGenerator, self).__init__() self.sine_amp = sine_amp self.noise_std = noise_std self.harmonic_num = harmonic_num self.dim = self.harmonic_num + 1 self.sampling_rate = samp_rate self.voiced_threshold = voiced_threshold self.merge = nn.Sequential( nn.Linear(self.dim, 1, bias=False), nn.Tanh(), ) def _f02uv(self, f0): uv = torch.ones_like(f0) uv = uv * (f0 > self.voiced_threshold) return uv def _f02sine(self, f0_values): rad_values = (f0_values / self.sampling_rate) % 1 rand_ini = torch.rand( f0_values.shape[0], f0_values.shape[2], device=f0_values.device ) rand_ini[:, 0] = 0 rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini tmp_over_one = torch.cumsum(rad_values, 1) % 1 tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 cumsum_shift = torch.zeros_like(rad_values) cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi) return sines def forward(self, f0): with torch.no_grad(): f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) f0_buf[:, :, 0] = f0[:, :, 0] for idx in np.arange(self.harmonic_num): f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2) sine_waves = self._f02sine(f0_buf) * self.sine_amp uv = self._f02uv(f0) noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 noise = noise_amp * torch.randn_like(sine_waves) sine_waves = sine_waves * uv + noise return self.merge(sine_waves) class RefineGANGenerator(nn.Module): def __init__( self, *, sample_rate: int = 44100, downsample_rates: tuple[int] = (2, 2, 8, 8), upsample_rates: tuple[int] = (8, 8, 2, 2), leaky_relu_slope: float = 0.2, num_mels: int = 128, start_channels: int = 16, gin_channels: int = 256, checkpointing: bool = False, upsample_initial_channel=512, ): super().__init__() self.upsample_rates = upsample_rates self.leaky_relu_slope = leaky_relu_slope self.checkpointing = checkpointing self.upp = np.prod(upsample_rates) self.m_source = SineGenerator(sample_rate) self.pre_conv = weight_norm( nn.Conv1d( 1, 16, 7, 1, padding=3, ) ) channels = start_channels size = self.upp self.downsample_blocks = nn.ModuleList([]) self.df0 = [] for i, u in enumerate(upsample_rates): new_size = int(size / upsample_rates[-i - 1]) self.df0.append([size, new_size]) size = new_size new_channels = channels * 2 self.downsample_blocks.append( weight_norm(nn.Conv1d(channels, new_channels, 7, 1, padding=3)) ) channels = new_channels channels = upsample_initial_channel self.mel_conv = weight_norm( nn.Conv1d( num_mels, channels // 2, 7, 1, padding=3, ) ) self.mel_conv.apply(init_weights) if gin_channels != 0: self.cond = nn.Conv1d(256, channels // 2, 1) self.upsample_blocks = nn.ModuleList([]) self.upsample_conv_blocks = nn.ModuleList([]) for rate in upsample_rates: new_channels = channels // 2 self.upsample_blocks.append(nn.Upsample(scale_factor=rate, mode="linear")) self.upsample_conv_blocks.append( ParallelResBlock( in_channels=channels + channels // 4, out_channels=new_channels, kernel_sizes=(3, 7, 11), dilation=(1, 3, 5), leaky_relu_slope=leaky_relu_slope, ) ) channels = new_channels self.conv_post = weight_norm( nn.Conv1d(channels, 1, 7, 1, padding=3, bias=False) ) self.conv_post.apply(init_weights) def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None): f0_size = mel.shape[-1] f0 = F.interpolate(f0.unsqueeze(1), size=f0_size * self.upp, mode="linear") har_source = self.m_source(f0.transpose(1, 2)).transpose(1, 2) x = self.pre_conv(har_source) downs = [] for block, (old_size, new_size) in zip(self.downsample_blocks, self.df0): x = F.leaky_relu(x, self.leaky_relu_slope) downs.append(x) if is_pytorchaudio2_0: x = torchaudio.functional.resample( x.contiguous(), orig_freq=int(f0_size * old_size), new_freq=int(f0_size * new_size), lowpass_filter_width=64, rolloff=0.9475937167399596, resampling_method="sinc_interp_kaiser", beta=14.769656459379492, ) else: x = torchaudio.functional.resample( x.contiguous(), orig_freq=int(f0_size * old_size), new_freq=int(f0_size * new_size), resampling_method="kaiser_window", beta=9.0, ) x = block(x) mel = self.mel_conv(mel) if g is not None: mel = mel + self.cond(g) x = torch.cat([mel, x], dim=1) for ups, res, down in zip( self.upsample_blocks, self.upsample_conv_blocks, reversed(downs), ): x = F.leaky_relu(x, self.leaky_relu_slope) if self.training and self.checkpointing: x = checkpoint(ups, x, use_reentrant=False) x = torch.cat([x, down], dim=1) x = checkpoint(res, x, use_reentrant=False) else: x = ups(x) x = torch.cat([x, down], dim=1) x = res(x) x = F.leaky_relu(x, self.leaky_relu_slope) x = self.conv_post(x) x = torch.tanh(x) return x def remove_weight_norm(self): remove_weight_norm(self.pre_conv) remove_weight_norm(self.mel_conv) remove_weight_norm(self.conv_post) for block in self.downsample_blocks: block.remove_weight_norm() for block in self.upsample_conv_blocks: block.remove_weight_norm()