import torch import torch.nn.functional as F def focal_loss(inputs, targets, alpha=0.5, gamma=2, reduction='mean'): logpt = F.cross_entropy(inputs, targets.long(), reduction='none') pt = torch.exp(-logpt) focal_loss = (1 - pt) ** gamma * logpt alpha_weight = alpha * targets + (1 - alpha) * (1 - targets) focal_loss = alpha_weight * focal_loss if reduction == 'mean': return torch.mean(focal_loss) elif reduction == 'sum': return torch.sum(focal_loss) else: return focal_loss def dice_loss(inputs, targets, epsilon=1e-7): targets_one_hot = torch.nn.functional.one_hot(targets.long(), num_classes=inputs.shape[1]) targets_one_hot = targets_one_hot.permute(0, 3, 1, 2).float() inputs = F.softmax(inputs, dim=1) targets_one_hot = targets_one_hot.type(inputs.type()) numerator = 2 * (inputs * targets_one_hot).sum(dim=(2,3)) denominator = inputs.sum(dim=(2,3)) + targets_one_hot.sum(dim=(2,3)) dice_coefficient = numerator / (denominator + epsilon) return 1 - dice_coefficient.mean()