File size: 3,675 Bytes
3aafbf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import torch.nn as nn
import numpy as np
import torch
import torch.nn.functional as F
from skimage import measure
def SoftIoULoss(pred, target):
pred = torch.sigmoid(pred)
smooth = 1
intersection = pred * target
intersection_sum = torch.sum(intersection, dim=(1, 2, 3))
pred_sum = torch.sum(pred, dim=(1, 2, 3))
target_sum = torch.sum(target, dim=(1, 2, 3))
loss = (intersection_sum + smooth) / \
(pred_sum + target_sum - intersection_sum + smooth)
loss = 1 - loss.mean()
return loss
def Dice(pred, target, warm_epoch=1, epoch=1, layer=0):
pred = torch.sigmoid(pred)
smooth = 1
intersection = pred * target
intersection_sum = torch.sum(intersection, dim=(1, 2, 3))
pred_sum = torch.sum(pred, dim=(1, 2, 3))
target_sum = torch.sum(target, dim=(1, 2, 3))
loss = (2 * intersection_sum + smooth) / \
(pred_sum + target_sum + intersection_sum + smooth)
loss = 1 - loss.mean()
return loss
class SLSIoULoss(nn.Module):
def __init__(self):
super(SLSIoULoss, self).__init__()
def forward(self, pred_log, target, warm_epoch, epoch, with_shape=True):
pred = torch.sigmoid(pred_log)
smooth = 0.0
intersection = pred * target
intersection_sum = torch.sum(intersection, dim=(1, 2, 3))
pred_sum = torch.sum(pred, dim=(1, 2, 3))
target_sum = torch.sum(target, dim=(1, 2, 3))
dis = torch.pow((pred_sum - target_sum) / 2, 2)
alpha = (torch.min(pred_sum, target_sum) + dis + smooth) / (torch.max(pred_sum, target_sum) + dis + smooth)
loss = (intersection_sum + smooth) / \
(pred_sum + target_sum - intersection_sum + smooth)
lloss = LLoss(pred, target)
if epoch > warm_epoch:
siou_loss = alpha * loss
if with_shape:
loss = 1 - siou_loss.mean() + lloss
else:
loss = 1 - siou_loss.mean()
else:
loss = 1 - loss.mean()
return loss
def LLoss(pred, target):
loss = torch.tensor(0.0, requires_grad=True).to(pred)
patch_size = pred.shape[0]
h = pred.shape[2]
w = pred.shape[3]
x_index = torch.arange(0, w, 1).view(1, 1, w).repeat((1, h, 1)).to(pred) / w
y_index = torch.arange(0, h, 1).view(1, h, 1).repeat((1, 1, w)).to(pred) / h
smooth = 1e-8
for i in range(patch_size):
pred_centerx = (x_index * pred[i]).mean()
pred_centery = (y_index * pred[i]).mean()
target_centerx = (x_index * target[i]).mean()
target_centery = (y_index * target[i]).mean()
angle_loss = (4 / (torch.pi ** 2)) * (torch.square(torch.arctan((pred_centery) / (pred_centerx + smooth))
- torch.arctan(
(target_centery) / (target_centerx + smooth))))
pred_length = torch.sqrt(pred_centerx * pred_centerx + pred_centery * pred_centery + smooth)
target_length = torch.sqrt(target_centerx * target_centerx + target_centery * target_centery + smooth)
length_loss = (torch.min(pred_length, target_length)) / (torch.max(pred_length, target_length) + smooth)
loss = loss + (1 - length_loss + angle_loss) / patch_size
return loss
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count |