compvis / test /geometry /subpix /test_spatial_softargmax.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
import pytest
import torch
from torch.autograd import gradcheck
from torch.nn.functional import mse_loss
import kornia
import kornia.testing as utils # test utils
from kornia.geometry.subpix.spatial_soft_argmax import _get_center_kernel2d, _get_center_kernel3d
from kornia.testing import assert_close
class TestCenterKernel2d:
def test_smoke(self, device, dtype):
kernel = _get_center_kernel2d(3, 4, device=device).to(dtype=dtype)
assert kernel.shape == (2, 2, 3, 4)
def test_odd(self, device, dtype):
kernel = _get_center_kernel2d(3, 3, device=device).to(dtype=dtype)
expected = torch.tensor(
[
[
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
],
[
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
],
],
device=device,
dtype=dtype,
)
assert_close(kernel, expected, atol=1e-4, rtol=1e-4)
def test_even(self, device, dtype):
kernel = _get_center_kernel2d(2, 2, device=device).to(dtype=dtype)
expected = torch.ones(2, 2, 2, 2, device=device, dtype=dtype) * 0.25
expected[0, 1] = 0
expected[1, 0] = 0
assert_close(kernel, expected, atol=1e-4, rtol=1e-4)
class TestCenterKernel3d:
def test_smoke(self, device, dtype):
kernel = _get_center_kernel3d(6, 3, 4, device=device).to(dtype=dtype)
assert kernel.shape == (3, 3, 6, 3, 4)
def test_odd(self, device, dtype):
kernel = _get_center_kernel3d(3, 5, 7, device=device).to(dtype=dtype)
expected = torch.zeros(3, 3, 3, 5, 7, device=device, dtype=dtype)
expected[0, 0, 1, 2, 3] = 1.0
expected[1, 1, 1, 2, 3] = 1.0
expected[2, 2, 1, 2, 3] = 1.0
assert_close(kernel, expected, atol=1e-4, rtol=1e-4)
def test_even(self, device, dtype):
kernel = _get_center_kernel3d(2, 4, 3, device=device).to(dtype=dtype)
expected = torch.zeros(3, 3, 2, 4, 3, device=device, dtype=dtype)
expected[0, 0, :, 1:3, 1] = 0.25
expected[1, 1, :, 1:3, 1] = 0.25
expected[2, 2, :, 1:3, 1] = 0.25
assert_close(kernel, expected, atol=1e-4, rtol=1e-4)
class TestSpatialSoftArgmax2d:
def test_smoke(self, device, dtype):
input = torch.zeros(1, 1, 2, 3, device=device, dtype=dtype)
m = kornia.geometry.subpix.SpatialSoftArgmax2d()
assert m(input).shape == (1, 1, 2)
def test_smoke_batch(self, device, dtype):
input = torch.zeros(2, 1, 2, 3, device=device, dtype=dtype)
m = kornia.geometry.subpix.SpatialSoftArgmax2d()
assert m(input).shape == (2, 1, 2)
def test_top_left_normalized(self, device, dtype):
input = torch.zeros(1, 1, 2, 3, device=device, dtype=dtype)
input[..., 0, 0] = 1e16
coord = kornia.geometry.subpix.spatial_soft_argmax2d(input, normalized_coordinates=True)
assert_close(coord[..., 0].item(), -1.0, atol=1e-4, rtol=1e-4)
assert_close(coord[..., 1].item(), -1.0, atol=1e-4, rtol=1e-4)
def test_top_left(self, device, dtype):
input = torch.zeros(1, 1, 2, 3, device=device, dtype=dtype)
input[..., 0, 0] = 1e16
coord = kornia.geometry.subpix.spatial_soft_argmax2d(input, normalized_coordinates=False)
assert_close(coord[..., 0].item(), 0.0, atol=1e-4, rtol=1e-4)
assert_close(coord[..., 1].item(), 0.0, atol=1e-4, rtol=1e-4)
def test_bottom_right_normalized(self, device, dtype):
input = torch.zeros(1, 1, 2, 3, device=device, dtype=dtype)
input[..., -1, -1] = 1e16
coord = kornia.geometry.subpix.spatial_soft_argmax2d(input, normalized_coordinates=True)
assert_close(coord[..., 0].item(), 1.0, atol=1e-4, rtol=1e-4)
assert_close(coord[..., 1].item(), 1.0, atol=1e-4, rtol=1e-4)
def test_bottom_right(self, device, dtype):
input = torch.zeros(1, 1, 2, 3, device=device, dtype=dtype)
input[..., -1, -1] = 1e16
coord = kornia.geometry.subpix.spatial_soft_argmax2d(input, normalized_coordinates=False)
assert_close(coord[..., 0].item(), 2.0, atol=1e-4, rtol=1e-4)
assert_close(coord[..., 1].item(), 1.0, atol=1e-4, rtol=1e-4)
def test_batch2_n2(self, device, dtype):
input = torch.zeros(2, 2, 2, 3, device=device, dtype=dtype)
input[0, 0, 0, 0] = 1e16 # top-left
input[0, 1, 0, -1] = 1e16 # top-right
input[1, 0, -1, 0] = 1e16 # bottom-left
input[1, 1, -1, -1] = 1e16 # bottom-right
coord = kornia.geometry.subpix.spatial_soft_argmax2d(input)
assert_close(coord[0, 0, 0].item(), -1.0, atol=1e-4, rtol=1e-4) # top-left
assert_close(coord[0, 0, 1].item(), -1.0, atol=1e-4, rtol=1e-4)
assert_close(coord[0, 1, 0].item(), 1.0, atol=1e-4, rtol=1e-4) # top-right
assert_close(coord[0, 1, 1].item(), -1.0, atol=1e-4, rtol=1e-4)
assert_close(coord[1, 0, 0].item(), -1.0, atol=1e-4, rtol=1e-4) # bottom-left
assert_close(coord[1, 0, 1].item(), 1.0, atol=1e-4, rtol=1e-4)
assert_close(coord[1, 1, 0].item(), 1.0, atol=1e-4, rtol=1e-4) # bottom-right
assert_close(coord[1, 1, 1].item(), 1.0, atol=1e-4, rtol=1e-4)
def test_gradcheck(self, device, dtype):
input = torch.rand(2, 3, 3, 2, device=device, dtype=dtype)
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(kornia.geometry.subpix.spatial_soft_argmax2d, (input), raise_exception=True)
def test_end_to_end(self, device, dtype):
input = torch.full((1, 2, 7, 7), 1.0, requires_grad=True, device=device, dtype=dtype)
target = torch.as_tensor([[[0.0, 0.0], [1.0, 1.0]]], device=device, dtype=dtype)
std = torch.tensor([1.0, 1.0], device=device, dtype=dtype)
hm = kornia.geometry.subpix.spatial_softmax2d(input)
assert_close(hm.sum(-1).sum(-1), torch.tensor([[1.0, 1.0]], device=device, dtype=dtype), atol=1e-4, rtol=1e-4)
pred = kornia.geometry.subpix.spatial_expectation2d(hm)
assert_close(
pred, torch.as_tensor([[[0.0, 0.0], [0.0, 0.0]]], device=device, dtype=dtype), atol=1e-4, rtol=1e-4
)
loss1 = mse_loss(pred, target, size_average=None, reduce=None, reduction='none').mean(-1, keepdim=False)
expected_loss1 = torch.as_tensor([[0.0, 1.0]], device=device, dtype=dtype)
assert_close(loss1, expected_loss1, atol=1e-4, rtol=1e-4)
target_hm = kornia.geometry.subpix.render_gaussian2d(target, std, input.shape[-2:]).contiguous()
loss2 = kornia.losses.js_div_loss_2d(hm, target_hm, reduction='none')
expected_loss2 = torch.as_tensor([[0.0087, 0.0818]], device=device, dtype=dtype)
assert_close(loss2, expected_loss2, rtol=0, atol=1e-3)
loss = (loss1 + loss2).mean()
loss.backward()
def test_jit(self, device, dtype):
input = torch.rand((2, 3, 7, 7), dtype=dtype, device=device)
op = kornia.geometry.subpix.spatial_soft_argmax2d
op_jit = torch.jit.script(op)
assert_close(op(input), op_jit(input), rtol=0, atol=1e-5)
@pytest.mark.skip(reason="it works but raises some warnings.")
def test_jit_trace(self, device, dtype):
input = torch.rand((2, 3, 7, 7), dtype=dtype, device=device)
op = kornia.geometry.subpix.spatial_soft_argmax2d
op_jit = torch.jit.trace(op, (input,))
assert_close(op(input), op_jit(input), rtol=0, atol=1e-5)
class TestConvSoftArgmax2d:
def test_smoke(self, device, dtype):
input = torch.zeros(1, 1, 3, 3, device=device, dtype=dtype)
m = kornia.geometry.subpix.ConvSoftArgmax2d((3, 3))
assert m(input).shape == (1, 1, 2, 3, 3)
def test_smoke_batch(self, device, dtype):
input = torch.zeros(2, 5, 3, 3, device=device, dtype=dtype)
m = kornia.geometry.subpix.ConvSoftArgmax2d()
assert m(input).shape == (2, 5, 2, 3, 3)
def test_smoke_with_val(self, device, dtype):
input = torch.zeros(1, 1, 3, 3, device=device, dtype=dtype)
m = kornia.geometry.subpix.ConvSoftArgmax2d((3, 3), output_value=True)
coords, val = m(input)
assert coords.shape == (1, 1, 2, 3, 3)
assert val.shape == (1, 1, 3, 3)
def test_smoke_batch_with_val(self, device, dtype):
input = torch.zeros(2, 5, 3, 3, device=device, dtype=dtype)
m = kornia.geometry.subpix.ConvSoftArgmax2d((3, 3), output_value=True)
coords, val = m(input)
assert coords.shape == (2, 5, 2, 3, 3)
assert val.shape == (2, 5, 3, 3)
def test_gradcheck(self, device, dtype):
input = torch.rand(2, 3, 5, 5, device=device, dtype=dtype)
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(kornia.geometry.subpix.conv_soft_argmax2d, (input), raise_exception=True)
def test_cold_diag(self, device, dtype):
input = torch.tensor(
[
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
]
],
device=device,
dtype=dtype,
)
softargmax = kornia.geometry.subpix.ConvSoftArgmax2d(
(3, 3), (2, 2), (0, 0), temperature=0.05, normalized_coordinates=False, output_value=True
)
expected_val = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=device, dtype=dtype)
expected_coord = torch.tensor(
[[[[[1.0, 3.0], [1.0, 3.0]], [[1.0, 1.0], [3.0, 3.0]]]]], device=device, dtype=dtype
)
coords, val = softargmax(input)
assert_close(val, expected_val, atol=1e-4, rtol=1e-4)
assert_close(coords, expected_coord, atol=1e-4, rtol=1e-4)
def test_hot_diag(self, device, dtype):
input = torch.tensor(
[
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
]
],
device=device,
dtype=dtype,
)
softargmax = kornia.geometry.subpix.ConvSoftArgmax2d(
(3, 3), (2, 2), (0, 0), temperature=10.0, normalized_coordinates=False, output_value=True
)
expected_val = torch.tensor([[[[0.1214, 0.0], [0.0, 0.1214]]]], device=device, dtype=dtype)
expected_coord = torch.tensor(
[[[[[1.0, 3.0], [1.0, 3.0]], [[1.0, 1.0], [3.0, 3.0]]]]], device=device, dtype=dtype
)
coords, val = softargmax(input)
assert_close(val, expected_val, atol=1e-4, rtol=1e-4)
assert_close(coords, expected_coord, atol=1e-4, rtol=1e-4)
def test_cold_diag_norm(self, device, dtype):
input = torch.tensor(
[
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
]
],
device=device,
dtype=dtype,
)
softargmax = kornia.geometry.subpix.ConvSoftArgmax2d(
(3, 3), (2, 2), (0, 0), temperature=0.05, normalized_coordinates=True, output_value=True
)
expected_val = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=device, dtype=dtype)
expected_coord = torch.tensor(
[[[[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, -0.5], [0.5, 0.5]]]]], device=device, dtype=dtype
)
coords, val = softargmax(input)
assert_close(val, expected_val, atol=1e-4, rtol=1e-4)
assert_close(coords, expected_coord, atol=1e-4, rtol=1e-4)
def test_hot_diag_norm(self, device, dtype):
input = torch.tensor(
[
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
]
],
device=device,
dtype=dtype,
)
softargmax = kornia.geometry.subpix.ConvSoftArgmax2d(
(3, 3), (2, 2), (0, 0), temperature=10.0, normalized_coordinates=True, output_value=True
)
expected_val = torch.tensor([[[[0.1214, 0.0], [0.0, 0.1214]]]], device=device, dtype=dtype)
expected_coord = torch.tensor(
[[[[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, -0.5], [0.5, 0.5]]]]], device=device, dtype=dtype
)
coords, val = softargmax(input)
assert_close(val, expected_val, atol=1e-4, rtol=1e-4)
assert_close(coords, expected_coord, atol=1e-4, rtol=1e-4)
class TestConvSoftArgmax3d:
def test_smoke(self, device, dtype):
input = torch.zeros(1, 1, 3, 3, 3, device=device, dtype=dtype)
m = kornia.geometry.subpix.ConvSoftArgmax3d((3, 3, 3), output_value=False)
assert m(input).shape == (1, 1, 3, 3, 3, 3)
def test_smoke_with_val(self, device, dtype):
input = torch.zeros(1, 1, 3, 3, 3, device=device, dtype=dtype)
m = kornia.geometry.subpix.ConvSoftArgmax3d((3, 3, 3), output_value=True)
coords, val = m(input)
assert coords.shape == (1, 1, 3, 3, 3, 3)
assert val.shape == (1, 1, 3, 3, 3)
def test_gradcheck(self, device, dtype):
input = torch.rand(1, 2, 3, 5, 5, device=device, dtype=dtype)
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(kornia.geometry.subpix.conv_soft_argmax3d, (input), raise_exception=True)
def test_cold_diag(self, device, dtype):
input = torch.tensor(
[
[
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
]
]
],
device=device,
dtype=dtype,
)
softargmax = kornia.geometry.subpix.ConvSoftArgmax3d(
(1, 3, 3), (1, 2, 2), (0, 0, 0), temperature=0.05, normalized_coordinates=False, output_value=True
)
expected_val = torch.tensor([[[[[1.0, 0.0], [0.0, 1.0]]]]], device=device, dtype=dtype)
expected_coord = torch.tensor(
[[[[[[0.0, 0.0], [0.0, 0.0]]], [[[1.0, 3.0], [1.0, 3.0]]], [[[1.0, 1.0], [3.0, 3.0]]]]]],
device=device,
dtype=dtype,
)
coords, val = softargmax(input)
assert_close(val, expected_val, atol=1e-4, rtol=1e-4)
assert_close(coords, expected_coord, atol=1e-4, rtol=1e-4)
def test_hot_diag(self, device, dtype):
input = torch.tensor(
[
[
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
]
]
],
device=device,
dtype=dtype,
)
softargmax = kornia.geometry.subpix.ConvSoftArgmax3d(
(1, 3, 3), (1, 2, 2), (0, 0, 0), temperature=10.0, normalized_coordinates=False, output_value=True
)
expected_val = torch.tensor([[[[[0.1214, 0.0], [0.0, 0.1214]]]]], device=device, dtype=dtype)
expected_coord = torch.tensor(
[[[[[[0.0, 0.0], [0.0, 0.0]]], [[[1.0, 3.0], [1.0, 3.0]]], [[[1.0, 1.0], [3.0, 3.0]]]]]],
device=device,
dtype=dtype,
)
coords, val = softargmax(input)
assert_close(val, expected_val, atol=1e-4, rtol=1e-4)
assert_close(coords, expected_coord, atol=1e-4, rtol=1e-4)
def test_cold_diag_norm(self, device, dtype):
input = torch.tensor(
[
[
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
]
]
],
device=device,
dtype=dtype,
)
softargmax = kornia.geometry.subpix.ConvSoftArgmax3d(
(1, 3, 3), (1, 2, 2), (0, 0, 0), temperature=0.05, normalized_coordinates=True, output_value=True
)
expected_val = torch.tensor([[[[[1.0, 0.0], [0.0, 1.0]]]]], device=device, dtype=dtype)
expected_coord = torch.tensor(
[[[[[[-1.0, -1.0], [-1.0, -1.0]]], [[[-0.5, 0.5], [-0.5, 0.5]]], [[[-0.5, -0.5], [0.5, 0.5]]]]]],
device=device,
dtype=dtype,
)
coords, val = softargmax(input)
assert_close(val, expected_val, atol=1e-4, rtol=1e-4)
assert_close(coords, expected_coord, atol=1e-4, rtol=1e-4)
def test_hot_diag_norm(self, device, dtype):
input = torch.tensor(
[
[
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
]
]
],
device=device,
dtype=dtype,
)
softargmax = kornia.geometry.subpix.ConvSoftArgmax3d(
(1, 3, 3), (1, 2, 2), (0, 0, 0), temperature=10.0, normalized_coordinates=True, output_value=True
)
expected_val = torch.tensor([[[[[0.1214, 0.0], [0.0, 0.1214]]]]], device=device, dtype=dtype)
expected_coord = torch.tensor(
[[[[[[-1.0, -1.0], [-1.0, -1.0]]], [[[-0.5, 0.5], [-0.5, 0.5]]], [[[-0.5, -0.5], [0.5, 0.5]]]]]],
device=device,
dtype=dtype,
)
coords, val = softargmax(input)
assert_close(val, expected_val, atol=1e-4, rtol=1e-4)
assert_close(coords, expected_coord, atol=1e-4, rtol=1e-4)
class TestConvQuadInterp3d:
@pytest.mark.skipif(
(int(torch.__version__.split('.')[0]) == 1) and (int(torch.__version__.split('.')[1]) < 9),
reason='<1.9.0 not supporting',
)
def test_smoke(self, device, dtype):
input = torch.randn(2, 3, 3, 4, 4, device=device, dtype=dtype)
nms = kornia.geometry.ConvQuadInterp3d(1)
coord, val = nms(input)
assert coord.shape == (2, 3, 3, 3, 4, 4)
assert val.shape == (2, 3, 3, 4, 4)
@pytest.mark.skipif(
(int(torch.__version__.split('.')[0]) == 1) and (int(torch.__version__.split('.')[1]) < 9),
reason='<1.9.0 not supporting',
)
def test_gradcheck(self, device, dtype):
input = torch.rand(1, 1, 3, 5, 5, device=device, dtype=dtype)
input[0, 0, 1, 2, 2] += 20.0
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(
kornia.geometry.ConvQuadInterp3d(strict_maxima_bonus=0), (input), raise_exception=True, atol=1e-3, rtol=1e-3
)
def test_diag(self, device, dtype):
input = torch.tensor(
[
[
[
[0.0, 0.0, 0.0, 0, 0],
[0.0, 0.0, 0.0, 0, 0.0],
[0.0, 0, 0.0, 0, 0.0],
[0.0, 0.0, 0, 0, 0.0],
[0.0, 0.0, 0.0, 0, 0.0],
],
[
[0.0, 0.0, 0.0, 0, 0],
[0.0, 0.0, 1, 0, 0.0],
[0.0, 1, 1.2, 1.1, 0.0],
[0.0, 0.0, 1.0, 0, 0.0],
[0.0, 0.0, 0.0, 0, 0.0],
],
[
[0.0, 0.0, 0.0, 0, 0],
[0.0, 0.0, 0.0, 0, 0.0],
[0.0, 0, 0.0, 0, 0.0],
[0.0, 0.0, 0, 0, 0.0],
[0.0, 0.0, 0.0, 0, 0.0],
],
]
],
device=device,
dtype=dtype,
)
input = kornia.filters.gaussian_blur2d(input, (5, 5), (0.5, 0.5)).unsqueeze(0)
softargmax = kornia.geometry.ConvQuadInterp3d(10)
expected_val = torch.tensor(
[
[
[
[
[0.0, 0.0, 0.0, 0, 0],
[0.0, 0.0, 0.0, 0, 0.0],
[0.0, 0, 0.0, 0, 0.0],
[0.0, 0.0, 0, 0, 0.0],
[0.0, 0.0, 0.0, 0, 0.0],
],
[
[2.2504e-04, 2.3146e-02, 1.6808e-01, 2.3188e-02, 2.3628e-04],
[2.3146e-02, 1.8118e-01, 7.4338e-01, 1.8955e-01, 2.5413e-02],
[1.6807e-01, 7.4227e-01, 1.1086e01, 8.0414e-01, 1.8482e-01],
[2.3146e-02, 1.8118e-01, 7.4338e-01, 1.8955e-01, 2.5413e-02],
[2.2504e-04, 2.3146e-02, 1.6808e-01, 2.3188e-02, 2.3628e-04],
],
[
[0.0, 0.0, 0.0, 0, 0],
[0.0, 0.0, 0.0, 0, 0.0],
[0.0, 0, 0.0, 0, 0.0],
[0.0, 0.0, 0, 0, 0.0],
[0.0, 0.0, 0.0, 0, 0.0],
],
]
]
],
device=device,
dtype=dtype,
)
expected_coord = torch.tensor(
[
[
[
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
],
[
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
],
[
[2.0, 2.0, 2.0, 2.0, 2.0],
[2.0, 2.0, 2.0, 2.0, 2.0],
[2.0, 2.0, 2.0, 2.0, 2.0],
[2.0, 2.0, 2.0, 2.0, 2.0],
[2.0, 2.0, 2.0, 2.0, 2.0],
],
],
[
[
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
],
[
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
],
[
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0],
],
],
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0, 4.0],
],
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0495, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0, 4.0],
],
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0, 4.0],
],
],
]
]
],
device=device,
dtype=dtype,
)
coords, val = softargmax(input)
assert_close(val, expected_val, atol=1e-4, rtol=1e-4)
assert_close(coords, expected_coord, atol=1e-4, rtol=1e-4)