s3_net / scripts /lovasz_losses.py
zzuxzt's picture
Upload folder using huggingface_hub
d9c5371 verified
#!/usr/bin/python
# -*- encoding: utf-8 -*-
#!/usr/bin/env python
#
# file: $ISIP_EXP/SOGMP/scripts/model.py
#
# revision history: xzt
# 20220824 (TE): first version
#
# usage:
#
# This script hold the loss fucntions for the Lovasz-Softmax loss.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
# grads = {}
##
# version 1: use torch.autograd
class LovaszSoftmax(nn.Module):
'''
This is the autograd version, used in the multi-category classification case
'''
def __init__(self, reduction='mean', ignore_index=-100):
super(LovaszSoftmax, self).__init__()
self.reduction = reduction
self.lb_ignore = ignore_index
def forward(self, logits, label):
'''
Same usage method as nn.CrossEntropyLoss:
>>> criteria = LovaszSoftmax()
>>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
>>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
>>> loss = criteria(logits, lbs)
'''
# overcome ignored label
n, c, h = logits.size()
logits = logits.transpose(0, 1).reshape(c, -1).float() # use fp32 to avoid nan
label = label.view(-1)
idx = label.ne(self.lb_ignore).nonzero(as_tuple=False).squeeze()
probs = logits.softmax(dim=0)[:, idx]
label = label[idx]
lb_one_hot = torch.zeros_like(probs).scatter_(
0, label.unsqueeze(0), 1).detach()
errs = (lb_one_hot - probs).abs()
errs_sort, errs_order = torch.sort(errs, dim=1, descending=True)
n_samples = errs.size(1)
# lovasz extension grad
with torch.no_grad():
# lb_one_hot_sort = lb_one_hot[
# torch.arange(c).unsqueeze(1).repeat(1, n_samples), errs_order
# ].detach()
lb_one_hot_sort = torch.cat([
lb_one_hot[i, ord].unsqueeze(0)
for i, ord in enumerate(errs_order)], dim=0)
n_pos = lb_one_hot_sort.sum(dim=1, keepdim=True)
inter = n_pos - lb_one_hot_sort.cumsum(dim=1)
union = n_pos + (1. - lb_one_hot_sort).cumsum(dim=1)
jacc = 1. - inter / union
if n_samples > 1:
jacc[:, 1:] = jacc[:, 1:] - jacc[:, :-1]
losses = torch.einsum('ab,ab->a', errs_sort, jacc)
if self.reduction == 'sum':
losses = losses.sum()
elif self.reduction == 'mean':
losses = losses.mean()
return losses, errs