File size: 8,145 Bytes
21613a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""
Focal Loss Implementation for Multi-Class Classification

Focal Loss addresses class imbalance by focusing on hard-to-classify examples.
It down-weights easy examples and focuses training on hard negatives.

Formula: FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)

Where:
- p_t: predicted probability for true class
- α_t: class-specific weight (handles class imbalance)
- γ: focusing parameter (default 2.0, recommended 2.5 for hard classes)

References:
- Lin et al. "Focal Loss for Dense Object Detection" (2017)
- https://arxiv.org/abs/1708.02002
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class FocalLoss(nn.Module):
    """
    Focal Loss for multi-class classification with class weighting.
    
    Args:
        alpha (torch.Tensor or None): Class weights of shape [num_classes].
            If None, all classes are weighted equally.
        gamma (float): Focusing parameter. Higher values focus more on hard examples.
            - gamma=0: equivalent to standard cross-entropy
            - gamma=1: moderate focus on hard examples
            - gamma=2: strong focus (original paper)
            - gamma=2.5: very strong focus (recommended for this task)
        reduction (str): Specifies the reduction to apply: 'none' | 'mean' | 'sum'
    
    Shape:
        - Input: (N, C) where N = batch size, C = number of classes
        - Target: (N) where each value is 0 ≤ targets[i] ≤ C-1
        - Output: scalar if reduction='mean' or 'sum', (N) if reduction='none'
    """
    
    def __init__(self, alpha=None, gamma=2.5, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
        # Validate gamma parameter
        if gamma < 0:
            raise ValueError(f"gamma must be non-negative, got {gamma}")
        
        # Validate reduction parameter
        if reduction not in ['none', 'mean', 'sum']:
            raise ValueError(f"reduction must be 'none', 'mean', or 'sum', got {reduction}")
    
    def forward(self, inputs, targets):
        """
        Compute Focal Loss.
        
        Args:
            inputs (torch.Tensor): Raw logits from model (before softmax)
                                   Shape: (batch_size, num_classes)
            targets (torch.Tensor): Ground truth class labels
                                    Shape: (batch_size,)
        
        Returns:
            torch.Tensor: Computed focal loss (scalar if reduction='mean'/'sum')
        """
        # Convert logits to probabilities
        probs = F.softmax(inputs, dim=1)
        
        # Get the probability of the true class for each sample
        # targets.unsqueeze(1) creates shape (N, 1) for gathering
        targets_one_hot = F.one_hot(targets, num_classes=inputs.size(1))
        p_t = (probs * targets_one_hot).sum(dim=1)  # Shape: (N,)
        
        # Compute focal weight: (1 - p_t)^gamma
        # This up-weights hard examples (low p_t) and down-weights easy examples (high p_t)
        focal_weight = (1.0 - p_t) ** self.gamma
        
        # Compute cross-entropy: -log(p_t)
        # Add epsilon for numerical stability
        ce_loss = -torch.log(p_t + 1e-8)
        
        # Combine: FL = focal_weight * ce_loss
        focal_loss = focal_weight * ce_loss
        
        # Apply class weights (alpha) if provided
        if self.alpha is not None:
            if self.alpha.device != inputs.device:
                self.alpha = self.alpha.to(inputs.device)
            
            # Get alpha for each sample based on its true class
            alpha_t = self.alpha[targets]  # Shape: (N,)
            focal_loss = alpha_t * focal_loss
        
        # Apply reduction
        if self.reduction == 'none':
            return focal_loss
        elif self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()


def compute_class_weights(targets, num_classes=7, minority_boost=1.8):
    """
    Compute balanced class weights with optional boost for minority classes.
    
    Args:
        targets (array-like): Ground truth labels
        num_classes (int): Total number of classes
        minority_boost (float): Multiplicative boost for smallest classes (default 1.8)
    
    Returns:
        torch.Tensor: Class weights of shape [num_classes]
    
    Example:
        >>> targets = [0, 0, 1, 1, 1, 2]
        >>> weights = compute_class_weights(targets, num_classes=3)
        >>> # Class 2 (smallest) will have higher weight
    """
    from sklearn.utils.class_weight import compute_class_weight
    import numpy as np
    
    # Convert to numpy if needed
    if torch.is_tensor(targets):
        targets = targets.cpu().numpy()
    
    # Compute balanced weights using sklearn
    class_weights = compute_class_weight(
        'balanced',
        classes=np.arange(num_classes),
        y=targets
    )
    
    # Identify minority classes (smallest 2-3 classes)
    # Sort class counts to find minorities
    unique, counts = np.unique(targets, return_counts=True)
    class_counts = np.zeros(num_classes)
    class_counts[unique] = counts
    
    # Find classes below median count
    median_count = np.median(class_counts[class_counts > 0])
    minority_classes = np.where(class_counts < median_count)[0]
    
    # Apply boost to minority classes (e.g., Classes 0 and 5)
    for cls_idx in minority_classes:
        if class_counts[cls_idx] > 0:  # Only boost if class exists
            class_weights[cls_idx] *= minority_boost
    
    # Convert to torch tensor
    weights_tensor = torch.FloatTensor(class_weights)
    
    print(f"📊 Class Weights (with {minority_boost}x minority boost):")
    for i in range(num_classes):
        count = int(class_counts[i])
        weight = class_weights[i]
        boost_marker = " ⬆️ BOOSTED" if i in minority_classes else ""
        print(f"   Class {i}: count={count:5d}, weight={weight:.3f}{boost_marker}")
    
    return weights_tensor


# Example usage and testing
if __name__ == "__main__":
    print("🔥 Focal Loss Implementation Test\n")
    
    # Test 1: Basic functionality
    print("Test 1: Basic Focal Loss")
    batch_size = 8
    num_classes = 7
    
    # Simulate logits and targets
    logits = torch.randn(batch_size, num_classes)
    targets = torch.tensor([0, 1, 2, 3, 4, 5, 6, 1])
    
    # Create focal loss (no class weights)
    focal_loss = FocalLoss(alpha=None, gamma=2.5)
    loss = focal_loss(logits, targets)
    print(f"   Loss value: {loss.item():.4f}")
    print("   ✅ Basic test passed\n")
    
    # Test 2: With class weights
    print("Test 2: Focal Loss with Class Weights")
    class_weights = torch.tensor([2.0, 1.0, 1.0, 0.8, 1.2, 2.5, 1.5])
    focal_loss_weighted = FocalLoss(alpha=class_weights, gamma=2.5)
    loss_weighted = focal_loss_weighted(logits, targets)
    print(f"   Loss value: {loss_weighted.item():.4f}")
    print("   ✅ Weighted test passed\n")
    
    # Test 3: Compute class weights
    print("Test 3: Compute Class Weights")
    simulated_targets = torch.cat([
        torch.zeros(100),      # Class 0: 100 samples
        torch.ones(200),       # Class 1: 200 samples
        torch.full((150,), 2), # Class 2: 150 samples
        torch.full((300,), 3), # Class 3: 300 samples (largest)
        torch.full((180,), 4), # Class 4: 180 samples
        torch.full((80,), 5),  # Class 5: 80 samples (smallest)
        torch.full((120,), 6), # Class 6: 120 samples
    ]).long()
    
    weights = compute_class_weights(simulated_targets, num_classes=7, minority_boost=1.8)
    print(f"\n   ✅ Class weight computation passed\n")
    
    # Test 4: Gradient flow
    print("Test 4: Gradient Flow")
    logits.requires_grad = True
    loss = focal_loss_weighted(logits, targets)
    loss.backward()
    print(f"   Gradient exists: {logits.grad is not None}")
    print(f"   Gradient norm: {logits.grad.norm().item():.4f}")
    print("   ✅ Gradient flow test passed\n")
    
    print("✅ All tests passed! Focal Loss is ready for training.")