InPeerReview's picture
Upload 4 files
3aafbf3 verified
import torch.nn as nn
import numpy as np
import torch
import torch.nn.functional as F
from skimage import measure
def SoftIoULoss(pred, target):
pred = torch.sigmoid(pred)
smooth = 1
intersection = pred * target
intersection_sum = torch.sum(intersection, dim=(1, 2, 3))
pred_sum = torch.sum(pred, dim=(1, 2, 3))
target_sum = torch.sum(target, dim=(1, 2, 3))
loss = (intersection_sum + smooth) / \
(pred_sum + target_sum - intersection_sum + smooth)
loss = 1 - loss.mean()
return loss
def Dice(pred, target, warm_epoch=1, epoch=1, layer=0):
pred = torch.sigmoid(pred)
smooth = 1
intersection = pred * target
intersection_sum = torch.sum(intersection, dim=(1, 2, 3))
pred_sum = torch.sum(pred, dim=(1, 2, 3))
target_sum = torch.sum(target, dim=(1, 2, 3))
loss = (2 * intersection_sum + smooth) / \
(pred_sum + target_sum + intersection_sum + smooth)
loss = 1 - loss.mean()
return loss
class SLSIoULoss(nn.Module):
def __init__(self):
super(SLSIoULoss, self).__init__()
def forward(self, pred_log, target, warm_epoch, epoch, with_shape=True):
pred = torch.sigmoid(pred_log)
smooth = 0.0
intersection = pred * target
intersection_sum = torch.sum(intersection, dim=(1, 2, 3))
pred_sum = torch.sum(pred, dim=(1, 2, 3))
target_sum = torch.sum(target, dim=(1, 2, 3))
dis = torch.pow((pred_sum - target_sum) / 2, 2)
alpha = (torch.min(pred_sum, target_sum) + dis + smooth) / (torch.max(pred_sum, target_sum) + dis + smooth)
loss = (intersection_sum + smooth) / \
(pred_sum + target_sum - intersection_sum + smooth)
lloss = LLoss(pred, target)
if epoch > warm_epoch:
siou_loss = alpha * loss
if with_shape:
loss = 1 - siou_loss.mean() + lloss
else:
loss = 1 - siou_loss.mean()
else:
loss = 1 - loss.mean()
return loss
def LLoss(pred, target):
loss = torch.tensor(0.0, requires_grad=True).to(pred)
patch_size = pred.shape[0]
h = pred.shape[2]
w = pred.shape[3]
x_index = torch.arange(0, w, 1).view(1, 1, w).repeat((1, h, 1)).to(pred) / w
y_index = torch.arange(0, h, 1).view(1, h, 1).repeat((1, 1, w)).to(pred) / h
smooth = 1e-8
for i in range(patch_size):
pred_centerx = (x_index * pred[i]).mean()
pred_centery = (y_index * pred[i]).mean()
target_centerx = (x_index * target[i]).mean()
target_centery = (y_index * target[i]).mean()
angle_loss = (4 / (torch.pi ** 2)) * (torch.square(torch.arctan((pred_centery) / (pred_centerx + smooth))
- torch.arctan(
(target_centery) / (target_centerx + smooth))))
pred_length = torch.sqrt(pred_centerx * pred_centerx + pred_centery * pred_centery + smooth)
target_length = torch.sqrt(target_centerx * target_centerx + target_centery * target_centery + smooth)
length_loss = (torch.min(pred_length, target_length)) / (torch.max(pred_length, target_length) + smooth)
loss = loss + (1 - length_loss + angle_loss) / patch_size
return loss
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count