AndreCosta's picture
Initial clean commit with LFS configured
7b615ae
raw
history blame contribute delete
730 Bytes
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self, smooth=1e-6):
super(DiceLoss, self).__init__()
self.smooth = smooth
def forward(self, logits, targets):
logits = logits.float() # [B, C, H, W]
probs = torch.softmax(logits, dim=1)
preds = probs[:, 1, :, :] # [B, H, W]
if targets.ndim == 4:
targets = targets.squeeze(1) # [B, H, W]
targets = (targets == 1).float() # binariza se necessário
intersection = (preds * targets).sum(dim=(1, 2))
union = preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2))
dice = (2 * intersection + self.smooth) / (union + self.smooth)
return 1 - dice.mean()