"""Custom loss functions for segmentation.""" import torch import torch.nn as nn import torch.nn.functional as F class DiceLoss(nn.Module): """Soft Dice loss operating on logits.""" def __init__(self, smooth: float = 1.0): super().__init__() self.smooth = smooth def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: probs = torch.sigmoid(logits) probs_flat = probs.view(probs.size(0), -1) targets_flat = targets.view(targets.size(0), -1) intersection = (probs_flat * targets_flat).sum(dim=1) union = probs_flat.sum(dim=1) + targets_flat.sum(dim=1) dice = (2.0 * intersection + self.smooth) / (union + self.smooth) return 1.0 - dice.mean() class BCEDiceLoss(nn.Module): """Weighted combination of BCE and Dice loss.""" def __init__(self, bce_weight: float = 0.5, dice_weight: float = 0.5): super().__init__() self.bce_weight = bce_weight self.dice_weight = dice_weight self.bce = nn.BCEWithLogitsLoss() self.dice = DiceLoss() def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: return self.bce_weight * self.bce(logits, targets) + self.dice_weight * self.dice(logits, targets)