File size: 2,532 Bytes
62785f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Custom loss functions for coronary segmentation with proper MONAI interface."""

import torch
import torch.nn as nn
from monai.losses.cldice import soft_dice, soft_skel
from monai.networks.utils import one_hot


class DiceclDiceLoss(nn.Module):
    """
    Combined Dice and clDice loss with proper MONAI interface.

    This wrapper addresses the interface limitations of MONAI's SoftDiceclDiceLoss
    (see https://github.com/Project-MONAI/MONAI/issues/8239).

    Uses (input, target) argument order compatible with MONAI SupervisedTrainer.
    Handles softmax and one-hot encoding internally.

    Args:
        iter_: Number of iterations for soft skeleton computation.
        alpha: Weight for clDice loss (1-alpha for Dice loss).
        smooth: Smoothing parameter to avoid division by zero.
        include_background: Whether to include background class in loss computation.
        to_onehot_y: Whether to convert target to one-hot encoding.
        softmax: Whether to apply softmax to input predictions.
    """

    def __init__(
        self,
        iter_: int = 50,
        alpha: float = 0.5,
        smooth: float = 1.0,
        include_background: bool = False,
        to_onehot_y: bool = True,
        softmax: bool = True,
    ):
        super().__init__()
        self.iter = iter_
        self.alpha = alpha
        self.smooth = smooth
        self.include_background = include_background
        self.to_onehot_y = to_onehot_y
        self.softmax = softmax

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        n_classes = input.shape[1]

        if self.softmax:
            y_pred = torch.softmax(input, dim=1)
        else:
            y_pred = input

        if self.to_onehot_y:
            y_true = one_hot(target, num_classes=n_classes)
        else:
            y_true = target

        if not self.include_background:
            y_pred = y_pred[:, 1:, ...]
            y_true = y_true[:, 1:, ...]

        dice_loss = 1.0 - soft_dice(y_true, y_pred, self.smooth)

        skel_pred = soft_skel(y_pred, self.iter)
        skel_true = soft_skel(y_true, self.iter)

        tprec = (torch.sum(skel_pred * y_true) + self.smooth) / (
            torch.sum(skel_pred) + self.smooth
        )
        tsens = (torch.sum(skel_true * y_pred) + self.smooth) / (
            torch.sum(skel_true) + self.smooth
        )
        cl_dice_loss = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)

        return (1.0 - self.alpha) * dice_loss + self.alpha * cl_dice_loss