File size: 8,145 Bytes
21613a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
"""
Focal Loss Implementation for Multi-Class Classification
Focal Loss addresses class imbalance by focusing on hard-to-classify examples.
It down-weights easy examples and focuses training on hard negatives.
Formula: FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)
Where:
- p_t: predicted probability for true class
- α_t: class-specific weight (handles class imbalance)
- γ: focusing parameter (default 2.0, recommended 2.5 for hard classes)
References:
- Lin et al. "Focal Loss for Dense Object Detection" (2017)
- https://arxiv.org/abs/1708.02002
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
"""
Focal Loss for multi-class classification with class weighting.
Args:
alpha (torch.Tensor or None): Class weights of shape [num_classes].
If None, all classes are weighted equally.
gamma (float): Focusing parameter. Higher values focus more on hard examples.
- gamma=0: equivalent to standard cross-entropy
- gamma=1: moderate focus on hard examples
- gamma=2: strong focus (original paper)
- gamma=2.5: very strong focus (recommended for this task)
reduction (str): Specifies the reduction to apply: 'none' | 'mean' | 'sum'
Shape:
- Input: (N, C) where N = batch size, C = number of classes
- Target: (N) where each value is 0 ≤ targets[i] ≤ C-1
- Output: scalar if reduction='mean' or 'sum', (N) if reduction='none'
"""
def __init__(self, alpha=None, gamma=2.5, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
# Validate gamma parameter
if gamma < 0:
raise ValueError(f"gamma must be non-negative, got {gamma}")
# Validate reduction parameter
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f"reduction must be 'none', 'mean', or 'sum', got {reduction}")
def forward(self, inputs, targets):
"""
Compute Focal Loss.
Args:
inputs (torch.Tensor): Raw logits from model (before softmax)
Shape: (batch_size, num_classes)
targets (torch.Tensor): Ground truth class labels
Shape: (batch_size,)
Returns:
torch.Tensor: Computed focal loss (scalar if reduction='mean'/'sum')
"""
# Convert logits to probabilities
probs = F.softmax(inputs, dim=1)
# Get the probability of the true class for each sample
# targets.unsqueeze(1) creates shape (N, 1) for gathering
targets_one_hot = F.one_hot(targets, num_classes=inputs.size(1))
p_t = (probs * targets_one_hot).sum(dim=1) # Shape: (N,)
# Compute focal weight: (1 - p_t)^gamma
# This up-weights hard examples (low p_t) and down-weights easy examples (high p_t)
focal_weight = (1.0 - p_t) ** self.gamma
# Compute cross-entropy: -log(p_t)
# Add epsilon for numerical stability
ce_loss = -torch.log(p_t + 1e-8)
# Combine: FL = focal_weight * ce_loss
focal_loss = focal_weight * ce_loss
# Apply class weights (alpha) if provided
if self.alpha is not None:
if self.alpha.device != inputs.device:
self.alpha = self.alpha.to(inputs.device)
# Get alpha for each sample based on its true class
alpha_t = self.alpha[targets] # Shape: (N,)
focal_loss = alpha_t * focal_loss
# Apply reduction
if self.reduction == 'none':
return focal_loss
elif self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
def compute_class_weights(targets, num_classes=7, minority_boost=1.8):
"""
Compute balanced class weights with optional boost for minority classes.
Args:
targets (array-like): Ground truth labels
num_classes (int): Total number of classes
minority_boost (float): Multiplicative boost for smallest classes (default 1.8)
Returns:
torch.Tensor: Class weights of shape [num_classes]
Example:
>>> targets = [0, 0, 1, 1, 1, 2]
>>> weights = compute_class_weights(targets, num_classes=3)
>>> # Class 2 (smallest) will have higher weight
"""
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
# Convert to numpy if needed
if torch.is_tensor(targets):
targets = targets.cpu().numpy()
# Compute balanced weights using sklearn
class_weights = compute_class_weight(
'balanced',
classes=np.arange(num_classes),
y=targets
)
# Identify minority classes (smallest 2-3 classes)
# Sort class counts to find minorities
unique, counts = np.unique(targets, return_counts=True)
class_counts = np.zeros(num_classes)
class_counts[unique] = counts
# Find classes below median count
median_count = np.median(class_counts[class_counts > 0])
minority_classes = np.where(class_counts < median_count)[0]
# Apply boost to minority classes (e.g., Classes 0 and 5)
for cls_idx in minority_classes:
if class_counts[cls_idx] > 0: # Only boost if class exists
class_weights[cls_idx] *= minority_boost
# Convert to torch tensor
weights_tensor = torch.FloatTensor(class_weights)
print(f"📊 Class Weights (with {minority_boost}x minority boost):")
for i in range(num_classes):
count = int(class_counts[i])
weight = class_weights[i]
boost_marker = " ⬆️ BOOSTED" if i in minority_classes else ""
print(f" Class {i}: count={count:5d}, weight={weight:.3f}{boost_marker}")
return weights_tensor
# Example usage and testing
if __name__ == "__main__":
print("🔥 Focal Loss Implementation Test\n")
# Test 1: Basic functionality
print("Test 1: Basic Focal Loss")
batch_size = 8
num_classes = 7
# Simulate logits and targets
logits = torch.randn(batch_size, num_classes)
targets = torch.tensor([0, 1, 2, 3, 4, 5, 6, 1])
# Create focal loss (no class weights)
focal_loss = FocalLoss(alpha=None, gamma=2.5)
loss = focal_loss(logits, targets)
print(f" Loss value: {loss.item():.4f}")
print(" ✅ Basic test passed\n")
# Test 2: With class weights
print("Test 2: Focal Loss with Class Weights")
class_weights = torch.tensor([2.0, 1.0, 1.0, 0.8, 1.2, 2.5, 1.5])
focal_loss_weighted = FocalLoss(alpha=class_weights, gamma=2.5)
loss_weighted = focal_loss_weighted(logits, targets)
print(f" Loss value: {loss_weighted.item():.4f}")
print(" ✅ Weighted test passed\n")
# Test 3: Compute class weights
print("Test 3: Compute Class Weights")
simulated_targets = torch.cat([
torch.zeros(100), # Class 0: 100 samples
torch.ones(200), # Class 1: 200 samples
torch.full((150,), 2), # Class 2: 150 samples
torch.full((300,), 3), # Class 3: 300 samples (largest)
torch.full((180,), 4), # Class 4: 180 samples
torch.full((80,), 5), # Class 5: 80 samples (smallest)
torch.full((120,), 6), # Class 6: 120 samples
]).long()
weights = compute_class_weights(simulated_targets, num_classes=7, minority_boost=1.8)
print(f"\n ✅ Class weight computation passed\n")
# Test 4: Gradient flow
print("Test 4: Gradient Flow")
logits.requires_grad = True
loss = focal_loss_weighted(logits, targets)
loss.backward()
print(f" Gradient exists: {logits.grad is not None}")
print(f" Gradient norm: {logits.grad.norm().item():.4f}")
print(" ✅ Gradient flow test passed\n")
print("✅ All tests passed! Focal Loss is ready for training.")
|