File size: 602 Bytes
4c1e73e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
"""Focal Loss for imbalanced classification"""
def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
return loss.mean() if self.reduction == 'mean' else loss.sum()
|