| """Custom loss functions for coronary segmentation with proper MONAI interface.""" |
|
|
| import torch |
| import torch.nn as nn |
| from monai.losses.cldice import soft_dice, soft_skel |
| from monai.networks.utils import one_hot |
|
|
|
|
| class DiceclDiceLoss(nn.Module): |
| """ |
| Combined Dice and clDice loss with proper MONAI interface. |
| |
| This wrapper addresses the interface limitations of MONAI's SoftDiceclDiceLoss |
| (see https://github.com/Project-MONAI/MONAI/issues/8239). |
| |
| Uses (input, target) argument order compatible with MONAI SupervisedTrainer. |
| Handles softmax and one-hot encoding internally. |
| |
| Args: |
| iter_: Number of iterations for soft skeleton computation. |
| alpha: Weight for clDice loss (1-alpha for Dice loss). |
| smooth: Smoothing parameter to avoid division by zero. |
| include_background: Whether to include background class in loss computation. |
| to_onehot_y: Whether to convert target to one-hot encoding. |
| softmax: Whether to apply softmax to input predictions. |
| """ |
|
|
| def __init__( |
| self, |
| iter_: int = 50, |
| alpha: float = 0.5, |
| smooth: float = 1.0, |
| include_background: bool = False, |
| to_onehot_y: bool = True, |
| softmax: bool = True, |
| ): |
| super().__init__() |
| self.iter = iter_ |
| self.alpha = alpha |
| self.smooth = smooth |
| self.include_background = include_background |
| self.to_onehot_y = to_onehot_y |
| self.softmax = softmax |
|
|
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| n_classes = input.shape[1] |
|
|
| if self.softmax: |
| y_pred = torch.softmax(input, dim=1) |
| else: |
| y_pred = input |
|
|
| if self.to_onehot_y: |
| y_true = one_hot(target, num_classes=n_classes) |
| else: |
| y_true = target |
|
|
| if not self.include_background: |
| y_pred = y_pred[:, 1:, ...] |
| y_true = y_true[:, 1:, ...] |
|
|
| dice_loss = 1.0 - soft_dice(y_true, y_pred, self.smooth) |
|
|
| skel_pred = soft_skel(y_pred, self.iter) |
| skel_true = soft_skel(y_true, self.iter) |
|
|
| tprec = (torch.sum(skel_pred * y_true) + self.smooth) / ( |
| torch.sum(skel_pred) + self.smooth |
| ) |
| tsens = (torch.sum(skel_true * y_pred) + self.smooth) / ( |
| torch.sum(skel_true) + self.smooth |
| ) |
| cl_dice_loss = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) |
|
|
| return (1.0 - self.alpha) * dice_loss + self.alpha * cl_dice_loss |
|
|