import torch import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from packaging import version is_pytorch2_1 = version.parse(torch.__version__) >= version.parse("2.1.0") if is_pytorch2_1: from torch.nn.utils.parametrizations import spectral_norm, weight_norm else: from torch.nn.utils.parametrizations import spectral_norm from torch.nn.utils import weight_norm from .commons import get_padding from .residuals import LRELU_SLOPE class MultiPeriodDiscriminator(torch.nn.Module): def __init__( self, use_spectral_norm: bool = False, checkpointing: bool = False, version: str = "v2", ): super().__init__() if version == "v1": periods = [2, 3, 5, 7, 11, 17] resolutions = [] elif version == "v2": periods = [2, 3, 5, 7, 11, 17, 23, 37] resolutions = [] elif version == "v3": periods = [2, 3, 5, 7, 11] resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]] self.checkpointing = checkpointing self.discriminators = torch.nn.ModuleList( [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in periods] + [ DiscriminatorR(r, use_spectral_norm=use_spectral_norm) for r in resolutions ] ) def forward(self, y, y_hat): y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] for d in self.discriminators: if self.training and self.checkpointing: y_d_r, fmap_r = checkpoint(d, y, use_reentrant=False) y_d_g, fmap_g = checkpoint(d, y_hat, use_reentrant=False) else: y_d_r, fmap_r = d(y) y_d_g, fmap_g = d(y_hat) y_d_rs.append(y_d_r) y_d_gs.append(y_d_g) fmap_rs.append(fmap_r) fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm: bool = False): super().__init__() norm_f = spectral_norm if use_spectral_norm else weight_norm self.convs = torch.nn.ModuleList( [ norm_f(torch.nn.Conv1d(1, 16, 15, 1, padding=7)), norm_f(torch.nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(torch.nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(torch.nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 5, 1, padding=2)), ] ) self.conv_post = norm_f(torch.nn.Conv1d(1024, 1, 3, 1, padding=1)) self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE) def forward(self, x): fmap = [] for conv in self.convs: x = self.lrelu(conv(x)) fmap.append(x) x = self.conv_post(x) fmap.append(x) x = torch.flatten(x, 1, -1) return x, fmap class DiscriminatorP(torch.nn.Module): def __init__( self, period: int, kernel_size: int = 5, stride: int = 3, use_spectral_norm: bool = False, ): super().__init__() self.period = period norm_f = spectral_norm if use_spectral_norm else weight_norm in_channels = [1, 32, 128, 512, 1024] out_channels = [32, 128, 512, 1024, 1024] strides = [3, 3, 3, 3, 1] self.convs = torch.nn.ModuleList( [ norm_f( torch.nn.Conv2d( in_ch, out_ch, (kernel_size, 1), (s, 1), padding=(get_padding(kernel_size, 1), 0), ) ) for in_ch, out_ch, s in zip(in_channels, out_channels, strides) ] ) self.conv_post = norm_f(torch.nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE) def forward(self, x): fmap = [] b, c, t = x.shape if t % self.period != 0: n_pad = self.period - (t % self.period) x = torch.nn.functional.pad(x, (0, n_pad), "reflect") x = x.view(b, c, -1, self.period) for conv in self.convs: x = self.lrelu(conv(x)) fmap.append(x) x = self.conv_post(x) fmap.append(x) x = torch.flatten(x, 1, -1) return x, fmap class DiscriminatorR(torch.nn.Module): def __init__(self, resolution, use_spectral_norm=False): super().__init__() self.resolution = resolution self.lrelu_slope = 0.1 norm_f = spectral_norm if use_spectral_norm else weight_norm self.convs = torch.nn.ModuleList( [ norm_f( torch.nn.Conv2d( 1, 32, (3, 9), padding=(1, 4), ) ), norm_f( torch.nn.Conv2d( 32, 32, (3, 9), stride=(1, 2), padding=(1, 4), ) ), norm_f( torch.nn.Conv2d( 32, 32, (3, 9), stride=(1, 2), padding=(1, 4), ) ), norm_f( torch.nn.Conv2d( 32, 32, (3, 9), stride=(1, 2), padding=(1, 4), ) ), norm_f( torch.nn.Conv2d( 32, 32, (3, 3), padding=(1, 1), ) ), ] ) self.conv_post = norm_f(torch.nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) def forward(self, x): fmap = [] x = self.spectrogram(x).unsqueeze(1) for layer in self.convs: x = F.leaky_relu(layer(x), self.lrelu_slope) fmap.append(x) x = self.conv_post(x) fmap.append(x) return torch.flatten(x, 1, -1), fmap def spectrogram(self, x): n_fft, hop_length, win_length = self.resolution pad = int((n_fft - hop_length) / 2) x = F.pad( x, (pad, pad), mode="reflect", ).squeeze(1) x = torch.stft( x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=torch.ones(win_length, device=x.device), center=False, return_complex=True, ) mag = torch.norm(torch.view_as_real(x), p=2, dim=-1) return mag