File size: 817 Bytes
c5f4ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import torch.nn.functional as F


e = 1-10


def dice_loss(pred, target, need_sigmoid=True):
    assert target.size() == pred.size()
    if need_sigmoid:
        pred = torch.sigmoid(pred)
    intersect = 2 * (pred * target).sum() + e
    union = (pred * pred).sum() + (target * target).sum() + e
    return 1 - intersect / union


class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        return dice_loss(pred=pred, target=target)
    

class DiceBCE(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        return  0.5 * dice_loss(pred=pred, target=target) + \
              0.5 * F.binary_cross_entropy_with_logits(input=pred, target=target)