Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # from torch.autograd import Variable | |
| class FocalLoss(nn.Module): | |
| def __init__(self, focusing_param=2, balance_param=0.25): #TODO try changing balance_param | |
| super(FocalLoss, self).__init__() | |
| self.focusing_param = focusing_param | |
| self.balance_param = balance_param | |
| def compute(self, output, target): | |
| logpt = - F.cross_entropy(output, target, reduction='sum') | |
| pt = torch.exp(logpt) | |
| focal_loss = -((1 - pt) ** self.focusing_param) * logpt | |
| balanced_focal_loss = self.balance_param * focal_loss | |
| return balanced_focal_loss | |
| # def test_focal_loss(): | |
| # loss = FocalLoss() | |
| # input = Variable(torch.randn(3, 5), requires_grad=True) | |
| # target = Variable(torch.LongTensor(3).random_(5)) | |
| # print(input) | |
| # print(target) | |
| # output = loss(input, target) | |
| # print(output) | |
| # output.backward() | |
| if __name__=='__main__': | |
| test_focal_loss() | |