import torch import torch.nn as nn import torch.nn.functional as F try: from LovaszSoftmax.pytorch.lovasz_losses import lovasz_hinge except ImportError: pass __all__ = ['BCEDiceLoss', 'LovaszHingeLoss'] class BCEDiceLoss(nn.Module): def __init__(self): super().__init__() def forward(self, input, target): bce = F.binary_cross_entropy_with_logits(input, target) smooth = 1e-5 input = torch.sigmoid(input) num = target.size(0) input = input.view(num, -1) target = target.view(num, -1) intersection = (input * target) dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth) dice = 1 - dice.sum() / num return 0.5 * bce + dice class LovaszHingeLoss(nn.Module): def __init__(self): super().__init__() def forward(self, input, target): input = input.squeeze(1) target = target.squeeze(1) loss = lovasz_hinge(input, target, per_image=True) return loss