File size: 3,149 Bytes
e99a83c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn as nn
import torch.nn.functional as F


class DiceLoss(nn.Module):
    """
    Soft Dice loss for binary segmentation.

    Expected shapes:
        logits:  [B, 1, H, W]
        targets: [B, 1, H, W]
        mask:    [B, 1, H, W], optional FOV mask

    The model should output raw logits, not sigmoid probabilities.
    """

    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits, targets, mask=None):
        probs = torch.sigmoid(logits)

        if mask is not None:
            probs = probs * mask
            targets = targets * mask

        probs = probs.flatten(1)
        targets = targets.flatten(1)

        intersection = (probs * targets).sum(dim=1)
        denominator = probs.sum(dim=1) + targets.sum(dim=1)

        dice = (2.0 * intersection + self.smooth) / (
            denominator + self.smooth
        )

        return 1.0 - dice.mean()


class BCEDiceLoss(nn.Module):
    """
    BCEWithLogits + Dice loss for binary vessel segmentation.

    The optional mask argument is intended for the DRIVE FOV mask, so that
    background outside the retinal field of view does not dominate training.
    """

    def __init__(
        self,
        bce_weight=1.0,
        dice_weight=1.0,
        smooth=1.0,
    ):
        super().__init__()

        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.dice = DiceLoss(smooth=smooth)

    def forward(self, logits, targets, mask=None):
        bce = F.binary_cross_entropy_with_logits(
            logits,
            targets,
            reduction="none",
        )

        if mask is not None:
            bce = bce * mask
            bce = bce.sum() / mask.sum().clamp_min(1.0)
        else:
            bce = bce.mean()

        dice = self.dice(logits, targets, mask)

        loss = self.bce_weight * bce + self.dice_weight * dice

        return loss


@torch.no_grad()
def compute_dice_score(
    logits,
    targets,
    mask=None,
    threshold=0.5,
    eps=1e-7,
):
    """
    Hard Dice score for monitoring.

    Expected shapes:
        logits:  [B, 1, H, W]
        targets: [B, 1, H, W]
        mask:    [B, 1, H, W], optional
    """

    probs = torch.sigmoid(logits)
    preds = (probs > threshold).float()

    if mask is not None:
        preds = preds * mask
        targets = targets * mask

    preds = preds.flatten(1)
    targets = targets.flatten(1)

    intersection = (preds * targets).sum(dim=1)
    denominator = preds.sum(dim=1) + targets.sum(dim=1)

    dice = (2.0 * intersection + eps) / (denominator + eps)

    return dice.mean().item()


if __name__ == "__main__":
    # Smoke test:
    # python losses.py

    logits = torch.randn(2, 1, 512, 512)
    targets = torch.randint(0, 2, (2, 1, 512, 512)).float()
    fov = torch.ones(2, 1, 512, 512)

    criterion = BCEDiceLoss(
        bce_weight=1.0,
        dice_weight=1.0,
    )

    loss = criterion(logits, targets, fov)
    dice = compute_dice_score(logits, targets, fov)

    print("Loss:", loss.item())
    print("Dice:", dice)
    print("Smoke test passed.")