kbressem's picture
Upload folder using huggingface_hub
62785f9 verified
"""Custom loss functions for coronary segmentation with proper MONAI interface."""
import torch
import torch.nn as nn
from monai.losses.cldice import soft_dice, soft_skel
from monai.networks.utils import one_hot
class DiceclDiceLoss(nn.Module):
"""
Combined Dice and clDice loss with proper MONAI interface.
This wrapper addresses the interface limitations of MONAI's SoftDiceclDiceLoss
(see https://github.com/Project-MONAI/MONAI/issues/8239).
Uses (input, target) argument order compatible with MONAI SupervisedTrainer.
Handles softmax and one-hot encoding internally.
Args:
iter_: Number of iterations for soft skeleton computation.
alpha: Weight for clDice loss (1-alpha for Dice loss).
smooth: Smoothing parameter to avoid division by zero.
include_background: Whether to include background class in loss computation.
to_onehot_y: Whether to convert target to one-hot encoding.
softmax: Whether to apply softmax to input predictions.
"""
def __init__(
self,
iter_: int = 50,
alpha: float = 0.5,
smooth: float = 1.0,
include_background: bool = False,
to_onehot_y: bool = True,
softmax: bool = True,
):
super().__init__()
self.iter = iter_
self.alpha = alpha
self.smooth = smooth
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.softmax = softmax
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
n_classes = input.shape[1]
if self.softmax:
y_pred = torch.softmax(input, dim=1)
else:
y_pred = input
if self.to_onehot_y:
y_true = one_hot(target, num_classes=n_classes)
else:
y_true = target
if not self.include_background:
y_pred = y_pred[:, 1:, ...]
y_true = y_true[:, 1:, ...]
dice_loss = 1.0 - soft_dice(y_true, y_pred, self.smooth)
skel_pred = soft_skel(y_pred, self.iter)
skel_true = soft_skel(y_true, self.iter)
tprec = (torch.sum(skel_pred * y_true) + self.smooth) / (
torch.sum(skel_pred) + self.smooth
)
tsens = (torch.sum(skel_true * y_pred) + self.smooth) / (
torch.sum(skel_true) + self.smooth
)
cl_dice_loss = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
return (1.0 - self.alpha) * dice_loss + self.alpha * cl_dice_loss