import torch.nn as nn import numpy as np import torch import torch.nn.functional as F from skimage import measure def SoftIoULoss(pred, target): pred = torch.sigmoid(pred) smooth = 1 intersection = pred * target intersection_sum = torch.sum(intersection, dim=(1, 2, 3)) pred_sum = torch.sum(pred, dim=(1, 2, 3)) target_sum = torch.sum(target, dim=(1, 2, 3)) loss = (intersection_sum + smooth) / \ (pred_sum + target_sum - intersection_sum + smooth) loss = 1 - loss.mean() return loss def Dice(pred, target, warm_epoch=1, epoch=1, layer=0): pred = torch.sigmoid(pred) smooth = 1 intersection = pred * target intersection_sum = torch.sum(intersection, dim=(1, 2, 3)) pred_sum = torch.sum(pred, dim=(1, 2, 3)) target_sum = torch.sum(target, dim=(1, 2, 3)) loss = (2 * intersection_sum + smooth) / \ (pred_sum + target_sum + intersection_sum + smooth) loss = 1 - loss.mean() return loss class SLSIoULoss(nn.Module): def __init__(self): super(SLSIoULoss, self).__init__() def forward(self, pred_log, target, warm_epoch, epoch, with_shape=True): pred = torch.sigmoid(pred_log) smooth = 0.0 intersection = pred * target intersection_sum = torch.sum(intersection, dim=(1, 2, 3)) pred_sum = torch.sum(pred, dim=(1, 2, 3)) target_sum = torch.sum(target, dim=(1, 2, 3)) dis = torch.pow((pred_sum - target_sum) / 2, 2) alpha = (torch.min(pred_sum, target_sum) + dis + smooth) / (torch.max(pred_sum, target_sum) + dis + smooth) loss = (intersection_sum + smooth) / \ (pred_sum + target_sum - intersection_sum + smooth) lloss = LLoss(pred, target) if epoch > warm_epoch: siou_loss = alpha * loss if with_shape: loss = 1 - siou_loss.mean() + lloss else: loss = 1 - siou_loss.mean() else: loss = 1 - loss.mean() return loss def LLoss(pred, target): loss = torch.tensor(0.0, requires_grad=True).to(pred) patch_size = pred.shape[0] h = pred.shape[2] w = pred.shape[3] x_index = torch.arange(0, w, 1).view(1, 1, w).repeat((1, h, 1)).to(pred) / w y_index = torch.arange(0, h, 1).view(1, h, 1).repeat((1, 1, w)).to(pred) / h smooth = 1e-8 for i in range(patch_size): pred_centerx = (x_index * pred[i]).mean() pred_centery = (y_index * pred[i]).mean() target_centerx = (x_index * target[i]).mean() target_centery = (y_index * target[i]).mean() angle_loss = (4 / (torch.pi ** 2)) * (torch.square(torch.arctan((pred_centery) / (pred_centerx + smooth)) - torch.arctan( (target_centery) / (target_centerx + smooth)))) pred_length = torch.sqrt(pred_centerx * pred_centerx + pred_centery * pred_centery + smooth) target_length = torch.sqrt(target_centerx * target_centerx + target_centery * target_centery + smooth) length_loss = (torch.min(pred_length, target_length)) / (torch.max(pred_length, target_length) + smooth) loss = loss + (1 - length_loss + angle_loss) / patch_size return loss class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count