Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import torch.nn.functional as F | |
| def log_dur_loss(dur_pred_log, dur_target, mask, loss_type="l1"): | |
| # dur_pred_log: (B, N) | |
| # dur_target: (B, N) | |
| # mask: (B, N) mask is 0 | |
| dur_target_log = torch.log(1 + dur_target) | |
| if loss_type == "l1": | |
| loss = F.l1_loss( | |
| dur_pred_log, dur_target_log, reduction="none" | |
| ).float() * mask.to(dur_target.dtype) | |
| elif loss_type == "l2": | |
| loss = F.mse_loss( | |
| dur_pred_log, dur_target_log, reduction="none" | |
| ).float() * mask.to(dur_target.dtype) | |
| else: | |
| raise NotImplementedError() | |
| loss = loss.sum() / (mask.to(dur_target.dtype).sum()) | |
| return loss | |
| def log_pitch_loss(pitch_pred_log, pitch_target, mask, loss_type="l1"): | |
| pitch_target_log = torch.log(pitch_target) | |
| if loss_type == "l1": | |
| loss = F.l1_loss( | |
| pitch_pred_log, pitch_target_log, reduction="none" | |
| ).float() * mask.to(pitch_target.dtype) | |
| elif loss_type == "l2": | |
| loss = F.mse_loss( | |
| pitch_pred_log, pitch_target_log, reduction="none" | |
| ).float() * mask.to(pitch_target.dtype) | |
| else: | |
| raise NotImplementedError() | |
| loss = loss.sum() / (mask.to(pitch_target.dtype).sum() + 1e-8) | |
| return loss | |
| def diff_loss(pred, target, mask, loss_type="l1"): | |
| # pred: (B, d, T) | |
| # target: (B, d, T) | |
| # mask: (B, T) | |
| if loss_type == "l1": | |
| loss = F.l1_loss(pred, target, reduction="none").float() * ( | |
| mask.to(pred.dtype).unsqueeze(1) | |
| ) | |
| elif loss_type == "l2": | |
| loss = F.mse_loss(pred, target, reduction="none").float() * ( | |
| mask.to(pred.dtype).unsqueeze(1) | |
| ) | |
| else: | |
| raise NotImplementedError() | |
| loss = (torch.mean(loss, dim=1)).sum() / (mask.to(pred.dtype).sum()) | |
| return loss | |
| def diff_ce_loss(pred_dist, gt_indices, mask): | |
| # pred_dist: (nq, B, T, 1024) | |
| # gt_indices: (nq, B, T) | |
| pred_dist = pred_dist.permute(1, 3, 0, 2) # (B, 1024, nq, T) | |
| gt_indices = gt_indices.permute(1, 0, 2).long() # (B, nq, T) | |
| loss = F.cross_entropy( | |
| pred_dist, gt_indices, reduction="none" | |
| ).float() # (B, nq, T) | |
| loss = loss * mask.to(loss.dtype).unsqueeze(1) | |
| loss = (torch.mean(loss, dim=1)).sum() / (mask.to(loss.dtype).sum()) | |
| return loss | |