| | |
| | |
| |
|
| | """ |
| | @Author : Peike Li |
| | @Contact : peike.li@yahoo.com |
| | @File : soft_dice_loss.py |
| | @Time : 8/13/19 5:09 PM |
| | @Desc : |
| | @License : This source code is licensed under the license found in the |
| | LICENSE file in the root directory of this source tree. |
| | """ |
| |
|
| | from __future__ import print_function, division |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| | try: |
| | from itertools import ifilterfalse |
| | except ImportError: |
| | from itertools import filterfalse as ifilterfalse |
| |
|
| |
|
| | def tversky_loss(probas, labels, alpha=0.5, beta=0.5, epsilon=1e-6): |
| | ''' |
| | Tversky loss function. |
| | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) |
| | labels: [P] Tensor, ground truth labels (between 0 and C - 1) |
| | |
| | Same as soft dice loss when alpha=beta=0.5. |
| | Same as Jaccord loss when alpha=beta=1.0. |
| | See `Tversky loss function for image segmentation using 3D fully convolutional deep networks` |
| | https://arxiv.org/pdf/1706.05721.pdf |
| | ''' |
| | C = probas.size(1) |
| | losses = [] |
| | for c in list(range(C)): |
| | fg = (labels == c).float() |
| | if fg.sum() == 0: |
| | continue |
| | class_pred = probas[:, c] |
| | p0 = class_pred |
| | p1 = 1 - class_pred |
| | g0 = fg |
| | g1 = 1 - fg |
| | numerator = torch.sum(p0 * g0) |
| | denominator = numerator + alpha * torch.sum(p0 * g1) + beta * torch.sum(p1 * g0) |
| | losses.append(1 - ((numerator) / (denominator + epsilon))) |
| | return mean(losses) |
| |
|
| |
|
| | def flatten_probas(probas, labels, ignore=255): |
| | """ |
| | Flattens predictions in the batch |
| | """ |
| | B, C, H, W = probas.size() |
| | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) |
| | labels = labels.view(-1) |
| | if ignore is None: |
| | return probas, labels |
| | valid = (labels != ignore) |
| | vprobas = probas[valid.nonzero().squeeze()] |
| | vlabels = labels[valid] |
| | return vprobas, vlabels |
| |
|
| |
|
| | def isnan(x): |
| | return x != x |
| |
|
| |
|
| | def mean(l, ignore_nan=False, empty=0): |
| | """ |
| | nanmean compatible with generators. |
| | """ |
| | l = iter(l) |
| | if ignore_nan: |
| | l = ifilterfalse(isnan, l) |
| | try: |
| | n = 1 |
| | acc = next(l) |
| | except StopIteration: |
| | if empty == 'raise': |
| | raise ValueError('Empty mean') |
| | return empty |
| | for n, v in enumerate(l, 2): |
| | acc += v |
| | if n == 1: |
| | return acc |
| | return acc / n |
| |
|
| |
|
| | class SoftDiceLoss(nn.Module): |
| | def __init__(self, ignore_index=255): |
| | super(SoftDiceLoss, self).__init__() |
| | self.ignore_index = ignore_index |
| |
|
| | def forward(self, pred, label): |
| | pred = F.softmax(pred, dim=1) |
| | return tversky_loss(*flatten_probas(pred, label, ignore=self.ignore_index), alpha=0.5, beta=0.5) |
| |
|
| |
|
| | class SoftJaccordLoss(nn.Module): |
| | def __init__(self, ignore_index=255): |
| | super(SoftJaccordLoss, self).__init__() |
| | self.ignore_index = ignore_index |
| |
|
| | def forward(self, pred, label): |
| | pred = F.softmax(pred, dim=1) |
| | return tversky_loss(*flatten_probas(pred, label, ignore=self.ignore_index), alpha=1.0, beta=1.0) |
| |
|