| import torch | |
| import torch.nn as nn | |
| class FocalLoss(nn.Module): | |
| def __init__(self, gamma=2.0, weight=None): | |
| super(FocalLoss, self).__init__() | |
| self.gamma = gamma | |
| self.ce = nn.CrossEntropyLoss(weight=weight) | |
| def forward(self, logits, targets): | |
| logits = logits.float() | |
| ce_loss = self.ce(logits, targets) | |
| pt = torch.exp(-ce_loss) | |
| return ((1 - pt) ** self.gamma * ce_loss).mean() | |