"""Loss builder. Everything is treated as MULTICLASS (binary == 2 classes), which sidesteps the binary/multiclass mode pitfall and unifies all datasets. ce_dice : CrossEntropy + multiclass Dice (default, robust for medical seg) ce : CrossEntropy only dice : multiclass Dice only Inputs: logits [B,C,H,W], target [B,H,W] (long, ids 0..C-1). """ from __future__ import annotations import torch import torch.nn as nn import segmentation_models_pytorch as smp class CEDiceLoss(nn.Module): def __init__(self, mode: str = "ce_dice"): super().__init__() self.mode = mode self.ce = nn.CrossEntropyLoss() self.dice = smp.losses.DiceLoss(mode=smp.losses.MULTICLASS_MODE, from_logits=True) def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.mode == "ce": return self.ce(logits, target) if self.mode == "dice": return self.dice(logits, target) return self.ce(logits, target) + self.dice(logits, target) def build_loss(name: str = "ce_dice") -> nn.Module: return CEDiceLoss(name)