Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class DiceLoss(nn.Module): | |
| """ | |
| Soft Dice loss for binary segmentation. | |
| Expected shapes: | |
| logits: [B, 1, H, W] | |
| targets: [B, 1, H, W] | |
| mask: [B, 1, H, W], optional FOV mask | |
| The model should output raw logits, not sigmoid probabilities. | |
| """ | |
| def __init__(self, smooth=1.0): | |
| super().__init__() | |
| self.smooth = smooth | |
| def forward(self, logits, targets, mask=None): | |
| probs = torch.sigmoid(logits) | |
| if mask is not None: | |
| probs = probs * mask | |
| targets = targets * mask | |
| probs = probs.flatten(1) | |
| targets = targets.flatten(1) | |
| intersection = (probs * targets).sum(dim=1) | |
| denominator = probs.sum(dim=1) + targets.sum(dim=1) | |
| dice = (2.0 * intersection + self.smooth) / ( | |
| denominator + self.smooth | |
| ) | |
| return 1.0 - dice.mean() | |
| class BCEDiceLoss(nn.Module): | |
| """ | |
| BCEWithLogits + Dice loss for binary vessel segmentation. | |
| The optional mask argument is intended for the DRIVE FOV mask, so that | |
| background outside the retinal field of view does not dominate training. | |
| """ | |
| def __init__( | |
| self, | |
| bce_weight=1.0, | |
| dice_weight=1.0, | |
| smooth=1.0, | |
| ): | |
| super().__init__() | |
| self.bce_weight = bce_weight | |
| self.dice_weight = dice_weight | |
| self.dice = DiceLoss(smooth=smooth) | |
| def forward(self, logits, targets, mask=None): | |
| bce = F.binary_cross_entropy_with_logits( | |
| logits, | |
| targets, | |
| reduction="none", | |
| ) | |
| if mask is not None: | |
| bce = bce * mask | |
| bce = bce.sum() / mask.sum().clamp_min(1.0) | |
| else: | |
| bce = bce.mean() | |
| dice = self.dice(logits, targets, mask) | |
| loss = self.bce_weight * bce + self.dice_weight * dice | |
| return loss | |
| def compute_dice_score( | |
| logits, | |
| targets, | |
| mask=None, | |
| threshold=0.5, | |
| eps=1e-7, | |
| ): | |
| """ | |
| Hard Dice score for monitoring. | |
| Expected shapes: | |
| logits: [B, 1, H, W] | |
| targets: [B, 1, H, W] | |
| mask: [B, 1, H, W], optional | |
| """ | |
| probs = torch.sigmoid(logits) | |
| preds = (probs > threshold).float() | |
| if mask is not None: | |
| preds = preds * mask | |
| targets = targets * mask | |
| preds = preds.flatten(1) | |
| targets = targets.flatten(1) | |
| intersection = (preds * targets).sum(dim=1) | |
| denominator = preds.sum(dim=1) + targets.sum(dim=1) | |
| dice = (2.0 * intersection + eps) / (denominator + eps) | |
| return dice.mean().item() | |
| if __name__ == "__main__": | |
| # Smoke test: | |
| # python losses.py | |
| logits = torch.randn(2, 1, 512, 512) | |
| targets = torch.randint(0, 2, (2, 1, 512, 512)).float() | |
| fov = torch.ones(2, 1, 512, 512) | |
| criterion = BCEDiceLoss( | |
| bce_weight=1.0, | |
| dice_weight=1.0, | |
| ) | |
| loss = criterion(logits, targets, fov) | |
| dice = compute_dice_score(logits, targets, fov) | |
| print("Loss:", loss.item()) | |
| print("Dice:", dice) | |
| print("Smoke test passed.") |