| import pytest | |
| import torch | |
| import kornia | |
| class TestOneHot: | |
| def test_smoke(self): | |
| num_classes = 4 | |
| labels = torch.zeros(2, 2, 1, dtype=torch.int64) | |
| labels[0, 0, 0] = 0 | |
| labels[0, 1, 0] = 1 | |
| labels[1, 0, 0] = 2 | |
| labels[1, 1, 0] = 3 | |
| # convert labels to one hot tensor | |
| one_hot = kornia.utils.one_hot(labels, num_classes) | |
| assert pytest.approx(one_hot[0, labels[0, 0, 0], 0, 0].item(), 1.0) | |
| assert pytest.approx(one_hot[0, labels[0, 1, 0], 1, 0].item(), 1.0) | |
| assert pytest.approx(one_hot[1, labels[1, 0, 0], 0, 0].item(), 1.0) | |
| assert pytest.approx(one_hot[1, labels[1, 1, 0], 1, 0].item(), 1.0) | |