|
|
import torch |
|
|
|
|
|
from kornia.augmentation import RandomCutMix, RandomMixUp |
|
|
from kornia.testing import assert_close |
|
|
|
|
|
|
|
|
class TestRandomMixUp: |
|
|
def test_smoke(self, device, dtype): |
|
|
f = RandomMixUp() |
|
|
repr = "RandomMixUp(lambda_val=None, p=1.0, p_batch=1.0, same_on_batch=False)" |
|
|
assert str(f) == repr |
|
|
|
|
|
def test_random_mixup_p1(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
f = RandomMixUp(p=1.0) |
|
|
|
|
|
input = torch.stack( |
|
|
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)] |
|
|
) |
|
|
label = torch.tensor([1, 0], device=device) |
|
|
lam = torch.tensor([0.1320, 0.3074], device=device, dtype=dtype) |
|
|
|
|
|
expected = torch.stack( |
|
|
[ |
|
|
torch.ones(1, 3, 4, device=device, dtype=dtype) * (1 - lam[0]), |
|
|
torch.ones(1, 3, 4, device=device, dtype=dtype) * lam[1], |
|
|
] |
|
|
) |
|
|
|
|
|
out_image, out_label = f(input, label) |
|
|
|
|
|
assert_close(out_image, expected, rtol=1e-4, atol=1e-4) |
|
|
assert (out_label[:, 0] == label).all() |
|
|
assert (out_label[:, 1] == torch.tensor([0, 1], device=device, dtype=dtype)).all() |
|
|
assert_close(out_label[:, 2], lam, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
def test_random_mixup_p0(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
f = RandomMixUp(p=0.0) |
|
|
|
|
|
input = torch.stack( |
|
|
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)] |
|
|
) |
|
|
label = torch.tensor([1, 0], device=device) |
|
|
|
|
|
|
|
|
|
|
|
expected = input.clone() |
|
|
|
|
|
out_image, out_label = f(input, label) |
|
|
|
|
|
assert_close(out_image, expected, rtol=1e-4, atol=1e-4) |
|
|
assert (out_label == label).all() |
|
|
|
|
|
def test_random_mixup_lam0(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
f = RandomMixUp(lambda_val=(0.0, 0.0), p=1.0) |
|
|
|
|
|
input = torch.stack( |
|
|
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)] |
|
|
) |
|
|
label = torch.tensor([1, 0], device=device) |
|
|
lam = torch.tensor([0.0, 0.0], device=device, dtype=dtype) |
|
|
|
|
|
expected = input.clone() |
|
|
|
|
|
out_image, out_label = f(input, label) |
|
|
|
|
|
assert_close(out_image, expected, rtol=1e-4, atol=1e-4) |
|
|
assert (out_label[:, 0] == label).all() |
|
|
assert (out_label[:, 1] == torch.tensor([0, 1], device=device)).all() |
|
|
assert_close(out_label[:, 2], lam, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
def test_random_mixup_same_on_batch(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
f = RandomMixUp(same_on_batch=True, p=1.0) |
|
|
|
|
|
input = torch.stack( |
|
|
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)] |
|
|
) |
|
|
label = torch.tensor([1, 0], device=device) |
|
|
lam = torch.tensor([0.0885, 0.0885], device=device, dtype=dtype) |
|
|
|
|
|
expected = torch.stack( |
|
|
[ |
|
|
torch.ones(1, 3, 4, device=device, dtype=dtype) * (1 - lam[0]), |
|
|
torch.ones(1, 3, 4, device=device, dtype=dtype) * lam[1], |
|
|
] |
|
|
) |
|
|
|
|
|
out_image, out_label = f(input, label) |
|
|
assert_close(out_image, expected, rtol=1e-4, atol=1e-4) |
|
|
assert (out_label[:, 0] == label).all() |
|
|
assert (out_label[:, 1] == torch.tensor([0, 1], device=device, dtype=dtype)).all() |
|
|
assert_close(out_label[:, 2], lam, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
|
|
|
class TestRandomCutMix: |
|
|
def test_smoke(self, device, dtype): |
|
|
f = RandomCutMix(width=3, height=3) |
|
|
repr = ( |
|
|
"RandomCutMix(num_mix=1, beta=None, cut_size=None, height=3, width=3, p=1.0, " |
|
|
"p_batch=1.0, same_on_batch=False)" |
|
|
) |
|
|
assert str(f) == repr |
|
|
|
|
|
def test_random_mixup_p1(self, device, dtype): |
|
|
torch.manual_seed(76) |
|
|
f = RandomCutMix(width=4, height=3, p=1.0) |
|
|
|
|
|
input = torch.stack( |
|
|
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)] |
|
|
) |
|
|
label = torch.tensor([1, 0], device=device) |
|
|
|
|
|
|
|
|
|
|
|
expected = torch.tensor( |
|
|
[ |
|
|
[[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0]]], |
|
|
[[[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
out_image, out_label = f(input, label) |
|
|
|
|
|
assert_close(out_image, expected, rtol=1e-4, atol=1e-4) |
|
|
assert (out_label[0, :, 0] == label).all() |
|
|
assert (out_label[0, :, 1] == torch.tensor([0, 1], device=device, dtype=dtype)).all() |
|
|
assert (out_label[0, :, 2] == torch.tensor([0.5, 0.5], device=device, dtype=dtype)).all() |
|
|
|
|
|
def test_random_mixup_p0(self, device, dtype): |
|
|
torch.manual_seed(76) |
|
|
f = RandomCutMix(p=0.0, width=4, height=3) |
|
|
|
|
|
input = torch.stack( |
|
|
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)] |
|
|
) |
|
|
label = torch.tensor([1, 0], device=device) |
|
|
|
|
|
expected = input.clone() |
|
|
|
|
|
out_image, out_label = f(input, label) |
|
|
|
|
|
assert_close(out_image, expected, rtol=1e-4, atol=1e-4) |
|
|
assert (out_label == label).all() |
|
|
|
|
|
def test_random_mixup_beta0(self, device, dtype): |
|
|
torch.manual_seed(76) |
|
|
|
|
|
|
|
|
f = RandomCutMix(beta=1e-7, width=4, height=3, p=1.0) |
|
|
|
|
|
input = torch.stack( |
|
|
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)] |
|
|
) |
|
|
label = torch.tensor([1, 0], device=device) |
|
|
|
|
|
expected = torch.tensor( |
|
|
[ |
|
|
[[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]], |
|
|
[[[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
out_image, out_label = f(input, label) |
|
|
|
|
|
assert_close(out_image, expected, rtol=1e-4, atol=1e-4) |
|
|
assert (out_label[0, :, 0] == label).all() |
|
|
assert (out_label[0, :, 1] == torch.tensor([0, 1], device=device, dtype=dtype)).all() |
|
|
|
|
|
assert_close(out_label[0, :, 2], torch.tensor([0.33333, 0.33333], device=device, dtype=dtype)) |
|
|
|
|
|
def test_random_mixup_num2(self, device, dtype): |
|
|
torch.manual_seed(76) |
|
|
f = RandomCutMix(width=4, height=3, num_mix=5, p=1.0) |
|
|
|
|
|
input = torch.stack( |
|
|
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)] |
|
|
) |
|
|
label = torch.tensor([1, 0], device=device) |
|
|
|
|
|
expected = torch.tensor( |
|
|
[ |
|
|
[[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0]]], |
|
|
[[[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
out_image, out_label = f(input, label) |
|
|
|
|
|
assert_close(out_image, expected, rtol=1e-4, atol=1e-4) |
|
|
assert (out_label[:, :, 0] == label).all() |
|
|
assert (out_label[:, :, 1] == torch.tensor([[1, 0], [1, 0], [1, 0], [1, 0], [0, 1]], device=device)).all() |
|
|
assert_close( |
|
|
out_label[:, :, 2], |
|
|
torch.tensor( |
|
|
[[0.0833, 0.3333], [0.0, 0.1667], [0.5, 0.0833], [0.0833, 0.0], [0.5, 0.3333]], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
), |
|
|
rtol=1e-4, |
|
|
atol=1e-4, |
|
|
) |
|
|
|
|
|
def test_random_mixup_same_on_batch(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
f = RandomCutMix(same_on_batch=True, width=4, height=3, p=1.0) |
|
|
|
|
|
input = torch.stack( |
|
|
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)] |
|
|
) |
|
|
label = torch.tensor([1, 0], device=device) |
|
|
|
|
|
|
|
|
|
|
|
expected = torch.tensor( |
|
|
[ |
|
|
[[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0]]], |
|
|
[[[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
out_image, out_label = f(input, label) |
|
|
|
|
|
assert_close(out_image, expected, rtol=1e-4, atol=1e-4) |
|
|
assert (out_label[0, :, 0] == label).all() |
|
|
assert (out_label[0, :, 1] == torch.tensor([0, 1], device=device, dtype=dtype)).all() |
|
|
assert_close( |
|
|
out_label[0, :, 2], torch.tensor([0.5000, 0.5000], device=device, dtype=dtype), rtol=1e-4, atol=1e-4 |
|
|
) |
|
|
|