Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from torchvision.ops import sigmoid_focal_loss | |
| class CCCLoss(nn.Module): | |
| '''concordance correlation coefficient loss''' | |
| def __init__(self, eps=1e-7): | |
| ''' | |
| Args: | |
| eps (float, optional): stabilizing term | |
| ''' | |
| super(CCCLoss, self).__init__() | |
| self.eps = eps | |
| def forward(self, y_hat, y): | |
| gold_mean = torch.mean(y.T, dim=-1, keepdim=True) | |
| pred_mean = torch.mean(y_hat.T, dim=-1, keepdim=True) | |
| covariance = (y.T-gold_mean)*(y_hat.T-pred_mean) | |
| gold_var = torch.mean(torch.square(y.T-gold_mean), dim=-1, keepdim=True) | |
| pred_var = torch.mean(torch.square(y_hat.T-pred_mean), dim=-1, keepdim=True) | |
| ccc = 2 * covariance / (gold_var + pred_var + torch.square(gold_mean - pred_mean) + self.eps) | |
| return torch.mean(1-ccc, dim=-1) | |
| # return torch.mean(torch.mean(1-ccc, dim=-1)) | |
| class SigmoidFocalLoss(nn.Module): | |
| def __init__(self, reduction=None): | |
| super(SigmoidFocalLoss, self).__init__() | |
| self.reduction = reduction | |
| def forward(self, y_hat , y): | |
| loss = sigmoid_focal_loss(y_hat, y, reduction=self.reduction) | |
| return loss | |
| class StutterLoss(nn.Module): | |
| '''SEP-28k Loss ''' | |
| def __init__(self, alpha=1, beta=1, stutter_weights=None, reduction='mean'): | |
| super(StutterLoss, self).__init__() | |
| self.stutter_loss = CCCLoss() | |
| self.disfluency_loss = SigmoidFocalLoss(reduction=reduction) | |
| self.alpha = alpha | |
| self.beta = beta | |
| self.stutter_weights = stutter_weights | |
| if (isinstance(self.stutter_weights, torch.Tensor)): | |
| self.stutter_weights = self.stutter_weights.reshape((1,-1)) | |
| def forward(self, y_hat , y): | |
| '''expects list of inputs and outputs''' | |
| y_class, y_bin = torch.split(y, [6,6], dim=-1) | |
| y_hat_class, y_hat_bin = torch.split(y_hat, [6,6], dim=-1) | |
| disfluency_loss = self.disfluency_loss(y_hat_class, y_class) | |
| stutter_loss = torch.mean(self.stutter_loss(y_hat_bin, y_bin)) | |
| if (not isinstance(self.stutter_weights, torch.Tensor)): | |
| return self.alpha * stutter_loss + self.beta * torch.mean(disfluency_loss, dim=0) | |
| return self.alpha * stutter_loss + self.beta * self.stutter_weights@disfluency_loss(y_hat_class, y_class) |