InPeerReview's picture
Upload 161 files
226675b verified
import torch
import torch.nn as nn
class dice_loss(nn.Module):
def __init__(self, batch=True):
super(dice_loss, self).__init__()
# batch equal to True means views all batch images as an entity and calculate loss
# batch equal to False means calculate loss of every single image in batch and get their mean
self.batch = batch
def soft_dice_coeff(self, y_pred, y_true):
smooth = 0.00001
if self.batch:
i = torch.sum(y_true)
j = torch.sum(y_pred)
intersection = torch.sum(y_true * y_pred)
else:
i = y_true.sum(1).sum(1).sum(1)
j = y_pred.sum(1).sum(1).sum(1)
intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
score = (2. * intersection + smooth) / (i + j + smooth)
return score.mean()
def soft_dice_loss(self, y_pred, y_true):
loss = 1 - self.soft_dice_coeff(y_pred, y_true)
return loss
def __call__(self, y_pred, y_true):
return self.soft_dice_loss(y_pred.to(dtype=torch.float32), y_true)
class dice_focal_loss(nn.Module):
def __init__(self):
super(dice_focal_loss, self).__init__()
self.focal_loss = nn.BCEWithLogitsLoss()
self.binnary_dice = dice_loss()
def __call__(self, scores, labels):
diceloss = self.binnary_dice(torch.sigmoid(scores.clone()), labels)
foclaloss = self.focal_loss(scores.clone(), labels)
return diceloss, foclaloss
class FCCDN_loss_without_seg(nn.Module):
def __init__(self):
super(FCCDN_loss_without_seg, self).__init__()
def __call__(self, scores, labels):
# scores = change_pred
# labels = binary_cd_labels
scores = scores.squeeze(1) if len(scores.shape) > 3 else scores
labels = labels.squeeze(1) if len(labels.shape) > 3 else labels
# if len(scores.shape) > 3:
# scores = scores.squeeze(1)
# if len(labels.shape) > 3:
# labels = labels.squeeze(1)
""" for binary change detection task"""
criterion_change = dice_focal_loss()
# change loss
diceloss, foclaloss = criterion_change(scores, labels.float())
loss_change = diceloss + foclaloss
return loss_change.mean()