CFPVesselSeg / losses.py
farrell236's picture
add src
e99a83c
import torch
import torch.nn as nn
import torch.nn.functional as F
class DiceLoss(nn.Module):
"""
Soft Dice loss for binary segmentation.
Expected shapes:
logits: [B, 1, H, W]
targets: [B, 1, H, W]
mask: [B, 1, H, W], optional FOV mask
The model should output raw logits, not sigmoid probabilities.
"""
def __init__(self, smooth=1.0):
super().__init__()
self.smooth = smooth
def forward(self, logits, targets, mask=None):
probs = torch.sigmoid(logits)
if mask is not None:
probs = probs * mask
targets = targets * mask
probs = probs.flatten(1)
targets = targets.flatten(1)
intersection = (probs * targets).sum(dim=1)
denominator = probs.sum(dim=1) + targets.sum(dim=1)
dice = (2.0 * intersection + self.smooth) / (
denominator + self.smooth
)
return 1.0 - dice.mean()
class BCEDiceLoss(nn.Module):
"""
BCEWithLogits + Dice loss for binary vessel segmentation.
The optional mask argument is intended for the DRIVE FOV mask, so that
background outside the retinal field of view does not dominate training.
"""
def __init__(
self,
bce_weight=1.0,
dice_weight=1.0,
smooth=1.0,
):
super().__init__()
self.bce_weight = bce_weight
self.dice_weight = dice_weight
self.dice = DiceLoss(smooth=smooth)
def forward(self, logits, targets, mask=None):
bce = F.binary_cross_entropy_with_logits(
logits,
targets,
reduction="none",
)
if mask is not None:
bce = bce * mask
bce = bce.sum() / mask.sum().clamp_min(1.0)
else:
bce = bce.mean()
dice = self.dice(logits, targets, mask)
loss = self.bce_weight * bce + self.dice_weight * dice
return loss
@torch.no_grad()
def compute_dice_score(
logits,
targets,
mask=None,
threshold=0.5,
eps=1e-7,
):
"""
Hard Dice score for monitoring.
Expected shapes:
logits: [B, 1, H, W]
targets: [B, 1, H, W]
mask: [B, 1, H, W], optional
"""
probs = torch.sigmoid(logits)
preds = (probs > threshold).float()
if mask is not None:
preds = preds * mask
targets = targets * mask
preds = preds.flatten(1)
targets = targets.flatten(1)
intersection = (preds * targets).sum(dim=1)
denominator = preds.sum(dim=1) + targets.sum(dim=1)
dice = (2.0 * intersection + eps) / (denominator + eps)
return dice.mean().item()
if __name__ == "__main__":
# Smoke test:
# python losses.py
logits = torch.randn(2, 1, 512, 512)
targets = torch.randint(0, 2, (2, 1, 512, 512)).float()
fov = torch.ones(2, 1, 512, 512)
criterion = BCEDiceLoss(
bce_weight=1.0,
dice_weight=1.0,
)
loss = criterion(logits, targets, fov)
dice = compute_dice_score(logits, targets, fov)
print("Loss:", loss.item())
print("Dice:", dice)
print("Smoke test passed.")