| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from functools import partial | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class MultiCrossEntropyLoss(nn.Module): | |
| def __init__(self, focal=False, weight=None, reduce=True): | |
| super(MultiCrossEntropyLoss, self).__init__() | |
| self.num_classes = 23 | |
| self.focal = focal | |
| self.weight= weight | |
| self.reduce = reduce | |
| self.gamma_ = torch.zeros(self.num_classes).to(device) + 0.025 | |
| self.gamma_f = 0.05 | |
| self.register_buffer('pos_grad', torch.zeros(self.num_classes-1).to(device)) | |
| self.register_buffer('neg_grad', torch.zeros(self.num_classes-1).to(device)) | |
| self.register_buffer('pos_neg', torch.ones(self.num_classes-1).to(device)) | |
| def forward(self, input, target): | |
| target_sum = torch.sum(target, dim=1) | |
| target_div = torch.where(target_sum != 0, target_sum, torch.ones_like(target_sum)).unsqueeze(1) | |
| target = target / target_div | |
| logsoftmax = nn.LogSoftmax(dim=1).to(input.device) | |
| gamma = self.gamma_.clone() | |
| gamma[:-1] = gamma[:-1] + self.gamma_f * (1 - self.pos_neg) | |
| if not self.focal: | |
| if self.weight is None: | |
| output = torch.sum(-target * logsoftmax(input), 1) | |
| else: | |
| output = torch.sum(-target * logsoftmax(input) / self.weight, 1) | |
| else: | |
| softmax = nn.Softmax(dim=1).to(input.device) | |
| p = softmax(input) | |
| output = torch.sum(-target * (1 - p)**gamma * logsoftmax(input), 1) | |
| if self.reduce: | |
| return torch.mean(output) | |
| else: | |
| return output | |
| def map_func(self, x, s): | |
| min_val = torch.min(x) | |
| max_val = torch.max(x) | |
| mu = torch.mean(x) | |
| x = (x - min_val) / (max_val - min_val) | |
| return 1 / (1 + torch.exp(-s * (x - mu))) | |
| def collect_grad(self, target, grad): | |
| grad = torch.abs(grad.reshape(-1, grad.shape[-1])).to(device) | |
| target = target.reshape(-1, target.shape[-1]).to(device) | |
| pos_grad = torch.sum(grad * target, dim=0)[:-1] | |
| neg_grad = torch.sum(grad * (1 - target), dim=0)[:-1] | |
| self.pos_grad += pos_grad | |
| self.neg_grad += neg_grad | |
| self.pos_neg = torch.clamp(self.pos_grad / (self.neg_grad + 1e-10), min=0, max=1) | |
| self.pos_neg = self.map_func(self.pos_neg, 1) | |
| def cls_loss_func(y, output, use_focal=False, weight=None, reduce=True): | |
| input_size = y.size() | |
| y = y.float().to(device) | |
| if weight is not None: | |
| weight = weight.to(device) | |
| loss_func = MultiCrossEntropyLoss(focal=True, weight=weight, reduce=reduce) | |
| y = y.reshape(-1, y.size(-1)) | |
| output = output.reshape(-1, output.size(-1)) | |
| loss = loss_func(output, y) | |
| if not reduce: | |
| loss = loss.reshape(input_size[:-1]) | |
| return loss | |
| def cls_loss_func_(loss_func, y, output, use_focal=False, weight=None, reduce=True): | |
| input_size = y.size() | |
| y = y.float().to(device) | |
| if weight is not None: | |
| weight = weight.to(device) | |
| y = y.reshape(-1, y.size(-1)) | |
| output = output.reshape(-1, output.size(-1)) | |
| loss = loss_func(output, y) | |
| if not reduce: | |
| loss = loss.reshape(input_size[:-1]) | |
| return loss | |
| def regress_loss_func(y, output): | |
| y = y.float().to(device) | |
| y = y.reshape(-1, y.size(-1)) | |
| output = output.reshape(-1, output.size(-1)) | |
| bgmask = y[:, 1] < -1e2 | |
| fg_logits = output[~bgmask] | |
| bg_logits = output[bgmask] | |
| fg_target = y[~bgmask] | |
| bg_target = y[bgmask] | |
| loss = nn.functional.l1_loss(fg_logits, fg_target) | |
| if loss.isnan(): | |
| return torch.tensor([0.0], requires_grad=True).to(device) | |
| return loss | |
| def suppress_loss_func(y, output): | |
| y = y.float().to(device) | |
| y = y.reshape(-1, y.size(-1)) | |
| output = output.reshape(-1, output.size(-1)) | |
| loss = nn.functional.binary_cross_entropy(output, y) | |
| return loss | |