|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.autograd import Variable |
|
|
try: |
|
|
from itertools import ifilterfalse |
|
|
except ImportError: |
|
|
from itertools import filterfalse as ifilterfalse |
|
|
|
|
|
class CELoss(nn.Module): |
|
|
def __init__(self, ignore_index=255, reduction='mean'): |
|
|
super(CELoss, self).__init__() |
|
|
|
|
|
self.ignore_index = ignore_index |
|
|
self.criterion = nn.CrossEntropyLoss(reduction=reduction) |
|
|
if not reduction: |
|
|
print("disabled the reduction.") |
|
|
|
|
|
def forward(self, pred, target): |
|
|
loss = self.criterion(pred, target) |
|
|
return loss |
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
def __init__(self, gamma=0, alpha=None, size_average=True): |
|
|
super(FocalLoss, self).__init__() |
|
|
self.gamma = gamma |
|
|
self.alpha = alpha |
|
|
if isinstance(alpha, (float, int)): |
|
|
self.alpha = torch.Tensor([alpha, 1-alpha]) |
|
|
if isinstance(alpha, list): |
|
|
self.alpha = torch.Tensor(alpha) |
|
|
self.size_average = size_average |
|
|
|
|
|
def forward(self, input, target): |
|
|
if input.dim() > 2: |
|
|
|
|
|
input = input.view(input.size(0), input.size(1), -1) |
|
|
|
|
|
|
|
|
input = input.transpose(1, 2) |
|
|
|
|
|
|
|
|
input = input.contiguous().view(-1, input.size(2)) |
|
|
|
|
|
target = target.view(-1, 1) |
|
|
logpt = F.log_softmax(input) |
|
|
logpt = logpt.gather(1, target) |
|
|
logpt = logpt.view(-1) |
|
|
pt = Variable(logpt.data.exp()) |
|
|
|
|
|
if self.alpha is not None: |
|
|
if self.alpha.type() != input.data.type(): |
|
|
self.alpha = self.alpha.type_as(input.data) |
|
|
at = self.alpha.gather(0, target.data.view(-1)) |
|
|
logpt = logpt * Variable(at) |
|
|
|
|
|
loss = -1 * (1-pt)**self.gamma * logpt |
|
|
|
|
|
if self.size_average: |
|
|
return loss.mean() |
|
|
else: |
|
|
return loss.sum() |
|
|
|
|
|
class dice_loss(nn.Module): |
|
|
def __init__(self, eps=1e-7): |
|
|
super(dice_loss, self).__init__() |
|
|
self.eps = eps |
|
|
|
|
|
def forward(self, logits, true): |
|
|
""" |
|
|
Computes the Sørensen–Dice loss. |
|
|
Note that PyTorch optimizers minimize a loss. In this |
|
|
case, we would like to maximize the dice loss so we |
|
|
return the negated dice loss. |
|
|
Args: |
|
|
true: a tensor of shape [B, 1, H, W]. |
|
|
logits: a tensor of shape [B, C, H, W]. Corresponds to |
|
|
the raw output or logits of the model. |
|
|
eps: added to the denominator for numerical stability. |
|
|
Returns: |
|
|
dice_loss: the Sørensen–Dice loss. |
|
|
""" |
|
|
num_classes = logits.shape[1] |
|
|
if num_classes == 1: |
|
|
true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] |
|
|
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() |
|
|
true_1_hot_f = true_1_hot[:, 0:1, :, :] |
|
|
true_1_hot_s = true_1_hot[:, 1:2, :, :] |
|
|
true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) |
|
|
pos_prob = torch.sigmoid(logits) |
|
|
neg_prob = 1 - pos_prob |
|
|
probas = torch.cat([pos_prob, neg_prob], dim=1) |
|
|
else: |
|
|
p = torch.eye(num_classes).cuda() |
|
|
true_1_hot = p[true.squeeze(1)] |
|
|
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() |
|
|
probas = F.softmax(logits, dim=1) |
|
|
true_1_hot = true_1_hot.type(logits.type()) |
|
|
dims = (0,) + tuple(range(2, true.ndimension())) |
|
|
intersection = torch.sum(probas * true_1_hot, dims) |
|
|
cardinality = torch.sum(probas + true_1_hot, dims) |
|
|
dice_loss = (2. * intersection / (cardinality + self.eps)).mean() |
|
|
return (1 - dice_loss) |
|
|
|
|
|
class BCEDICE_loss(nn.Module): |
|
|
def __init__(self): |
|
|
super(BCEDICE_loss, self).__init__() |
|
|
self.bce = torch.nn.BCELoss() |
|
|
|
|
|
def forward(self, target, true): |
|
|
|
|
|
bce_loss = self.bce(target, true.float()) |
|
|
|
|
|
true_u = true.unsqueeze(1) |
|
|
target_u = target.unsqueeze(1) |
|
|
|
|
|
inter = (true * target).sum() |
|
|
eps = 1e-7 |
|
|
dice_loss = (2 * inter + eps) / (true.sum() + target.sum() + eps) |
|
|
|
|
|
return bce_loss + 1 - dice_loss |
|
|
|
|
|
class LOVASZ(nn.Module): |
|
|
def __init__(self): |
|
|
super(LOVASZ, self).__init__() |
|
|
|
|
|
def forward(self, probas, labels): |
|
|
return lovasz_softmax(F.softmax(probas, dim=1), labels) |
|
|
|
|
|
def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): |
|
|
""" |
|
|
Multi-class Lovasz-Softmax loss |
|
|
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). |
|
|
Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. |
|
|
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) |
|
|
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. |
|
|
per_image: compute the loss per image instead of per batch |
|
|
ignore: void class labels |
|
|
""" |
|
|
if per_image: |
|
|
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) |
|
|
for prob, lab in zip(probas, labels)) |
|
|
else: |
|
|
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) |
|
|
return loss |
|
|
|
|
|
|
|
|
def lovasz_softmax_flat(probas, labels, classes='present'): |
|
|
""" |
|
|
Multi-class Lovasz-Softmax loss |
|
|
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) |
|
|
labels: [P] Tensor, ground truth labels (between 0 and C - 1) |
|
|
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. |
|
|
""" |
|
|
if probas.numel() == 0: |
|
|
|
|
|
return probas * 0. |
|
|
C = probas.size(1) |
|
|
losses = [] |
|
|
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes |
|
|
for c in class_to_sum: |
|
|
fg = (labels == c).float() |
|
|
if (classes is 'present' and fg.sum() == 0): |
|
|
continue |
|
|
if C == 1: |
|
|
if len(classes) > 1: |
|
|
raise ValueError('Sigmoid output possible only with 1 class') |
|
|
class_pred = probas[:, 0] |
|
|
else: |
|
|
class_pred = probas[:, c] |
|
|
errors = (Variable(fg) - class_pred).abs() |
|
|
errors_sorted, perm = torch.sort(errors, 0, descending=True) |
|
|
perm = perm.data |
|
|
fg_sorted = fg[perm] |
|
|
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) |
|
|
return mean(losses) |
|
|
|
|
|
def lovasz_grad(gt_sorted): |
|
|
""" |
|
|
Computes gradient of the Lovasz extension w.r.t sorted errors |
|
|
See Alg. 1 in paper |
|
|
""" |
|
|
p = len(gt_sorted) |
|
|
gts = gt_sorted.sum() |
|
|
intersection = gts - gt_sorted.float().cumsum(0) |
|
|
union = gts + (1 - gt_sorted).float().cumsum(0) |
|
|
jaccard = 1. - intersection / union |
|
|
if p > 1: |
|
|
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] |
|
|
return jaccard |
|
|
|
|
|
def flatten_probas(probas, labels, ignore=None): |
|
|
""" |
|
|
Flattens predictions in the batch |
|
|
""" |
|
|
if probas.dim() == 3: |
|
|
|
|
|
B, H, W = probas.size() |
|
|
probas = probas.view(B, 1, H, W) |
|
|
B, C, H, W = probas.size() |
|
|
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) |
|
|
labels = labels.view(-1) |
|
|
if ignore is None: |
|
|
return probas, labels |
|
|
valid = (labels != ignore) |
|
|
vprobas = probas[valid.nonzero().squeeze()] |
|
|
vlabels = labels[valid] |
|
|
return vprobas, vlabels |
|
|
|
|
|
def isnan(x): |
|
|
return x != x |
|
|
|
|
|
|
|
|
def mean(l, ignore_nan=False, empty=0): |
|
|
""" |
|
|
nanmean compatible with generators. |
|
|
""" |
|
|
l = iter(l) |
|
|
if ignore_nan: |
|
|
l = ifilterfalse(isnan, l) |
|
|
try: |
|
|
n = 1 |
|
|
acc = next(l) |
|
|
except StopIteration: |
|
|
if empty == 'raise': |
|
|
raise ValueError('Empty mean') |
|
|
return empty |
|
|
for n, v in enumerate(l, 2): |
|
|
acc += v |
|
|
if n == 1: |
|
|
return acc |
|
|
return acc / n |
|
|
|
|
|
if __name__ == "__main__": |
|
|
predict = torch.randn(4, 2, 10, 10) |
|
|
target = torch.randint(low=0,high=2,size=[4, 10, 10]) |
|
|
func = CELoss() |
|
|
loss = func(predict, target) |
|
|
print(loss) |
|
|
|