import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable try: from itertools import ifilterfalse except ImportError: # py3k from itertools import filterfalse as ifilterfalse class CELoss(nn.Module): def __init__(self, ignore_index=255, reduction='mean'): super(CELoss, self).__init__() self.ignore_index = ignore_index self.criterion = nn.CrossEntropyLoss(reduction=reduction) if not reduction: print("disabled the reduction.") def forward(self, pred, target): loss = self.criterion(pred, target) return loss class FocalLoss(nn.Module): def __init__(self, gamma=0, alpha=None, size_average=True): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1-alpha]) if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) self.size_average = size_average def forward(self, input, target): if input.dim() > 2: # N,C,H,W => N,C,H*W input = input.view(input.size(0), input.size(1), -1) # N,C,H*W => N,H*W,C input = input.transpose(1, 2) # N,H*W,C => N*H*W,C input = input.contiguous().view(-1, input.size(2)) target = target.view(-1, 1) logpt = F.log_softmax(input) logpt = logpt.gather(1, target) logpt = logpt.view(-1) pt = Variable(logpt.data.exp()) if self.alpha is not None: if self.alpha.type() != input.data.type(): self.alpha = self.alpha.type_as(input.data) at = self.alpha.gather(0, target.data.view(-1)) logpt = logpt * Variable(at) loss = -1 * (1-pt)**self.gamma * logpt if self.size_average: return loss.mean() else: return loss.sum() class dice_loss(nn.Module): def __init__(self, eps=1e-7): super(dice_loss, self).__init__() self.eps = eps def forward(self, logits, true): """ Computes the Sørensen–Dice loss. Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we return the negated dice loss. Args: true: a tensor of shape [B, 1, H, W]. logits: a tensor of shape [B, C, H, W]. Corresponds to the raw output or logits of the model. eps: added to the denominator for numerical stability. Returns: dice_loss: the Sørensen–Dice loss. """ num_classes = logits.shape[1] if num_classes == 1: true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() true_1_hot_f = true_1_hot[:, 0:1, :, :] true_1_hot_s = true_1_hot[:, 1:2, :, :] true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) pos_prob = torch.sigmoid(logits) neg_prob = 1 - pos_prob probas = torch.cat([pos_prob, neg_prob], dim=1) else: p = torch.eye(num_classes).cuda() true_1_hot = p[true.squeeze(1)] true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() probas = F.softmax(logits, dim=1) true_1_hot = true_1_hot.type(logits.type()) dims = (0,) + tuple(range(2, true.ndimension())) intersection = torch.sum(probas * true_1_hot, dims) cardinality = torch.sum(probas + true_1_hot, dims) dice_loss = (2. * intersection / (cardinality + self.eps)).mean() return (1 - dice_loss) class BCEDICE_loss(nn.Module): def __init__(self): super(BCEDICE_loss, self).__init__() self.bce = torch.nn.BCELoss() def forward(self, target, true): bce_loss = self.bce(target, true.float()) true_u = true.unsqueeze(1) target_u = target.unsqueeze(1) inter = (true * target).sum() eps = 1e-7 dice_loss = (2 * inter + eps) / (true.sum() + target.sum() + eps) return bce_loss + 1 - dice_loss class LOVASZ(nn.Module): def __init__(self): super(LOVASZ, self).__init__() def forward(self, probas, labels): return lovasz_softmax(F.softmax(probas, dim=1), labels) def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): """ Multi-class Lovasz-Softmax loss probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. per_image: compute the loss per image instead of per batch ignore: void class labels """ if per_image: loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) for prob, lab in zip(probas, labels)) else: loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) return loss def lovasz_softmax_flat(probas, labels, classes='present'): """ Multi-class Lovasz-Softmax loss probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) labels: [P] Tensor, ground truth labels (between 0 and C - 1) classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. """ if probas.numel() == 0: # only void pixels, the gradients should be 0 return probas * 0. C = probas.size(1) losses = [] class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes for c in class_to_sum: fg = (labels == c).float() # foreground for class c if (classes is 'present' and fg.sum() == 0): continue if C == 1: if len(classes) > 1: raise ValueError('Sigmoid output possible only with 1 class') class_pred = probas[:, 0] else: class_pred = probas[:, c] errors = (Variable(fg) - class_pred).abs() errors_sorted, perm = torch.sort(errors, 0, descending=True) perm = perm.data fg_sorted = fg[perm] losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) return mean(losses) def lovasz_grad(gt_sorted): """ Computes gradient of the Lovasz extension w.r.t sorted errors See Alg. 1 in paper """ p = len(gt_sorted) gts = gt_sorted.sum() intersection = gts - gt_sorted.float().cumsum(0) union = gts + (1 - gt_sorted).float().cumsum(0) jaccard = 1. - intersection / union if p > 1: # cover 1-pixel case jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] return jaccard def flatten_probas(probas, labels, ignore=None): """ Flattens predictions in the batch """ if probas.dim() == 3: # assumes output of a sigmoid layer B, H, W = probas.size() probas = probas.view(B, 1, H, W) B, C, H, W = probas.size() probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C labels = labels.view(-1) if ignore is None: return probas, labels valid = (labels != ignore) vprobas = probas[valid.nonzero().squeeze()] vlabels = labels[valid] return vprobas, vlabels def isnan(x): return x != x def mean(l, ignore_nan=False, empty=0): """ nanmean compatible with generators. """ l = iter(l) if ignore_nan: l = ifilterfalse(isnan, l) try: n = 1 acc = next(l) except StopIteration: if empty == 'raise': raise ValueError('Empty mean') return empty for n, v in enumerate(l, 2): acc += v if n == 1: return acc return acc / n if __name__ == "__main__": predict = torch.randn(4, 2, 10, 10) target = torch.randint(low=0,high=2,size=[4, 10, 10]) func = CELoss() loss = func(predict, target) print(loss)