compvis / test /enhance /test_equalization.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
from typing import Tuple
import pytest
import torch
from torch.autograd import gradcheck
from kornia import enhance
from kornia.geometry import rotate
from kornia.testing import assert_close, BaseTester, tensor_to_gradcheck_var
class TestEqualization(BaseTester):
def test_smoke(self, device, dtype):
C, H, W = 1, 10, 20
img = torch.rand(C, H, W, device=device, dtype=dtype)
res = enhance.equalize_clahe(img)
assert isinstance(res, torch.Tensor)
assert res.shape == img.shape
assert res.device == img.device
assert res.dtype == img.dtype
@pytest.mark.parametrize("B, C", [(None, 1), (None, 3), (1, 1), (1, 3), (4, 1), (4, 3)])
def test_cardinality(self, B, C, device, dtype):
H, W = 10, 20
if B is None:
img = torch.rand(C, H, W, device=device, dtype=dtype)
else:
img = torch.rand(B, C, H, W, device=device, dtype=dtype)
res = enhance.equalize_clahe(img)
assert res.shape == img.shape
@pytest.mark.parametrize("clip, grid", [(0.0, None), (None, (2, 2)), (2.0, (2, 2))])
def test_optional_params(self, clip, grid, device, dtype):
C, H, W = 1, 10, 20
img = torch.rand(C, H, W, device=device, dtype=dtype)
if clip is None:
res = enhance.equalize_clahe(img, grid_size=grid)
elif grid is None:
res = enhance.equalize_clahe(img, clip_limit=clip)
else:
res = enhance.equalize_clahe(img, clip, grid)
assert isinstance(res, torch.Tensor)
assert res.shape == img.shape
@pytest.mark.parametrize(
"B, clip, grid, exception_type",
[
(0, 1.0, (2, 2), ValueError),
(1, 1, (2, 2), TypeError),
(1, 2.0, 2, TypeError),
(1, 2.0, (2, 2, 2), TypeError),
(1, 2.0, (2, 2.0), TypeError),
(1, 2.0, (2, 0), ValueError),
],
)
def test_exception(self, B, clip, grid, exception_type):
C, H, W = 1, 10, 20
img = torch.rand(B, C, H, W)
with pytest.raises(exception_type):
enhance.equalize_clahe(img, clip, grid)
@pytest.mark.parametrize("dims", [(1, 1, 1, 1, 1), (1, 1)])
def test_exception_tensor_dims(self, dims):
img = torch.rand(dims)
with pytest.raises(ValueError):
enhance.equalize_clahe(img)
def test_exception_tensor_type(self):
with pytest.raises(TypeError):
enhance.equalize_clahe([1, 2, 3])
def test_gradcheck(self, device, dtype):
torch.random.manual_seed(4)
bs, channels, height, width = 1, 1, 11, 11
inputs = torch.rand(bs, channels, height, width, device=device, dtype=dtype)
inputs = tensor_to_gradcheck_var(inputs)
def grad_rot(input, a, b, c):
rot = rotate(input, torch.tensor(30.0, dtype=input.dtype, device=device))
return enhance.equalize_clahe(rot, a, b, c)
assert gradcheck(grad_rot, (inputs, 40.0, (2, 2), True), nondet_tol=1e-4, raise_exception=True)
@pytest.mark.skip(reason="args and kwargs in decorator")
def test_jit(self, device, dtype):
batch_size, channels, height, width = 1, 2, 10, 20
inp = torch.rand(batch_size, channels, height, width, device=device, dtype=dtype)
op = enhance.equalize_clahe
op_script = torch.jit.script(op)
assert_close(op(inp), op_script(inp))
def test_module(self):
# equalize_clahe is only a function
pass
@pytest.fixture()
def img(self, device, dtype):
height, width = 20, 20
# TODO: test with a more realistic pattern
img = torch.arange(width, device=device).div(float(width - 1))[None].expand(height, width)[None][None]
return img
def test_he(self, img):
# should be similar to enhance.equalize but slower. Similar because the lut is computed in a different way.
clip_limit: float = 0.0
grid_size: Tuple = (1, 1)
res = enhance.equalize_clahe(img, clip_limit=clip_limit, grid_size=grid_size)
# NOTE: for next versions we need to improve the computation of the LUT
# and test with a better image
assert torch.allclose(
res[..., 0, :],
torch.tensor(
[
[
[
0.0471,
0.0980,
0.1490,
0.2000,
0.2471,
0.2980,
0.3490,
0.3490,
0.4471,
0.4471,
0.5490,
0.5490,
0.6471,
0.6471,
0.6980,
0.7490,
0.8000,
0.8471,
0.8980,
1.0000,
]
]
],
dtype=res.dtype,
device=res.device,
),
atol=1e-04,
rtol=1e-04,
)
def test_ahe(self, img):
clip_limit: float = 0.0
grid_size: Tuple = (8, 8)
res = enhance.equalize_clahe(img, clip_limit=clip_limit, grid_size=grid_size)
# NOTE: for next versions we need to improve the computation of the LUT
# and test with a better image
assert torch.allclose(
res[..., 0, :],
torch.tensor(
[
[
[
0.2471,
0.4980,
0.7490,
0.6667,
0.4980,
0.4980,
0.7490,
0.4993,
0.4980,
0.2471,
0.7490,
0.4993,
0.4980,
0.2471,
0.4980,
0.4993,
0.3333,
0.2471,
0.4980,
1.0000,
]
]
],
dtype=res.dtype,
device=res.device,
),
atol=1e-04,
rtol=1e-04,
)
def test_clahe(self, img):
clip_limit: float = 2.0
grid_size: Tuple = (8, 8)
res = enhance.equalize_clahe(img, clip_limit=clip_limit, grid_size=grid_size)
res_diff = enhance.equalize_clahe(img, clip_limit=clip_limit, grid_size=grid_size, slow_and_differentiable=True)
# NOTE: for next versions we need to improve the computation of the LUT
# and test with a better image
expected = torch.tensor(
[
[
[
0.1216,
0.8745,
0.9373,
0.9163,
0.8745,
0.8745,
0.9373,
0.8745,
0.8745,
0.8118,
0.9373,
0.8745,
0.8745,
0.8118,
0.8745,
0.8745,
0.8327,
0.8118,
0.8745,
1.0000,
]
]
],
dtype=res.dtype,
device=res.device,
)
exp_diff = torch.tensor(
[
[
[
0.1250,
0.8752,
0.9042,
0.9167,
0.8401,
0.8852,
0.9302,
0.9120,
0.8750,
0.8370,
0.9620,
0.9077,
0.8750,
0.8754,
0.9204,
0.9167,
0.8370,
0.8806,
0.9096,
1.0000,
]
]
],
dtype=res.dtype,
device=res.device,
)
assert torch.allclose(res[..., 0, :], expected, atol=1e-04, rtol=1e-04)
assert torch.allclose(res_diff[..., 0, :], exp_diff, atol=1e-04, rtol=1e-04)