File size: 5,185 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import numpy as np
import pytest
import torch
import torch.nn as nn
from scipy.ndimage import convolve
from torch.autograd import gradcheck

import kornia
import kornia.testing as utils
from kornia.testing import assert_close


class HausdorffERLossNumpy(nn.Module):
    """Binary Hausdorff loss based on morphological erosion.

    Taken from https://github.com/PatRyg99/HausdorffLoss/blob/master/hausdorff_loss.py
    """

    def __init__(self, alpha=2.0, erosions=10, **kwargs):
        super().__init__()
        self.alpha = alpha
        self.erosions = erosions
        self.prepare_kernels()

    def prepare_kernels(self):
        cross = np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0]]])
        bound = np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])

        self.kernel2D = cross * 0.2
        # NOTE: updated from np.array([bound, cross, bound]) * (1 / 7)
        self.kernel3D = np.array([bound, cross, bound]).squeeze()[None] * (1 / 7)

    @torch.no_grad()
    def perform_erosion(self, pred: np.ndarray, target: np.ndarray) -> np.ndarray:
        bound = (pred - target) ** 2

        if bound.ndim == 5:
            kernel = self.kernel3D
        elif bound.ndim == 4:
            kernel = self.kernel2D
        else:
            raise ValueError(f"Dimension {bound.ndim} is nor supported.")

        eroted = np.zeros_like(bound)

        for batch in range(len(bound)):

            for k in range(self.erosions):

                # compute convolution with kernel
                dilation = convolve(bound[batch], kernel, mode="constant", cval=0.0)

                # apply soft thresholding at 0.5 and normalize
                erosion = dilation - 0.5
                erosion[erosion < 0] = 0

                if erosion.ptp() != 0:
                    erosion = (erosion - erosion.min()) / erosion.ptp()

                # save erosion and add to loss
                bound[batch] = erosion
                eroted[batch] += erosion * (k + 1) ** self.alpha

        return eroted

    def forward_one(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Uses one binary channel: 1 - fg, 0 - bg
        pred: (b, 1, x, y, z) or (b, 1, x, y)
        target: (b, 1, x, y, z) or (b, 1, x, y)
        """
        assert pred.size(1) == target.size(1) == 1
        # pred = torch.sigmoid(pred)

        eroted = torch.from_numpy(self.perform_erosion(pred.cpu().numpy(), target.cpu().numpy())).to(
            dtype=pred.dtype, device=pred.device
        )

        loss = eroted.mean()

        return loss

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Uses one binary channel: 1 - fg, 0 - bg
        pred: (b, 1, x, y, z) or (b, 1, x, y)
        target: (b, 1, x, y, z) or (b, 1, x, y)
        """
        assert pred.dim() == 4 or pred.dim() == 5, "Only 2D and 3D supported"
        assert pred.dim() == target.dim() and target.size(1) == 1, "Prediction and target need to be of same dimension"
        return torch.stack(
            [
                self.forward_one(
                    pred[:, i : i + 1],
                    torch.where(
                        target == i,
                        torch.tensor(1, device=target.device, dtype=target.dtype),
                        torch.tensor(0, device=target.device, dtype=target.dtype),
                    ),
                )
                for i in range(pred.size(1))
            ]
        ).mean()


class TestHausdorffLoss:
    @pytest.mark.parametrize("reduction", ['mean', 'none', 'sum'])
    @pytest.mark.parametrize(
        "hd,shape", [[kornia.losses.HausdorffERLoss, (10, 10)], [kornia.losses.HausdorffERLoss3D, (10, 10, 10)]]
    )
    def test_smoke_none(self, hd, shape, reduction, device, dtype):
        num_classes = 3
        logits = torch.rand(2, num_classes, *shape, dtype=dtype, device=device)
        labels = (torch.rand(2, 1, *shape, dtype=dtype, device=device) * (num_classes - 1)).long()
        loss = hd(reduction=reduction)

        loss(logits, labels)

    @pytest.mark.parametrize(
        "hd,shape", [[kornia.losses.HausdorffERLoss, (50, 50)], [kornia.losses.HausdorffERLoss3D, (50, 50, 50)]]
    )
    def test_numeric(self, hd, shape, device, dtype):
        num_classes = 3
        logits = torch.rand(2, num_classes, *shape, dtype=dtype, device=device)
        labels = (torch.rand(2, 1, *shape, dtype=dtype, device=device) * (num_classes - 1)).long()
        loss = hd(k=10)
        loss_np = HausdorffERLossNumpy(erosions=10)

        expected = loss_np(logits, labels)
        actual = loss(logits, labels)
        assert_close(actual, expected)

    @pytest.mark.parametrize(
        "hd,shape", [[kornia.losses.HausdorffERLoss, (5, 5)], [kornia.losses.HausdorffERLoss3D, (5, 5, 5)]]
    )
    def test_gradcheck(self, hd, shape, device):
        num_classes = 3
        logits = torch.rand(2, num_classes, *shape, device=device)
        labels = (torch.rand(2, 1, *shape, device=device) * (num_classes - 1)).long()
        loss = hd(k=2)

        logits = utils.tensor_to_gradcheck_var(logits)  # to var
        assert gradcheck(loss, (logits, labels), raise_exception=True)