File size: 2,532 Bytes
62785f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | """Custom loss functions for coronary segmentation with proper MONAI interface."""
import torch
import torch.nn as nn
from monai.losses.cldice import soft_dice, soft_skel
from monai.networks.utils import one_hot
class DiceclDiceLoss(nn.Module):
"""
Combined Dice and clDice loss with proper MONAI interface.
This wrapper addresses the interface limitations of MONAI's SoftDiceclDiceLoss
(see https://github.com/Project-MONAI/MONAI/issues/8239).
Uses (input, target) argument order compatible with MONAI SupervisedTrainer.
Handles softmax and one-hot encoding internally.
Args:
iter_: Number of iterations for soft skeleton computation.
alpha: Weight for clDice loss (1-alpha for Dice loss).
smooth: Smoothing parameter to avoid division by zero.
include_background: Whether to include background class in loss computation.
to_onehot_y: Whether to convert target to one-hot encoding.
softmax: Whether to apply softmax to input predictions.
"""
def __init__(
self,
iter_: int = 50,
alpha: float = 0.5,
smooth: float = 1.0,
include_background: bool = False,
to_onehot_y: bool = True,
softmax: bool = True,
):
super().__init__()
self.iter = iter_
self.alpha = alpha
self.smooth = smooth
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.softmax = softmax
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
n_classes = input.shape[1]
if self.softmax:
y_pred = torch.softmax(input, dim=1)
else:
y_pred = input
if self.to_onehot_y:
y_true = one_hot(target, num_classes=n_classes)
else:
y_true = target
if not self.include_background:
y_pred = y_pred[:, 1:, ...]
y_true = y_true[:, 1:, ...]
dice_loss = 1.0 - soft_dice(y_true, y_pred, self.smooth)
skel_pred = soft_skel(y_pred, self.iter)
skel_true = soft_skel(y_true, self.iter)
tprec = (torch.sum(skel_pred * y_true) + self.smooth) / (
torch.sum(skel_pred) + self.smooth
)
tsens = (torch.sum(skel_true * y_pred) + self.smooth) / (
torch.sum(skel_true) + self.smooth
)
cl_dice_loss = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
return (1.0 - self.alpha) * dice_loss + self.alpha * cl_dice_loss
|