| import torch | |
| import torch.nn as nn | |
| class DiceLoss(nn.Module): | |
| """Dice Loss for segmentation""" | |
| def __init__(self, smooth=1.0): | |
| super().__init__() | |
| self.smooth = smooth | |
| def forward(self, inputs, targets): | |
| inputs = torch.sigmoid(inputs).view(-1) | |
| targets = targets.view(-1).float() | |
| intersection = (inputs * targets).sum() | |
| dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth) | |
| return 1 - dice | |