""" Focal Loss Implementation for Multi-Class Classification Focal Loss addresses class imbalance by focusing on hard-to-classify examples. It down-weights easy examples and focuses training on hard negatives. Formula: FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t) Where: - p_t: predicted probability for true class - α_t: class-specific weight (handles class imbalance) - γ: focusing parameter (default 2.0, recommended 2.5 for hard classes) References: - Lin et al. "Focal Loss for Dense Object Detection" (2017) - https://arxiv.org/abs/1708.02002 """ import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): """ Focal Loss for multi-class classification with class weighting. Args: alpha (torch.Tensor or None): Class weights of shape [num_classes]. If None, all classes are weighted equally. gamma (float): Focusing parameter. Higher values focus more on hard examples. - gamma=0: equivalent to standard cross-entropy - gamma=1: moderate focus on hard examples - gamma=2: strong focus (original paper) - gamma=2.5: very strong focus (recommended for this task) reduction (str): Specifies the reduction to apply: 'none' | 'mean' | 'sum' Shape: - Input: (N, C) where N = batch size, C = number of classes - Target: (N) where each value is 0 ≤ targets[i] ≤ C-1 - Output: scalar if reduction='mean' or 'sum', (N) if reduction='none' """ def __init__(self, alpha=None, gamma=2.5, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction # Validate gamma parameter if gamma < 0: raise ValueError(f"gamma must be non-negative, got {gamma}") # Validate reduction parameter if reduction not in ['none', 'mean', 'sum']: raise ValueError(f"reduction must be 'none', 'mean', or 'sum', got {reduction}") def forward(self, inputs, targets): """ Compute Focal Loss. Args: inputs (torch.Tensor): Raw logits from model (before softmax) Shape: (batch_size, num_classes) targets (torch.Tensor): Ground truth class labels Shape: (batch_size,) Returns: torch.Tensor: Computed focal loss (scalar if reduction='mean'/'sum') """ # Convert logits to probabilities probs = F.softmax(inputs, dim=1) # Get the probability of the true class for each sample # targets.unsqueeze(1) creates shape (N, 1) for gathering targets_one_hot = F.one_hot(targets, num_classes=inputs.size(1)) p_t = (probs * targets_one_hot).sum(dim=1) # Shape: (N,) # Compute focal weight: (1 - p_t)^gamma # This up-weights hard examples (low p_t) and down-weights easy examples (high p_t) focal_weight = (1.0 - p_t) ** self.gamma # Compute cross-entropy: -log(p_t) # Add epsilon for numerical stability ce_loss = -torch.log(p_t + 1e-8) # Combine: FL = focal_weight * ce_loss focal_loss = focal_weight * ce_loss # Apply class weights (alpha) if provided if self.alpha is not None: if self.alpha.device != inputs.device: self.alpha = self.alpha.to(inputs.device) # Get alpha for each sample based on its true class alpha_t = self.alpha[targets] # Shape: (N,) focal_loss = alpha_t * focal_loss # Apply reduction if self.reduction == 'none': return focal_loss elif self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() def compute_class_weights(targets, num_classes=7, minority_boost=1.8): """ Compute balanced class weights with optional boost for minority classes. Args: targets (array-like): Ground truth labels num_classes (int): Total number of classes minority_boost (float): Multiplicative boost for smallest classes (default 1.8) Returns: torch.Tensor: Class weights of shape [num_classes] Example: >>> targets = [0, 0, 1, 1, 1, 2] >>> weights = compute_class_weights(targets, num_classes=3) >>> # Class 2 (smallest) will have higher weight """ from sklearn.utils.class_weight import compute_class_weight import numpy as np # Convert to numpy if needed if torch.is_tensor(targets): targets = targets.cpu().numpy() # Compute balanced weights using sklearn class_weights = compute_class_weight( 'balanced', classes=np.arange(num_classes), y=targets ) # Identify minority classes (smallest 2-3 classes) # Sort class counts to find minorities unique, counts = np.unique(targets, return_counts=True) class_counts = np.zeros(num_classes) class_counts[unique] = counts # Find classes below median count median_count = np.median(class_counts[class_counts > 0]) minority_classes = np.where(class_counts < median_count)[0] # Apply boost to minority classes (e.g., Classes 0 and 5) for cls_idx in minority_classes: if class_counts[cls_idx] > 0: # Only boost if class exists class_weights[cls_idx] *= minority_boost # Convert to torch tensor weights_tensor = torch.FloatTensor(class_weights) print(f"📊 Class Weights (with {minority_boost}x minority boost):") for i in range(num_classes): count = int(class_counts[i]) weight = class_weights[i] boost_marker = " ⬆️ BOOSTED" if i in minority_classes else "" print(f" Class {i}: count={count:5d}, weight={weight:.3f}{boost_marker}") return weights_tensor # Example usage and testing if __name__ == "__main__": print("🔥 Focal Loss Implementation Test\n") # Test 1: Basic functionality print("Test 1: Basic Focal Loss") batch_size = 8 num_classes = 7 # Simulate logits and targets logits = torch.randn(batch_size, num_classes) targets = torch.tensor([0, 1, 2, 3, 4, 5, 6, 1]) # Create focal loss (no class weights) focal_loss = FocalLoss(alpha=None, gamma=2.5) loss = focal_loss(logits, targets) print(f" Loss value: {loss.item():.4f}") print(" ✅ Basic test passed\n") # Test 2: With class weights print("Test 2: Focal Loss with Class Weights") class_weights = torch.tensor([2.0, 1.0, 1.0, 0.8, 1.2, 2.5, 1.5]) focal_loss_weighted = FocalLoss(alpha=class_weights, gamma=2.5) loss_weighted = focal_loss_weighted(logits, targets) print(f" Loss value: {loss_weighted.item():.4f}") print(" ✅ Weighted test passed\n") # Test 3: Compute class weights print("Test 3: Compute Class Weights") simulated_targets = torch.cat([ torch.zeros(100), # Class 0: 100 samples torch.ones(200), # Class 1: 200 samples torch.full((150,), 2), # Class 2: 150 samples torch.full((300,), 3), # Class 3: 300 samples (largest) torch.full((180,), 4), # Class 4: 180 samples torch.full((80,), 5), # Class 5: 80 samples (smallest) torch.full((120,), 6), # Class 6: 120 samples ]).long() weights = compute_class_weights(simulated_targets, num_classes=7, minority_boost=1.8) print(f"\n ✅ Class weight computation passed\n") # Test 4: Gradient flow print("Test 4: Gradient Flow") logits.requires_grad = True loss = focal_loss_weighted(logits, targets) loss.backward() print(f" Gradient exists: {logits.grad is not None}") print(f" Gradient norm: {logits.grad.norm().item():.4f}") print(" ✅ Gradient flow test passed\n") print("✅ All tests passed! Focal Loss is ready for training.")