| import torch |
| from torch import nn |
| from torch.nn.modules.conv import Conv1d |
|
|
| from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator |
|
|
|
|
| class DiscriminatorS(torch.nn.Module): |
| """HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN. |
| |
| Args: |
| use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. |
| """ |
|
|
| def __init__(self, use_spectral_norm=False): |
| super().__init__() |
| norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm |
| self.convs = nn.ModuleList( |
| [ |
| norm_f(Conv1d(1, 16, 15, 1, padding=7)), |
| norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), |
| norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), |
| norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), |
| norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), |
| norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), |
| ] |
| ) |
| self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) |
|
|
| def forward(self, x): |
| """ |
| Args: |
| x (Tensor): input waveform. |
| |
| Returns: |
| Tensor: discriminator scores. |
| List[Tensor]: list of features from the convolutiona layers. |
| """ |
| feat = [] |
| for l in self.convs: |
| x = l(x) |
| x = torch.nn.functional.leaky_relu(x, 0.1) |
| feat.append(x) |
| x = self.conv_post(x) |
| feat.append(x) |
| x = torch.flatten(x, 1, -1) |
| return x, feat |
|
|
|
|
| class VitsDiscriminator(nn.Module): |
| """VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator. |
| |
| :: |
| waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats |
| |--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^ |
| |
| Args: |
| use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. |
| """ |
|
|
| def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False): |
| super().__init__() |
| self.nets = nn.ModuleList() |
| self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm)) |
| self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]) |
|
|
| def forward(self, x, x_hat=None): |
| """ |
| Args: |
| x (Tensor): ground truth waveform. |
| x_hat (Tensor): predicted waveform. |
| |
| Returns: |
| List[Tensor]: discriminator scores. |
| List[List[Tensor]]: list of list of features from each layers of each discriminator. |
| """ |
| x_scores = [] |
| x_hat_scores = [] if x_hat is not None else None |
| x_feats = [] |
| x_hat_feats = [] if x_hat is not None else None |
| for net in self.nets: |
| x_score, x_feat = net(x) |
| x_scores.append(x_score) |
| x_feats.append(x_feat) |
| if x_hat is not None: |
| x_hat_score, x_hat_feat = net(x_hat) |
| x_hat_scores.append(x_hat_score) |
| x_hat_feats.append(x_hat_feat) |
| return x_scores, x_feats, x_hat_scores, x_hat_feats |
|
|