code2-repo / focal_loss.py
Deepu1965's picture
Upload folder using huggingface_hub
21613a7 verified
"""
Focal Loss Implementation for Multi-Class Classification
Focal Loss addresses class imbalance by focusing on hard-to-classify examples.
It down-weights easy examples and focuses training on hard negatives.
Formula: FL(p_t) = -Ξ±_t * (1 - p_t)^Ξ³ * log(p_t)
Where:
- p_t: predicted probability for true class
- Ξ±_t: class-specific weight (handles class imbalance)
- Ξ³: focusing parameter (default 2.0, recommended 2.5 for hard classes)
References:
- Lin et al. "Focal Loss for Dense Object Detection" (2017)
- https://arxiv.org/abs/1708.02002
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
"""
Focal Loss for multi-class classification with class weighting.
Args:
alpha (torch.Tensor or None): Class weights of shape [num_classes].
If None, all classes are weighted equally.
gamma (float): Focusing parameter. Higher values focus more on hard examples.
- gamma=0: equivalent to standard cross-entropy
- gamma=1: moderate focus on hard examples
- gamma=2: strong focus (original paper)
- gamma=2.5: very strong focus (recommended for this task)
reduction (str): Specifies the reduction to apply: 'none' | 'mean' | 'sum'
Shape:
- Input: (N, C) where N = batch size, C = number of classes
- Target: (N) where each value is 0 ≀ targets[i] ≀ C-1
- Output: scalar if reduction='mean' or 'sum', (N) if reduction='none'
"""
def __init__(self, alpha=None, gamma=2.5, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
# Validate gamma parameter
if gamma < 0:
raise ValueError(f"gamma must be non-negative, got {gamma}")
# Validate reduction parameter
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f"reduction must be 'none', 'mean', or 'sum', got {reduction}")
def forward(self, inputs, targets):
"""
Compute Focal Loss.
Args:
inputs (torch.Tensor): Raw logits from model (before softmax)
Shape: (batch_size, num_classes)
targets (torch.Tensor): Ground truth class labels
Shape: (batch_size,)
Returns:
torch.Tensor: Computed focal loss (scalar if reduction='mean'/'sum')
"""
# Convert logits to probabilities
probs = F.softmax(inputs, dim=1)
# Get the probability of the true class for each sample
# targets.unsqueeze(1) creates shape (N, 1) for gathering
targets_one_hot = F.one_hot(targets, num_classes=inputs.size(1))
p_t = (probs * targets_one_hot).sum(dim=1) # Shape: (N,)
# Compute focal weight: (1 - p_t)^gamma
# This up-weights hard examples (low p_t) and down-weights easy examples (high p_t)
focal_weight = (1.0 - p_t) ** self.gamma
# Compute cross-entropy: -log(p_t)
# Add epsilon for numerical stability
ce_loss = -torch.log(p_t + 1e-8)
# Combine: FL = focal_weight * ce_loss
focal_loss = focal_weight * ce_loss
# Apply class weights (alpha) if provided
if self.alpha is not None:
if self.alpha.device != inputs.device:
self.alpha = self.alpha.to(inputs.device)
# Get alpha for each sample based on its true class
alpha_t = self.alpha[targets] # Shape: (N,)
focal_loss = alpha_t * focal_loss
# Apply reduction
if self.reduction == 'none':
return focal_loss
elif self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
def compute_class_weights(targets, num_classes=7, minority_boost=1.8):
"""
Compute balanced class weights with optional boost for minority classes.
Args:
targets (array-like): Ground truth labels
num_classes (int): Total number of classes
minority_boost (float): Multiplicative boost for smallest classes (default 1.8)
Returns:
torch.Tensor: Class weights of shape [num_classes]
Example:
>>> targets = [0, 0, 1, 1, 1, 2]
>>> weights = compute_class_weights(targets, num_classes=3)
>>> # Class 2 (smallest) will have higher weight
"""
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
# Convert to numpy if needed
if torch.is_tensor(targets):
targets = targets.cpu().numpy()
# Compute balanced weights using sklearn
class_weights = compute_class_weight(
'balanced',
classes=np.arange(num_classes),
y=targets
)
# Identify minority classes (smallest 2-3 classes)
# Sort class counts to find minorities
unique, counts = np.unique(targets, return_counts=True)
class_counts = np.zeros(num_classes)
class_counts[unique] = counts
# Find classes below median count
median_count = np.median(class_counts[class_counts > 0])
minority_classes = np.where(class_counts < median_count)[0]
# Apply boost to minority classes (e.g., Classes 0 and 5)
for cls_idx in minority_classes:
if class_counts[cls_idx] > 0: # Only boost if class exists
class_weights[cls_idx] *= minority_boost
# Convert to torch tensor
weights_tensor = torch.FloatTensor(class_weights)
print(f"πŸ“Š Class Weights (with {minority_boost}x minority boost):")
for i in range(num_classes):
count = int(class_counts[i])
weight = class_weights[i]
boost_marker = " ⬆️ BOOSTED" if i in minority_classes else ""
print(f" Class {i}: count={count:5d}, weight={weight:.3f}{boost_marker}")
return weights_tensor
# Example usage and testing
if __name__ == "__main__":
print("πŸ”₯ Focal Loss Implementation Test\n")
# Test 1: Basic functionality
print("Test 1: Basic Focal Loss")
batch_size = 8
num_classes = 7
# Simulate logits and targets
logits = torch.randn(batch_size, num_classes)
targets = torch.tensor([0, 1, 2, 3, 4, 5, 6, 1])
# Create focal loss (no class weights)
focal_loss = FocalLoss(alpha=None, gamma=2.5)
loss = focal_loss(logits, targets)
print(f" Loss value: {loss.item():.4f}")
print(" βœ… Basic test passed\n")
# Test 2: With class weights
print("Test 2: Focal Loss with Class Weights")
class_weights = torch.tensor([2.0, 1.0, 1.0, 0.8, 1.2, 2.5, 1.5])
focal_loss_weighted = FocalLoss(alpha=class_weights, gamma=2.5)
loss_weighted = focal_loss_weighted(logits, targets)
print(f" Loss value: {loss_weighted.item():.4f}")
print(" βœ… Weighted test passed\n")
# Test 3: Compute class weights
print("Test 3: Compute Class Weights")
simulated_targets = torch.cat([
torch.zeros(100), # Class 0: 100 samples
torch.ones(200), # Class 1: 200 samples
torch.full((150,), 2), # Class 2: 150 samples
torch.full((300,), 3), # Class 3: 300 samples (largest)
torch.full((180,), 4), # Class 4: 180 samples
torch.full((80,), 5), # Class 5: 80 samples (smallest)
torch.full((120,), 6), # Class 6: 120 samples
]).long()
weights = compute_class_weights(simulated_targets, num_classes=7, minority_boost=1.8)
print(f"\n βœ… Class weight computation passed\n")
# Test 4: Gradient flow
print("Test 4: Gradient Flow")
logits.requires_grad = True
loss = focal_loss_weighted(logits, targets)
loss.backward()
print(f" Gradient exists: {logits.grad is not None}")
print(f" Gradient norm: {logits.grad.norm().item():.4f}")
print(" βœ… Gradient flow test passed\n")
print("βœ… All tests passed! Focal Loss is ready for training.")