|
|
import numpy as np |
|
|
import pytest |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from scipy.ndimage import convolve |
|
|
from torch.autograd import gradcheck |
|
|
|
|
|
import kornia |
|
|
import kornia.testing as utils |
|
|
from kornia.testing import assert_close |
|
|
|
|
|
|
|
|
class HausdorffERLossNumpy(nn.Module): |
|
|
"""Binary Hausdorff loss based on morphological erosion. |
|
|
|
|
|
Taken from https://github.com/PatRyg99/HausdorffLoss/blob/master/hausdorff_loss.py |
|
|
""" |
|
|
|
|
|
def __init__(self, alpha=2.0, erosions=10, **kwargs): |
|
|
super().__init__() |
|
|
self.alpha = alpha |
|
|
self.erosions = erosions |
|
|
self.prepare_kernels() |
|
|
|
|
|
def prepare_kernels(self): |
|
|
cross = np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0]]]) |
|
|
bound = np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]) |
|
|
|
|
|
self.kernel2D = cross * 0.2 |
|
|
|
|
|
self.kernel3D = np.array([bound, cross, bound]).squeeze()[None] * (1 / 7) |
|
|
|
|
|
@torch.no_grad() |
|
|
def perform_erosion(self, pred: np.ndarray, target: np.ndarray) -> np.ndarray: |
|
|
bound = (pred - target) ** 2 |
|
|
|
|
|
if bound.ndim == 5: |
|
|
kernel = self.kernel3D |
|
|
elif bound.ndim == 4: |
|
|
kernel = self.kernel2D |
|
|
else: |
|
|
raise ValueError(f"Dimension {bound.ndim} is nor supported.") |
|
|
|
|
|
eroted = np.zeros_like(bound) |
|
|
|
|
|
for batch in range(len(bound)): |
|
|
|
|
|
for k in range(self.erosions): |
|
|
|
|
|
|
|
|
dilation = convolve(bound[batch], kernel, mode="constant", cval=0.0) |
|
|
|
|
|
|
|
|
erosion = dilation - 0.5 |
|
|
erosion[erosion < 0] = 0 |
|
|
|
|
|
if erosion.ptp() != 0: |
|
|
erosion = (erosion - erosion.min()) / erosion.ptp() |
|
|
|
|
|
|
|
|
bound[batch] = erosion |
|
|
eroted[batch] += erosion * (k + 1) ** self.alpha |
|
|
|
|
|
return eroted |
|
|
|
|
|
def forward_one(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Uses one binary channel: 1 - fg, 0 - bg |
|
|
pred: (b, 1, x, y, z) or (b, 1, x, y) |
|
|
target: (b, 1, x, y, z) or (b, 1, x, y) |
|
|
""" |
|
|
assert pred.size(1) == target.size(1) == 1 |
|
|
|
|
|
|
|
|
eroted = torch.from_numpy(self.perform_erosion(pred.cpu().numpy(), target.cpu().numpy())).to( |
|
|
dtype=pred.dtype, device=pred.device |
|
|
) |
|
|
|
|
|
loss = eroted.mean() |
|
|
|
|
|
return loss |
|
|
|
|
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Uses one binary channel: 1 - fg, 0 - bg |
|
|
pred: (b, 1, x, y, z) or (b, 1, x, y) |
|
|
target: (b, 1, x, y, z) or (b, 1, x, y) |
|
|
""" |
|
|
assert pred.dim() == 4 or pred.dim() == 5, "Only 2D and 3D supported" |
|
|
assert pred.dim() == target.dim() and target.size(1) == 1, "Prediction and target need to be of same dimension" |
|
|
return torch.stack( |
|
|
[ |
|
|
self.forward_one( |
|
|
pred[:, i : i + 1], |
|
|
torch.where( |
|
|
target == i, |
|
|
torch.tensor(1, device=target.device, dtype=target.dtype), |
|
|
torch.tensor(0, device=target.device, dtype=target.dtype), |
|
|
), |
|
|
) |
|
|
for i in range(pred.size(1)) |
|
|
] |
|
|
).mean() |
|
|
|
|
|
|
|
|
class TestHausdorffLoss: |
|
|
@pytest.mark.parametrize("reduction", ['mean', 'none', 'sum']) |
|
|
@pytest.mark.parametrize( |
|
|
"hd,shape", [[kornia.losses.HausdorffERLoss, (10, 10)], [kornia.losses.HausdorffERLoss3D, (10, 10, 10)]] |
|
|
) |
|
|
def test_smoke_none(self, hd, shape, reduction, device, dtype): |
|
|
num_classes = 3 |
|
|
logits = torch.rand(2, num_classes, *shape, dtype=dtype, device=device) |
|
|
labels = (torch.rand(2, 1, *shape, dtype=dtype, device=device) * (num_classes - 1)).long() |
|
|
loss = hd(reduction=reduction) |
|
|
|
|
|
loss(logits, labels) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"hd,shape", [[kornia.losses.HausdorffERLoss, (50, 50)], [kornia.losses.HausdorffERLoss3D, (50, 50, 50)]] |
|
|
) |
|
|
def test_numeric(self, hd, shape, device, dtype): |
|
|
num_classes = 3 |
|
|
logits = torch.rand(2, num_classes, *shape, dtype=dtype, device=device) |
|
|
labels = (torch.rand(2, 1, *shape, dtype=dtype, device=device) * (num_classes - 1)).long() |
|
|
loss = hd(k=10) |
|
|
loss_np = HausdorffERLossNumpy(erosions=10) |
|
|
|
|
|
expected = loss_np(logits, labels) |
|
|
actual = loss(logits, labels) |
|
|
assert_close(actual, expected) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"hd,shape", [[kornia.losses.HausdorffERLoss, (5, 5)], [kornia.losses.HausdorffERLoss3D, (5, 5, 5)]] |
|
|
) |
|
|
def test_gradcheck(self, hd, shape, device): |
|
|
num_classes = 3 |
|
|
logits = torch.rand(2, num_classes, *shape, device=device) |
|
|
labels = (torch.rand(2, 1, *shape, device=device) * (num_classes - 1)).long() |
|
|
loss = hd(k=2) |
|
|
|
|
|
logits = utils.tensor_to_gradcheck_var(logits) |
|
|
assert gradcheck(loss, (logits, labels), raise_exception=True) |
|
|
|