import torch import torch.nn as nn class DiceLoss(nn.Module): def __init__(self, smooth=1e-6): super(DiceLoss, self).__init__() self.smooth = smooth def forward(self, logits, targets): logits = logits.float() # [B, C, H, W] probs = torch.softmax(logits, dim=1) preds = probs[:, 1, :, :] # [B, H, W] if targets.ndim == 4: targets = targets.squeeze(1) # [B, H, W] targets = (targets == 1).float() # binariza se necessário intersection = (preds * targets).sum(dim=(1, 2)) union = preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) dice = (2 * intersection + self.smooth) / (union + self.smooth) return 1 - dice.mean()