| from typing import Dict, Union |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from TTS.utils.audio.torch_transforms import TorchSTFT |
| from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss |
|
|
| |
| |
| |
|
|
|
|
| class STFTLoss(nn.Module): |
| """STFT loss. Input generate and real waveforms are converted |
| to spectrograms compared with L1 and Spectral convergence losses. |
| It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" |
|
|
| def __init__(self, n_fft, hop_length, win_length): |
| super().__init__() |
| self.n_fft = n_fft |
| self.hop_length = hop_length |
| self.win_length = win_length |
| self.stft = TorchSTFT(n_fft, hop_length, win_length) |
|
|
| def forward(self, y_hat, y): |
| y_hat_M = self.stft(y_hat) |
| y_M = self.stft(y) |
| |
| loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) |
| |
| loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro") |
| return loss_mag, loss_sc |
|
|
|
|
| class MultiScaleSTFTLoss(torch.nn.Module): |
| """Multi-scale STFT loss. Input generate and real waveforms are converted |
| to spectrograms compared with L1 and Spectral convergence losses. |
| It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" |
|
|
| def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)): |
| super().__init__() |
| self.loss_funcs = torch.nn.ModuleList() |
| for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths): |
| self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length)) |
|
|
| def forward(self, y_hat, y): |
| N = len(self.loss_funcs) |
| loss_sc = 0 |
| loss_mag = 0 |
| for f in self.loss_funcs: |
| lm, lsc = f(y_hat, y) |
| loss_mag += lm |
| loss_sc += lsc |
| loss_sc /= N |
| loss_mag /= N |
| return loss_mag, loss_sc |
|
|
|
|
| class L1SpecLoss(nn.Module): |
| """L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf""" |
|
|
| def __init__( |
| self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True |
| ): |
| super().__init__() |
| self.use_mel = use_mel |
| self.stft = TorchSTFT( |
| n_fft, |
| hop_length, |
| win_length, |
| sample_rate=sample_rate, |
| mel_fmin=mel_fmin, |
| mel_fmax=mel_fmax, |
| n_mels=n_mels, |
| use_mel=use_mel, |
| ) |
|
|
| def forward(self, y_hat, y): |
| y_hat_M = self.stft(y_hat) |
| y_M = self.stft(y) |
| |
| loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) |
| return loss_mag |
|
|
|
|
| class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): |
| """Multiscale STFT loss for multi band model outputs. |
| From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106""" |
|
|
| |
| def forward(self, y_hat, y): |
| y_hat = y_hat.view(-1, 1, y_hat.shape[2]) |
| y = y.view(-1, 1, y.shape[2]) |
| return super().forward(y_hat.squeeze(1), y.squeeze(1)) |
|
|
|
|
| class MSEGLoss(nn.Module): |
| """Mean Squared Generator Loss""" |
|
|
| |
| def forward(self, score_real): |
| loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape)) |
| return loss_fake |
|
|
|
|
| class HingeGLoss(nn.Module): |
| """Hinge Discriminator Loss""" |
|
|
| |
| def forward(self, score_real): |
| |
| loss_fake = torch.mean(F.relu(1.0 - score_real)) |
| return loss_fake |
|
|
|
|
| |
| |
| |
|
|
|
|
| class MSEDLoss(nn.Module): |
| """Mean Squared Discriminator Loss""" |
|
|
| def __init__( |
| self, |
| ): |
| super().__init__() |
| self.loss_func = nn.MSELoss() |
|
|
| |
| def forward(self, score_fake, score_real): |
| loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape)) |
| loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape)) |
| loss_d = loss_real + loss_fake |
| return loss_d, loss_real, loss_fake |
|
|
|
|
| class HingeDLoss(nn.Module): |
| """Hinge Discriminator Loss""" |
|
|
| |
| def forward(self, score_fake, score_real): |
| loss_real = torch.mean(F.relu(1.0 - score_real)) |
| loss_fake = torch.mean(F.relu(1.0 + score_fake)) |
| loss_d = loss_real + loss_fake |
| return loss_d, loss_real, loss_fake |
|
|
|
|
| class MelganFeatureLoss(nn.Module): |
| def __init__( |
| self, |
| ): |
| super().__init__() |
| self.loss_func = nn.L1Loss() |
|
|
| |
| def forward(self, fake_feats, real_feats): |
| loss_feats = 0 |
| num_feats = 0 |
| for idx, _ in enumerate(fake_feats): |
| for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]): |
| loss_feats += self.loss_func(fake_feat, real_feat) |
| num_feats += 1 |
| loss_feats = loss_feats / num_feats |
| return loss_feats |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _apply_G_adv_loss(scores_fake, loss_func): |
| """Compute G adversarial loss function |
| and normalize values""" |
| adv_loss = 0 |
| if isinstance(scores_fake, list): |
| for score_fake in scores_fake: |
| fake_loss = loss_func(score_fake) |
| adv_loss += fake_loss |
| adv_loss /= len(scores_fake) |
| else: |
| fake_loss = loss_func(scores_fake) |
| adv_loss = fake_loss |
| return adv_loss |
|
|
|
|
| def _apply_D_loss(scores_fake, scores_real, loss_func): |
| """Compute D loss func and normalize loss values""" |
| loss = 0 |
| real_loss = 0 |
| fake_loss = 0 |
| if isinstance(scores_fake, list): |
| |
| for score_fake, score_real in zip(scores_fake, scores_real): |
| total_loss, real_loss_, fake_loss_ = loss_func(score_fake=score_fake, score_real=score_real) |
| loss += total_loss |
| real_loss += real_loss_ |
| fake_loss += fake_loss_ |
| |
| loss /= len(scores_fake) |
| real_loss /= len(scores_real) |
| fake_loss /= len(scores_fake) |
| else: |
| |
| total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real) |
| loss = total_loss |
| return loss, real_loss, fake_loss |
|
|
|
|
| |
| |
| |
|
|
|
|
| class GeneratorLoss(nn.Module): |
| """Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes |
| losses. It allows to experiment with different combinations of loss functions with different models by just |
| changing configurations. |
| |
| Args: |
| C (AttrDict): model configuration. |
| """ |
|
|
| def __init__(self, C): |
| super().__init__() |
| assert not ( |
| C.use_mse_gan_loss and C.use_hinge_gan_loss |
| ), " [!] Cannot use HingeGANLoss and MSEGANLoss together." |
|
|
| self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False |
| self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False |
| self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False |
| self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False |
| self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False |
| self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False |
|
|
| self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0 |
| self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0 |
| self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0 |
| self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0 |
| self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0 |
| self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0 |
|
|
| if C.use_stft_loss: |
| self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params) |
| if C.use_subband_stft_loss: |
| self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params) |
| if C.use_mse_gan_loss: |
| self.mse_loss = MSEGLoss() |
| if C.use_hinge_gan_loss: |
| self.hinge_loss = HingeGLoss() |
| if C.use_feat_match_loss: |
| self.feat_match_loss = MelganFeatureLoss() |
| if C.use_l1_spec_loss: |
| assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"] |
| self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params) |
|
|
| def forward( |
| self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None |
| ): |
| gen_loss = 0 |
| adv_loss = 0 |
| return_dict = {} |
|
|
| |
| if self.use_stft_loss: |
| stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1)) |
| return_dict["G_stft_loss_mg"] = stft_loss_mg |
| return_dict["G_stft_loss_sc"] = stft_loss_sc |
| gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc) |
|
|
| |
| if self.use_l1_spec_loss: |
| l1_spec_loss = self.l1_spec_loss(y_hat, y) |
| return_dict["G_l1_spec_loss"] = l1_spec_loss |
| gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss |
|
|
| |
| if self.use_subband_stft_loss: |
| subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub) |
| return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg |
| return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc |
| gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc) |
|
|
| |
| if self.use_mse_gan_loss and scores_fake is not None: |
| mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss) |
| return_dict["G_mse_fake_loss"] = mse_fake_loss |
| adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss |
|
|
| |
| if self.use_hinge_gan_loss and not scores_fake is not None: |
| hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss) |
| return_dict["G_hinge_fake_loss"] = hinge_fake_loss |
| adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss |
|
|
| |
| if self.use_feat_match_loss and not feats_fake is None: |
| feat_match_loss = self.feat_match_loss(feats_fake, feats_real) |
| return_dict["G_feat_match_loss"] = feat_match_loss |
| adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss |
| return_dict["loss"] = gen_loss + adv_loss |
| return_dict["G_gen_loss"] = gen_loss |
| return_dict["G_adv_loss"] = adv_loss |
| return return_dict |
|
|
|
|
| class DiscriminatorLoss(nn.Module): |
| """Like ```GeneratorLoss```""" |
|
|
| def __init__(self, C): |
| super().__init__() |
| assert not ( |
| C.use_mse_gan_loss and C.use_hinge_gan_loss |
| ), " [!] Cannot use HingeGANLoss and MSEGANLoss together." |
|
|
| self.use_mse_gan_loss = C.use_mse_gan_loss |
| self.use_hinge_gan_loss = C.use_hinge_gan_loss |
|
|
| if C.use_mse_gan_loss: |
| self.mse_loss = MSEDLoss() |
| if C.use_hinge_gan_loss: |
| self.hinge_loss = HingeDLoss() |
|
|
| def forward(self, scores_fake, scores_real): |
| loss = 0 |
| return_dict = {} |
|
|
| if self.use_mse_gan_loss: |
| mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss( |
| scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss |
| ) |
| return_dict["D_mse_gan_loss"] = mse_D_loss |
| return_dict["D_mse_gan_real_loss"] = mse_D_real_loss |
| return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss |
| loss += mse_D_loss |
|
|
| if self.use_hinge_gan_loss: |
| hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss( |
| scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss |
| ) |
| return_dict["D_hinge_gan_loss"] = hinge_D_loss |
| return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss |
| return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss |
| loss += hinge_D_loss |
|
|
| return_dict["loss"] = loss |
| return return_dict |
|
|
|
|
| class WaveRNNLoss(nn.Module): |
| def __init__(self, wave_rnn_mode: Union[str, int]): |
| super().__init__() |
| if wave_rnn_mode == "mold": |
| self.loss_func = discretized_mix_logistic_loss |
| elif wave_rnn_mode == "gauss": |
| self.loss_func = gaussian_loss |
| elif isinstance(wave_rnn_mode, int): |
| self.loss_func = torch.nn.CrossEntropyLoss() |
| else: |
| raise ValueError(" [!] Unknown mode for Wavernn.") |
|
|
| def forward(self, y_hat, y) -> Dict: |
| loss = self.loss_func(y_hat, y) |
| return {"loss": loss} |
|
|