| | |
| | |
| |
|
| | """ |
| | @Author : Peike Li |
| | @Contact : peike.li@yahoo.com |
| | @File : lovasz_softmax.py |
| | @Time : 8/30/19 7:12 PM |
| | @Desc : Lovasz-Softmax and Jaccard hinge loss in PyTorch |
| | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) |
| | @License : This source code is licensed under the license found in the |
| | LICENSE file in the root directory of this source tree. |
| | """ |
| |
|
| | from __future__ import print_function, division |
| |
|
| | import torch |
| | from torch.autograd import Variable |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from torch import nn |
| |
|
| | try: |
| | from itertools import ifilterfalse |
| | except ImportError: |
| | from itertools import filterfalse as ifilterfalse |
| |
|
| |
|
| | def lovasz_grad(gt_sorted): |
| | """ |
| | Computes gradient of the Lovasz extension w.r.t sorted errors |
| | See Alg. 1 in paper |
| | """ |
| | p = len(gt_sorted) |
| | gts = gt_sorted.sum() |
| | intersection = gts - gt_sorted.float().cumsum(0) |
| | union = gts + (1 - gt_sorted).float().cumsum(0) |
| | jaccard = 1. - intersection / union |
| | if p > 1: |
| | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] |
| | return jaccard |
| |
|
| |
|
| | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): |
| | """ |
| | IoU for foreground class |
| | binary: 1 foreground, 0 background |
| | """ |
| | if not per_image: |
| | preds, labels = (preds,), (labels,) |
| | ious = [] |
| | for pred, label in zip(preds, labels): |
| | intersection = ((label == 1) & (pred == 1)).sum() |
| | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() |
| | if not union: |
| | iou = EMPTY |
| | else: |
| | iou = float(intersection) / float(union) |
| | ious.append(iou) |
| | iou = mean(ious) |
| | return 100 * iou |
| |
|
| |
|
| | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): |
| | """ |
| | Array of IoU for each (non ignored) class |
| | """ |
| | if not per_image: |
| | preds, labels = (preds,), (labels,) |
| | ious = [] |
| | for pred, label in zip(preds, labels): |
| | iou = [] |
| | for i in range(C): |
| | if i != ignore: |
| | intersection = ((label == i) & (pred == i)).sum() |
| | union = ((label == i) | ((pred == i) & (label != ignore))).sum() |
| | if not union: |
| | iou.append(EMPTY) |
| | else: |
| | iou.append(float(intersection) / float(union)) |
| | ious.append(iou) |
| | ious = [mean(iou) for iou in zip(*ious)] |
| | return 100 * np.array(ious) |
| |
|
| |
|
| | |
| |
|
| |
|
| | def lovasz_hinge(logits, labels, per_image=True, ignore=None): |
| | """ |
| | Binary Lovasz hinge loss |
| | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) |
| | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) |
| | per_image: compute the loss per image instead of per batch |
| | ignore: void class id |
| | """ |
| | if per_image: |
| | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) |
| | for log, lab in zip(logits, labels)) |
| | else: |
| | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) |
| | return loss |
| |
|
| |
|
| | def lovasz_hinge_flat(logits, labels): |
| | """ |
| | Binary Lovasz hinge loss |
| | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) |
| | labels: [P] Tensor, binary ground truth labels (0 or 1) |
| | ignore: label to ignore |
| | """ |
| | if len(labels) == 0: |
| | |
| | return logits.sum() * 0. |
| | signs = 2. * labels.float() - 1. |
| | errors = (1. - logits * Variable(signs)) |
| | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) |
| | perm = perm.data |
| | gt_sorted = labels[perm] |
| | grad = lovasz_grad(gt_sorted) |
| | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) |
| | return loss |
| |
|
| |
|
| | def flatten_binary_scores(scores, labels, ignore=None): |
| | """ |
| | Flattens predictions in the batch (binary case) |
| | Remove labels equal to 'ignore' |
| | """ |
| | scores = scores.view(-1) |
| | labels = labels.view(-1) |
| | if ignore is None: |
| | return scores, labels |
| | valid = (labels != ignore) |
| | vscores = scores[valid] |
| | vlabels = labels[valid] |
| | return vscores, vlabels |
| |
|
| |
|
| | class StableBCELoss(torch.nn.modules.Module): |
| | def __init__(self): |
| | super(StableBCELoss, self).__init__() |
| |
|
| | def forward(self, input, target): |
| | neg_abs = - input.abs() |
| | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() |
| | return loss.mean() |
| |
|
| |
|
| | def binary_xloss(logits, labels, ignore=None): |
| | """ |
| | Binary Cross entropy loss |
| | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) |
| | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) |
| | ignore: void class id |
| | """ |
| | logits, labels = flatten_binary_scores(logits, labels, ignore) |
| | loss = StableBCELoss()(logits, Variable(labels.float())) |
| | return loss |
| |
|
| |
|
| | |
| |
|
| |
|
| | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=255, weighted=None): |
| | """ |
| | Multi-class Lovasz-Softmax loss |
| | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). |
| | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. |
| | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) |
| | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. |
| | per_image: compute the loss per image instead of per batch |
| | ignore: void class labels |
| | """ |
| | if per_image: |
| | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes, weighted=weighted) |
| | for prob, lab in zip(probas, labels)) |
| | else: |
| | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes, weighted=weighted ) |
| | return loss |
| |
|
| |
|
| | def lovasz_softmax_flat(probas, labels, classes='present', weighted=None): |
| | """ |
| | Multi-class Lovasz-Softmax loss |
| | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) |
| | labels: [P] Tensor, ground truth labels (between 0 and C - 1) |
| | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. |
| | """ |
| | if probas.numel() == 0: |
| | |
| | return probas * 0. |
| | C = probas.size(1) |
| | losses = [] |
| | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes |
| | for c in class_to_sum: |
| | fg = (labels == c).float() |
| | if (classes is 'present' and fg.sum() == 0): |
| | continue |
| | if C == 1: |
| | if len(classes) > 1: |
| | raise ValueError('Sigmoid output possible only with 1 class') |
| | class_pred = probas[:, 0] |
| | else: |
| | class_pred = probas[:, c] |
| | errors = (Variable(fg) - class_pred).abs() |
| | errors_sorted, perm = torch.sort(errors, 0, descending=True) |
| | perm = perm.data |
| | fg_sorted = fg[perm] |
| | if weighted is not None: |
| | losses.append(weighted[c]*torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) |
| | else: |
| | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) |
| | return mean(losses) |
| |
|
| |
|
| | def flatten_probas(probas, labels, ignore=None): |
| | """ |
| | Flattens predictions in the batch |
| | """ |
| | if probas.dim() == 3: |
| | |
| | B, H, W = probas.size() |
| | probas = probas.view(B, 1, H, W) |
| | B, C, H, W = probas.size() |
| | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) |
| | labels = labels.view(-1) |
| | if ignore is None: |
| | return probas, labels |
| | valid = (labels != ignore) |
| | vprobas = probas[valid.nonzero().squeeze()] |
| | vlabels = labels[valid] |
| | return vprobas, vlabels |
| |
|
| |
|
| | def xloss(logits, labels, ignore=None): |
| | """ |
| | Cross entropy loss |
| | """ |
| | return F.cross_entropy(logits, Variable(labels), ignore_index=255) |
| |
|
| |
|
| | |
| | def isnan(x): |
| | return x != x |
| |
|
| |
|
| | def mean(l, ignore_nan=False, empty=0): |
| | """ |
| | nanmean compatible with generators. |
| | """ |
| | l = iter(l) |
| | if ignore_nan: |
| | l = ifilterfalse(isnan, l) |
| | try: |
| | n = 1 |
| | acc = next(l) |
| | except StopIteration: |
| | if empty == 'raise': |
| | raise ValueError('Empty mean') |
| | return empty |
| | for n, v in enumerate(l, 2): |
| | acc += v |
| | if n == 1: |
| | return acc |
| | return acc / n |
| |
|
| | |
| | class LovaszSoftmax(nn.Module): |
| | def __init__(self, per_image=False, ignore_index=255, weighted=None): |
| | super(LovaszSoftmax, self).__init__() |
| | self.lovasz_softmax = lovasz_softmax |
| | self.per_image = per_image |
| | self.ignore_index=ignore_index |
| | self.weighted = weighted |
| |
|
| | def forward(self, pred, label): |
| | pred = F.softmax(pred, dim=1) |
| | return self.lovasz_softmax(pred, label, per_image=self.per_image, ignore=self.ignore_index, weighted=self.weighted) |