File size: 730 Bytes
7b615ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self, smooth=1e-6):
super(DiceLoss, self).__init__()
self.smooth = smooth
def forward(self, logits, targets):
logits = logits.float() # [B, C, H, W]
probs = torch.softmax(logits, dim=1)
preds = probs[:, 1, :, :] # [B, H, W]
if targets.ndim == 4:
targets = targets.squeeze(1) # [B, H, W]
targets = (targets == 1).float() # binariza se necessário
intersection = (preds * targets).sum(dim=(1, 2))
union = preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2))
dice = (2 * intersection + self.smooth) / (union + self.smooth)
return 1 - dice.mean()
|