| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| import torch |
| from torch.nn import functional as F |
|
|
|
|
| class FocalLoss(torch.nn.Module): |
| """Multi-class Focal loss implementation""" |
|
|
| def __init__(self, gamma=2, weight=None, ignore_index=-100): |
| super(FocalLoss, self).__init__() |
| self.gamma = gamma |
| self.weight = weight |
| self.ignore_index = ignore_index |
|
|
| def forward(self, input, target): |
| """ |
| input: [N, C] |
| target: [N, ] |
| """ |
| logpt = F.log_softmax(input, dim=1) |
| pt = torch.exp(logpt) |
| logpt = (1-pt)**self.gamma * logpt |
| loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index) |
| return loss |
|
|
| |
|
|
|
|
| class LabelSmoothingCorrectionCrossEntropy(torch.nn.Module): |
| def __init__(self, eps=0.1, reduction='mean', ignore_index=-100): |
| super(LabelSmoothingCorrectionCrossEntropy, self).__init__() |
| self.eps = eps |
| self.reduction = reduction |
| self.ignore_index = ignore_index |
|
|
| def forward(self, output, target): |
| c = output.size()[-1] |
| log_preds = F.log_softmax(output, dim=-1) |
| if self.reduction == 'sum': |
| loss = -log_preds.sum() |
| else: |
| loss = -log_preds.sum(dim=-1) |
| if self.reduction == 'mean': |
| loss = loss.mean() |
|
|
| |
| labels_hat = torch.argmax(output, dim=1) |
| lt_sum = labels_hat + target |
| abs_lt_sub = abs(labels_hat - target) |
| correction_loss = 0 |
| for i in range(c): |
| if lt_sum[i] == 0: |
| pass |
| elif lt_sum[i] == 1: |
| if abs_lt_sub[i] == 1: |
| pass |
| else: |
| correction_loss -= self.eps*(0.5945275813408382) |
| else: |
| correction_loss += self.eps*(1/0.32447699714575207) |
| correction_loss /= c |
| |
| return loss*self.eps/c + (1-self.eps) * \ |
| F.nll_loss(log_preds, target, reduction=self.reduction, ignore_index=self.ignore_index) + correction_loss |
|
|