File size: 4,850 Bytes
ff0e79e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""

Dataset-aware loss functions

Implements Critical Fix #2: Dataset-Aware Loss Function

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional


class DiceLoss(nn.Module):
    """Dice loss for segmentation"""
    
    def __init__(self, smooth: float = 1.0):
        """

        Initialize Dice loss

        

        Args:

            smooth: Smoothing factor to avoid division by zero

        """
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """

        Compute Dice loss

        

        Args:

            pred: Predicted probabilities (B, 1, H, W)

            target: Ground truth mask (B, 1, H, W)

        

        Returns:

            Dice loss value

        """
        pred = torch.sigmoid(pred)
        
        # Flatten
        pred_flat = pred.view(-1)
        target_flat = target.view(-1)
        
        # Dice coefficient
        intersection = (pred_flat * target_flat).sum()
        dice = (2. * intersection + self.smooth) / (
            pred_flat.sum() + target_flat.sum() + self.smooth
        )
        
        return 1 - dice


class CombinedLoss(nn.Module):
    """

    Combined BCE + Dice loss for segmentation

    Dataset-aware: Only uses Dice when pixel masks are available

    """
    
    def __init__(self, 

                 bce_weight: float = 1.0,

                 dice_weight: float = 1.0):
        """

        Initialize combined loss

        

        Args:

            bce_weight: Weight for BCE loss

            dice_weight: Weight for Dice loss

        """
        super().__init__()
        
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = DiceLoss()
    
    def forward(self, 

                pred: torch.Tensor, 

                target: torch.Tensor,

                has_pixel_mask: bool = True) -> Dict[str, torch.Tensor]:
        """

        Compute loss (dataset-aware)

        

        Critical Fix #2: Only use Dice loss for datasets with pixel masks

        

        Args:

            pred: Predicted logits (B, 1, H, W)

            target: Ground truth mask (B, 1, H, W)

            has_pixel_mask: Whether dataset has pixel-level masks

        

        Returns:

            Dictionary with 'total', 'bce', and optionally 'dice' losses

        """
        # BCE loss (always used)
        bce = self.bce_loss(pred, target)
        
        losses = {
            'bce': bce
        }
        
        if has_pixel_mask:
            # Use Dice loss only for datasets with pixel masks
            dice = self.dice_loss(pred, target)
            losses['dice'] = dice
            losses['total'] = self.bce_weight * bce + self.dice_weight * dice
        else:
            # Critical Fix #2: CASIA only uses BCE
            losses['total'] = self.bce_weight * bce
        
        return losses


class DatasetAwareLoss(nn.Module):
    """

    Dataset-aware loss function wrapper

    Automatically determines appropriate loss based on dataset metadata

    """
    
    def __init__(self, config):
        """

        Initialize dataset-aware loss

        

        Args:

            config: Configuration object

        """
        super().__init__()
        
        self.config = config
        
        bce_weight = config.get('loss.bce_weight', 1.0)
        dice_weight = config.get('loss.dice_weight', 1.0)
        
        self.combined_loss = CombinedLoss(
            bce_weight=bce_weight,
            dice_weight=dice_weight
        )
    
    def forward(self, 

                pred: torch.Tensor, 

                target: torch.Tensor,

                metadata: Dict) -> Dict[str, torch.Tensor]:
        """

        Compute loss with dataset awareness

        

        Args:

            pred: Predicted logits (B, 1, H, W)

            target: Ground truth mask (B, 1, H, W)

            metadata: Batch metadata containing 'has_pixel_mask' flags

        

        Returns:

            Dictionary with loss components

        """
        # Check if batch has pixel masks
        has_pixel_mask = all(m.get('has_pixel_mask', True) for m in metadata) \
                        if isinstance(metadata, list) else metadata.get('has_pixel_mask', True)
        
        return self.combined_loss(pred, target, has_pixel_mask)


def get_loss_function(config) -> DatasetAwareLoss:
    """

    Factory function to create loss

    

    Args:

        config: Configuration object

    

    Returns:

        Loss function instance

    """
    return DatasetAwareLoss(config)