|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class dice_loss(nn.Module): |
|
|
def __init__(self, batch=True): |
|
|
super(dice_loss, self).__init__() |
|
|
|
|
|
|
|
|
self.batch = batch |
|
|
|
|
|
def soft_dice_coeff(self, y_pred, y_true): |
|
|
smooth = 0.00001 |
|
|
if self.batch: |
|
|
i = torch.sum(y_true) |
|
|
j = torch.sum(y_pred) |
|
|
intersection = torch.sum(y_true * y_pred) |
|
|
else: |
|
|
i = y_true.sum(1).sum(1).sum(1) |
|
|
j = y_pred.sum(1).sum(1).sum(1) |
|
|
intersection = (y_true * y_pred).sum(1).sum(1).sum(1) |
|
|
|
|
|
score = (2. * intersection + smooth) / (i + j + smooth) |
|
|
return score.mean() |
|
|
|
|
|
def soft_dice_loss(self, y_pred, y_true): |
|
|
loss = 1 - self.soft_dice_coeff(y_pred, y_true) |
|
|
return loss |
|
|
|
|
|
def __call__(self, y_pred, y_true): |
|
|
return self.soft_dice_loss(y_pred.to(dtype=torch.float32), y_true) |
|
|
|
|
|
|
|
|
class dice_focal_loss(nn.Module): |
|
|
|
|
|
def __init__(self): |
|
|
super(dice_focal_loss, self).__init__() |
|
|
self.focal_loss = nn.BCEWithLogitsLoss() |
|
|
self.binnary_dice = dice_loss() |
|
|
|
|
|
def __call__(self, scores, labels): |
|
|
diceloss = self.binnary_dice(torch.sigmoid(scores.clone()), labels) |
|
|
foclaloss = self.focal_loss(scores.clone(), labels) |
|
|
|
|
|
return diceloss, foclaloss |
|
|
|
|
|
class FCCDN_loss_without_seg(nn.Module): |
|
|
def __init__(self): |
|
|
super(FCCDN_loss_without_seg, self).__init__() |
|
|
|
|
|
def __call__(self, scores, labels): |
|
|
|
|
|
|
|
|
scores = scores.squeeze(1) if len(scores.shape) > 3 else scores |
|
|
labels = labels.squeeze(1) if len(labels.shape) > 3 else labels |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" for binary change detection task""" |
|
|
criterion_change = dice_focal_loss() |
|
|
|
|
|
|
|
|
diceloss, foclaloss = criterion_change(scores, labels.float()) |
|
|
|
|
|
loss_change = diceloss + foclaloss |
|
|
|
|
|
return loss_change.mean() |
|
|
|