| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class FocalLoss(nn.Module): | |
| """Focal Loss for imbalanced classification""" | |
| def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'): | |
| super().__init__() | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| self.reduction = reduction | |
| def forward(self, inputs, targets): | |
| ce_loss = F.cross_entropy(inputs, targets, reduction='none') | |
| pt = torch.exp(-ce_loss) | |
| loss = self.alpha * (1 - pt) ** self.gamma * ce_loss | |
| return loss.mean() if self.reduction == 'mean' else loss.sum() | |