File size: 1,073 Bytes
006869b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn.functional as F


def focal_loss(inputs, targets, alpha=0.5, gamma=2, reduction='mean'):
    logpt = F.cross_entropy(inputs, targets.long(), reduction='none')
    pt = torch.exp(-logpt)
    focal_loss = (1 - pt) ** gamma * logpt
    alpha_weight = alpha * targets + (1 - alpha) * (1 - targets)
    focal_loss = alpha_weight * focal_loss

    if reduction == 'mean':
        return torch.mean(focal_loss)
    elif reduction == 'sum':
        return torch.sum(focal_loss)
    else:
        return focal_loss


def dice_loss(inputs, targets, epsilon=1e-7):
    targets_one_hot = torch.nn.functional.one_hot(targets.long(), num_classes=inputs.shape[1])
    targets_one_hot = targets_one_hot.permute(0, 3, 1, 2).float()
    inputs = F.softmax(inputs, dim=1)
    targets_one_hot = targets_one_hot.type(inputs.type())
    numerator = 2 * (inputs * targets_one_hot).sum(dim=(2,3))
    denominator = inputs.sum(dim=(2,3)) + targets_one_hot.sum(dim=(2,3))
    dice_coefficient = numerator / (denominator + epsilon)
    return 1 - dice_coefficient.mean()