"""Custom loss functions for coronary segmentation with proper MONAI interface.""" import torch import torch.nn as nn from monai.losses.cldice import soft_dice, soft_skel from monai.networks.utils import one_hot class DiceclDiceLoss(nn.Module): """ Combined Dice and clDice loss with proper MONAI interface. This wrapper addresses the interface limitations of MONAI's SoftDiceclDiceLoss (see https://github.com/Project-MONAI/MONAI/issues/8239). Uses (input, target) argument order compatible with MONAI SupervisedTrainer. Handles softmax and one-hot encoding internally. Args: iter_: Number of iterations for soft skeleton computation. alpha: Weight for clDice loss (1-alpha for Dice loss). smooth: Smoothing parameter to avoid division by zero. include_background: Whether to include background class in loss computation. to_onehot_y: Whether to convert target to one-hot encoding. softmax: Whether to apply softmax to input predictions. """ def __init__( self, iter_: int = 50, alpha: float = 0.5, smooth: float = 1.0, include_background: bool = False, to_onehot_y: bool = True, softmax: bool = True, ): super().__init__() self.iter = iter_ self.alpha = alpha self.smooth = smooth self.include_background = include_background self.to_onehot_y = to_onehot_y self.softmax = softmax def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: n_classes = input.shape[1] if self.softmax: y_pred = torch.softmax(input, dim=1) else: y_pred = input if self.to_onehot_y: y_true = one_hot(target, num_classes=n_classes) else: y_true = target if not self.include_background: y_pred = y_pred[:, 1:, ...] y_true = y_true[:, 1:, ...] dice_loss = 1.0 - soft_dice(y_true, y_pred, self.smooth) skel_pred = soft_skel(y_pred, self.iter) skel_true = soft_skel(y_true, self.iter) tprec = (torch.sum(skel_pred * y_true) + self.smooth) / ( torch.sum(skel_pred) + self.smooth ) tsens = (torch.sum(skel_true * y_pred) + self.smooth) / ( torch.sum(skel_true) + self.smooth ) cl_dice_loss = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) return (1.0 - self.alpha) * dice_loss + self.alpha * cl_dice_loss