import torch import torch.nn as nn import torch.nn.functional as F class DiceLoss(nn.Module): """ Soft Dice loss for binary segmentation. Expected shapes: logits: [B, 1, H, W] targets: [B, 1, H, W] mask: [B, 1, H, W], optional FOV mask The model should output raw logits, not sigmoid probabilities. """ def __init__(self, smooth=1.0): super().__init__() self.smooth = smooth def forward(self, logits, targets, mask=None): probs = torch.sigmoid(logits) if mask is not None: probs = probs * mask targets = targets * mask probs = probs.flatten(1) targets = targets.flatten(1) intersection = (probs * targets).sum(dim=1) denominator = probs.sum(dim=1) + targets.sum(dim=1) dice = (2.0 * intersection + self.smooth) / ( denominator + self.smooth ) return 1.0 - dice.mean() class BCEDiceLoss(nn.Module): """ BCEWithLogits + Dice loss for binary vessel segmentation. The optional mask argument is intended for the DRIVE FOV mask, so that background outside the retinal field of view does not dominate training. """ def __init__( self, bce_weight=1.0, dice_weight=1.0, smooth=1.0, ): super().__init__() self.bce_weight = bce_weight self.dice_weight = dice_weight self.dice = DiceLoss(smooth=smooth) def forward(self, logits, targets, mask=None): bce = F.binary_cross_entropy_with_logits( logits, targets, reduction="none", ) if mask is not None: bce = bce * mask bce = bce.sum() / mask.sum().clamp_min(1.0) else: bce = bce.mean() dice = self.dice(logits, targets, mask) loss = self.bce_weight * bce + self.dice_weight * dice return loss @torch.no_grad() def compute_dice_score( logits, targets, mask=None, threshold=0.5, eps=1e-7, ): """ Hard Dice score for monitoring. Expected shapes: logits: [B, 1, H, W] targets: [B, 1, H, W] mask: [B, 1, H, W], optional """ probs = torch.sigmoid(logits) preds = (probs > threshold).float() if mask is not None: preds = preds * mask targets = targets * mask preds = preds.flatten(1) targets = targets.flatten(1) intersection = (preds * targets).sum(dim=1) denominator = preds.sum(dim=1) + targets.sum(dim=1) dice = (2.0 * intersection + eps) / (denominator + eps) return dice.mean().item() if __name__ == "__main__": # Smoke test: # python losses.py logits = torch.randn(2, 1, 512, 512) targets = torch.randint(0, 2, (2, 1, 512, 512)).float() fov = torch.ones(2, 1, 512, 512) criterion = BCEDiceLoss( bce_weight=1.0, dice_weight=1.0, ) loss = criterion(logits, targets, fov) dice = compute_dice_score(logits, targets, fov) print("Loss:", loss.item()) print("Dice:", dice) print("Smoke test passed.")