Spaces:
Runtime error
Runtime error
| """ | |
| Taken from ESPNet | |
| """ | |
| import torch | |
| from Layers.DurationPredictor import DurationPredictorLoss | |
| from Utility.utils import make_non_pad_mask | |
| class FastSpeech2Loss(torch.nn.Module): | |
| def __init__(self, use_masking=True, use_weighted_masking=False): | |
| """ | |
| use_masking (bool): | |
| Whether to apply masking for padded part in loss calculation. | |
| use_weighted_masking (bool): | |
| Whether to weighted masking in loss calculation. | |
| """ | |
| super().__init__() | |
| assert (use_masking != use_weighted_masking) or not use_masking | |
| self.use_masking = use_masking | |
| self.use_weighted_masking = use_weighted_masking | |
| # define criterions | |
| reduction = "none" if self.use_weighted_masking else "mean" | |
| self.l1_criterion = torch.nn.L1Loss(reduction=reduction) | |
| self.mse_criterion = torch.nn.MSELoss(reduction=reduction) | |
| self.duration_criterion = DurationPredictorLoss(reduction=reduction) | |
| def forward(self, after_outs, before_outs, d_outs, p_outs, e_outs, ys, | |
| ds, ps, es, ilens, olens, ): | |
| """ | |
| Args: | |
| after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). | |
| before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). | |
| d_outs (LongTensor): Batch of outputs of duration predictor (B, Tmax). | |
| p_outs (Tensor): Batch of outputs of pitch predictor (B, Tmax, 1). | |
| e_outs (Tensor): Batch of outputs of energy predictor (B, Tmax, 1). | |
| ys (Tensor): Batch of target features (B, Lmax, odim). | |
| ds (LongTensor): Batch of durations (B, Tmax). | |
| ps (Tensor): Batch of target token-averaged pitch (B, Tmax, 1). | |
| es (Tensor): Batch of target token-averaged energy (B, Tmax, 1). | |
| ilens (LongTensor): Batch of the lengths of each input (B,). | |
| olens (LongTensor): Batch of the lengths of each target (B,). | |
| Returns: | |
| Tensor: L1 loss value. | |
| Tensor: Duration predictor loss value. | |
| Tensor: Pitch predictor loss value. | |
| Tensor: Energy predictor loss value. | |
| """ | |
| # apply mask to remove padded part | |
| if self.use_masking: | |
| out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) | |
| before_outs = before_outs.masked_select(out_masks) | |
| if after_outs is not None: | |
| after_outs = after_outs.masked_select(out_masks) | |
| ys = ys.masked_select(out_masks) | |
| duration_masks = make_non_pad_mask(ilens).to(ys.device) | |
| d_outs = d_outs.masked_select(duration_masks) | |
| ds = ds.masked_select(duration_masks) | |
| pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ys.device) | |
| p_outs = p_outs.masked_select(pitch_masks) | |
| e_outs = e_outs.masked_select(pitch_masks) | |
| ps = ps.masked_select(pitch_masks) | |
| es = es.masked_select(pitch_masks) | |
| # calculate loss | |
| l1_loss = self.l1_criterion(before_outs, ys) | |
| if after_outs is not None: | |
| l1_loss += self.l1_criterion(after_outs, ys) | |
| duration_loss = self.duration_criterion(d_outs, ds) | |
| pitch_loss = self.mse_criterion(p_outs, ps) | |
| energy_loss = self.mse_criterion(e_outs, es) | |
| # make weighted mask and apply it | |
| if self.use_weighted_masking: | |
| out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) | |
| out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() | |
| out_weights /= ys.size(0) * ys.size(2) | |
| duration_masks = make_non_pad_mask(ilens).to(ys.device) | |
| duration_weights = (duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float()) | |
| duration_weights /= ds.size(0) | |
| # apply weight | |
| l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum() | |
| duration_loss = (duration_loss.mul(duration_weights).masked_select(duration_masks).sum()) | |
| pitch_masks = duration_masks.unsqueeze(-1) | |
| pitch_weights = duration_weights.unsqueeze(-1) | |
| pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_masks).sum() | |
| energy_loss = (energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum()) | |
| return l1_loss, duration_loss, pitch_loss, energy_loss | |