SAT / loss_func.py
Darknsu's picture
Update loss_func.py
93e84b8 verified
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from functools import partial
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MultiCrossEntropyLoss(nn.Module):
def __init__(self, focal=False, weight=None, reduce=True):
super(MultiCrossEntropyLoss, self).__init__()
self.num_classes = 23
self.focal = focal
self.weight= weight
self.reduce = reduce
self.gamma_ = torch.zeros(self.num_classes).to(device) + 0.025
self.gamma_f = 0.05
self.register_buffer('pos_grad', torch.zeros(self.num_classes-1).to(device))
self.register_buffer('neg_grad', torch.zeros(self.num_classes-1).to(device))
self.register_buffer('pos_neg', torch.ones(self.num_classes-1).to(device))
def forward(self, input, target):
target_sum = torch.sum(target, dim=1)
target_div = torch.where(target_sum != 0, target_sum, torch.ones_like(target_sum)).unsqueeze(1)
target = target / target_div
logsoftmax = nn.LogSoftmax(dim=1).to(input.device)
gamma = self.gamma_.clone()
gamma[:-1] = gamma[:-1] + self.gamma_f * (1 - self.pos_neg)
if not self.focal:
if self.weight is None:
output = torch.sum(-target * logsoftmax(input), 1)
else:
output = torch.sum(-target * logsoftmax(input) / self.weight, 1)
else:
softmax = nn.Softmax(dim=1).to(input.device)
p = softmax(input)
output = torch.sum(-target * (1 - p)**gamma * logsoftmax(input), 1)
if self.reduce:
return torch.mean(output)
else:
return output
def map_func(self, x, s):
min_val = torch.min(x)
max_val = torch.max(x)
mu = torch.mean(x)
x = (x - min_val) / (max_val - min_val)
return 1 / (1 + torch.exp(-s * (x - mu)))
def collect_grad(self, target, grad):
grad = torch.abs(grad.reshape(-1, grad.shape[-1])).to(device)
target = target.reshape(-1, target.shape[-1]).to(device)
pos_grad = torch.sum(grad * target, dim=0)[:-1]
neg_grad = torch.sum(grad * (1 - target), dim=0)[:-1]
self.pos_grad += pos_grad
self.neg_grad += neg_grad
self.pos_neg = torch.clamp(self.pos_grad / (self.neg_grad + 1e-10), min=0, max=1)
self.pos_neg = self.map_func(self.pos_neg, 1)
def cls_loss_func(y, output, use_focal=False, weight=None, reduce=True):
input_size = y.size()
y = y.float().to(device)
if weight is not None:
weight = weight.to(device)
loss_func = MultiCrossEntropyLoss(focal=True, weight=weight, reduce=reduce)
y = y.reshape(-1, y.size(-1))
output = output.reshape(-1, output.size(-1))
loss = loss_func(output, y)
if not reduce:
loss = loss.reshape(input_size[:-1])
return loss
def cls_loss_func_(loss_func, y, output, use_focal=False, weight=None, reduce=True):
input_size = y.size()
y = y.float().to(device)
if weight is not None:
weight = weight.to(device)
y = y.reshape(-1, y.size(-1))
output = output.reshape(-1, output.size(-1))
loss = loss_func(output, y)
if not reduce:
loss = loss.reshape(input_size[:-1])
return loss
def regress_loss_func(y, output):
y = y.float().to(device)
y = y.reshape(-1, y.size(-1))
output = output.reshape(-1, output.size(-1))
bgmask = y[:, 1] < -1e2
fg_logits = output[~bgmask]
bg_logits = output[bgmask]
fg_target = y[~bgmask]
bg_target = y[bgmask]
loss = nn.functional.l1_loss(fg_logits, fg_target)
if loss.isnan():
return torch.tensor([0.0], requires_grad=True).to(device)
return loss
def suppress_loss_func(y, output):
y = y.float().to(device)
y = y.reshape(-1, y.size(-1))
output = output.reshape(-1, output.size(-1))
loss = nn.functional.binary_cross_entropy(output, y)
return loss