| """
|
| Dataset-aware loss functions
|
| Implements Critical Fix #2: Dataset-Aware Loss Function
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Dict, Optional
|
|
|
|
|
| class DiceLoss(nn.Module):
|
| """Dice loss for segmentation"""
|
|
|
| def __init__(self, smooth: float = 1.0):
|
| """
|
| Initialize Dice loss
|
|
|
| Args:
|
| smooth: Smoothing factor to avoid division by zero
|
| """
|
| super().__init__()
|
| self.smooth = smooth
|
|
|
| def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| """
|
| Compute Dice loss
|
|
|
| Args:
|
| pred: Predicted probabilities (B, 1, H, W)
|
| target: Ground truth mask (B, 1, H, W)
|
|
|
| Returns:
|
| Dice loss value
|
| """
|
| pred = torch.sigmoid(pred)
|
|
|
|
|
| pred_flat = pred.view(-1)
|
| target_flat = target.view(-1)
|
|
|
|
|
| intersection = (pred_flat * target_flat).sum()
|
| dice = (2. * intersection + self.smooth) / (
|
| pred_flat.sum() + target_flat.sum() + self.smooth
|
| )
|
|
|
| return 1 - dice
|
|
|
|
|
| class CombinedLoss(nn.Module):
|
| """
|
| Combined BCE + Dice loss for segmentation
|
| Dataset-aware: Only uses Dice when pixel masks are available
|
| """
|
|
|
| def __init__(self,
|
| bce_weight: float = 1.0,
|
| dice_weight: float = 1.0):
|
| """
|
| Initialize combined loss
|
|
|
| Args:
|
| bce_weight: Weight for BCE loss
|
| dice_weight: Weight for Dice loss
|
| """
|
| super().__init__()
|
|
|
| self.bce_weight = bce_weight
|
| self.dice_weight = dice_weight
|
|
|
| self.bce_loss = nn.BCEWithLogitsLoss()
|
| self.dice_loss = DiceLoss()
|
|
|
| def forward(self,
|
| pred: torch.Tensor,
|
| target: torch.Tensor,
|
| has_pixel_mask: bool = True) -> Dict[str, torch.Tensor]:
|
| """
|
| Compute loss (dataset-aware)
|
|
|
| Critical Fix #2: Only use Dice loss for datasets with pixel masks
|
|
|
| Args:
|
| pred: Predicted logits (B, 1, H, W)
|
| target: Ground truth mask (B, 1, H, W)
|
| has_pixel_mask: Whether dataset has pixel-level masks
|
|
|
| Returns:
|
| Dictionary with 'total', 'bce', and optionally 'dice' losses
|
| """
|
|
|
| bce = self.bce_loss(pred, target)
|
|
|
| losses = {
|
| 'bce': bce
|
| }
|
|
|
| if has_pixel_mask:
|
|
|
| dice = self.dice_loss(pred, target)
|
| losses['dice'] = dice
|
| losses['total'] = self.bce_weight * bce + self.dice_weight * dice
|
| else:
|
|
|
| losses['total'] = self.bce_weight * bce
|
|
|
| return losses
|
|
|
|
|
| class DatasetAwareLoss(nn.Module):
|
| """
|
| Dataset-aware loss function wrapper
|
| Automatically determines appropriate loss based on dataset metadata
|
| """
|
|
|
| def __init__(self, config):
|
| """
|
| Initialize dataset-aware loss
|
|
|
| Args:
|
| config: Configuration object
|
| """
|
| super().__init__()
|
|
|
| self.config = config
|
|
|
| bce_weight = config.get('loss.bce_weight', 1.0)
|
| dice_weight = config.get('loss.dice_weight', 1.0)
|
|
|
| self.combined_loss = CombinedLoss(
|
| bce_weight=bce_weight,
|
| dice_weight=dice_weight
|
| )
|
|
|
| def forward(self,
|
| pred: torch.Tensor,
|
| target: torch.Tensor,
|
| metadata: Dict) -> Dict[str, torch.Tensor]:
|
| """
|
| Compute loss with dataset awareness
|
|
|
| Args:
|
| pred: Predicted logits (B, 1, H, W)
|
| target: Ground truth mask (B, 1, H, W)
|
| metadata: Batch metadata containing 'has_pixel_mask' flags
|
|
|
| Returns:
|
| Dictionary with loss components
|
| """
|
|
|
| has_pixel_mask = all(m.get('has_pixel_mask', True) for m in metadata) \
|
| if isinstance(metadata, list) else metadata.get('has_pixel_mask', True)
|
|
|
| return self.combined_loss(pred, target, has_pixel_mask)
|
|
|
|
|
| def get_loss_function(config) -> DatasetAwareLoss:
|
| """
|
| Factory function to create loss
|
|
|
| Args:
|
| config: Configuration object
|
|
|
| Returns:
|
| Loss function instance
|
| """
|
| return DatasetAwareLoss(config)
|
|
|