Spaces:
Running
Running
File size: 3,149 Bytes
e99a83c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import torch
import torch.nn as nn
import torch.nn.functional as F
class DiceLoss(nn.Module):
"""
Soft Dice loss for binary segmentation.
Expected shapes:
logits: [B, 1, H, W]
targets: [B, 1, H, W]
mask: [B, 1, H, W], optional FOV mask
The model should output raw logits, not sigmoid probabilities.
"""
def __init__(self, smooth=1.0):
super().__init__()
self.smooth = smooth
def forward(self, logits, targets, mask=None):
probs = torch.sigmoid(logits)
if mask is not None:
probs = probs * mask
targets = targets * mask
probs = probs.flatten(1)
targets = targets.flatten(1)
intersection = (probs * targets).sum(dim=1)
denominator = probs.sum(dim=1) + targets.sum(dim=1)
dice = (2.0 * intersection + self.smooth) / (
denominator + self.smooth
)
return 1.0 - dice.mean()
class BCEDiceLoss(nn.Module):
"""
BCEWithLogits + Dice loss for binary vessel segmentation.
The optional mask argument is intended for the DRIVE FOV mask, so that
background outside the retinal field of view does not dominate training.
"""
def __init__(
self,
bce_weight=1.0,
dice_weight=1.0,
smooth=1.0,
):
super().__init__()
self.bce_weight = bce_weight
self.dice_weight = dice_weight
self.dice = DiceLoss(smooth=smooth)
def forward(self, logits, targets, mask=None):
bce = F.binary_cross_entropy_with_logits(
logits,
targets,
reduction="none",
)
if mask is not None:
bce = bce * mask
bce = bce.sum() / mask.sum().clamp_min(1.0)
else:
bce = bce.mean()
dice = self.dice(logits, targets, mask)
loss = self.bce_weight * bce + self.dice_weight * dice
return loss
@torch.no_grad()
def compute_dice_score(
logits,
targets,
mask=None,
threshold=0.5,
eps=1e-7,
):
"""
Hard Dice score for monitoring.
Expected shapes:
logits: [B, 1, H, W]
targets: [B, 1, H, W]
mask: [B, 1, H, W], optional
"""
probs = torch.sigmoid(logits)
preds = (probs > threshold).float()
if mask is not None:
preds = preds * mask
targets = targets * mask
preds = preds.flatten(1)
targets = targets.flatten(1)
intersection = (preds * targets).sum(dim=1)
denominator = preds.sum(dim=1) + targets.sum(dim=1)
dice = (2.0 * intersection + eps) / (denominator + eps)
return dice.mean().item()
if __name__ == "__main__":
# Smoke test:
# python losses.py
logits = torch.randn(2, 1, 512, 512)
targets = torch.randint(0, 2, (2, 1, 512, 512)).float()
fov = torch.ones(2, 1, 512, 512)
criterion = BCEDiceLoss(
bce_weight=1.0,
dice_weight=1.0,
)
loss = criterion(logits, targets, fov)
dice = compute_dice_score(logits, targets, fov)
print("Loss:", loss.item())
print("Dice:", dice)
print("Smoke test passed.") |