compvis / test /augmentation /test_augmentation_mix.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
import torch
from kornia.augmentation import RandomCutMix, RandomMixUp
from kornia.testing import assert_close
class TestRandomMixUp:
def test_smoke(self, device, dtype):
f = RandomMixUp()
repr = "RandomMixUp(lambda_val=None, p=1.0, p_batch=1.0, same_on_batch=False)"
assert str(f) == repr
def test_random_mixup_p1(self, device, dtype):
torch.manual_seed(0)
f = RandomMixUp(p=1.0)
input = torch.stack(
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)]
)
label = torch.tensor([1, 0], device=device)
lam = torch.tensor([0.1320, 0.3074], device=device, dtype=dtype)
expected = torch.stack(
[
torch.ones(1, 3, 4, device=device, dtype=dtype) * (1 - lam[0]),
torch.ones(1, 3, 4, device=device, dtype=dtype) * lam[1],
]
)
out_image, out_label = f(input, label)
assert_close(out_image, expected, rtol=1e-4, atol=1e-4)
assert (out_label[:, 0] == label).all()
assert (out_label[:, 1] == torch.tensor([0, 1], device=device, dtype=dtype)).all()
assert_close(out_label[:, 2], lam, rtol=1e-4, atol=1e-4)
def test_random_mixup_p0(self, device, dtype):
torch.manual_seed(0)
f = RandomMixUp(p=0.0)
input = torch.stack(
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)]
)
label = torch.tensor([1, 0], device=device)
# TODO(jian): where is it used ?
# lam = torch.tensor([0.0, 0.0], device=device, dtype=dtype)
expected = input.clone()
out_image, out_label = f(input, label)
assert_close(out_image, expected, rtol=1e-4, atol=1e-4)
assert (out_label == label).all()
def test_random_mixup_lam0(self, device, dtype):
torch.manual_seed(0)
f = RandomMixUp(lambda_val=(0.0, 0.0), p=1.0)
input = torch.stack(
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)]
)
label = torch.tensor([1, 0], device=device)
lam = torch.tensor([0.0, 0.0], device=device, dtype=dtype)
expected = input.clone()
out_image, out_label = f(input, label)
assert_close(out_image, expected, rtol=1e-4, atol=1e-4)
assert (out_label[:, 0] == label).all()
assert (out_label[:, 1] == torch.tensor([0, 1], device=device)).all()
assert_close(out_label[:, 2], lam, rtol=1e-4, atol=1e-4)
def test_random_mixup_same_on_batch(self, device, dtype):
torch.manual_seed(0)
f = RandomMixUp(same_on_batch=True, p=1.0)
input = torch.stack(
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)]
)
label = torch.tensor([1, 0], device=device)
lam = torch.tensor([0.0885, 0.0885], device=device, dtype=dtype)
expected = torch.stack(
[
torch.ones(1, 3, 4, device=device, dtype=dtype) * (1 - lam[0]),
torch.ones(1, 3, 4, device=device, dtype=dtype) * lam[1],
]
)
out_image, out_label = f(input, label)
assert_close(out_image, expected, rtol=1e-4, atol=1e-4)
assert (out_label[:, 0] == label).all()
assert (out_label[:, 1] == torch.tensor([0, 1], device=device, dtype=dtype)).all()
assert_close(out_label[:, 2], lam, rtol=1e-4, atol=1e-4)
class TestRandomCutMix:
def test_smoke(self, device, dtype):
f = RandomCutMix(width=3, height=3)
repr = (
"RandomCutMix(num_mix=1, beta=None, cut_size=None, height=3, width=3, p=1.0, "
"p_batch=1.0, same_on_batch=False)"
)
assert str(f) == repr
def test_random_mixup_p1(self, device, dtype):
torch.manual_seed(76)
f = RandomCutMix(width=4, height=3, p=1.0)
input = torch.stack(
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)]
)
label = torch.tensor([1, 0], device=device)
# TODO(jian): where is it used ?
# lam = torch.tensor([0.1320, 0.3074], device=device, dtype=dtype)
expected = torch.tensor(
[
[[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
[[[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]],
],
device=device,
dtype=dtype,
)
out_image, out_label = f(input, label)
assert_close(out_image, expected, rtol=1e-4, atol=1e-4)
assert (out_label[0, :, 0] == label).all()
assert (out_label[0, :, 1] == torch.tensor([0, 1], device=device, dtype=dtype)).all()
assert (out_label[0, :, 2] == torch.tensor([0.5, 0.5], device=device, dtype=dtype)).all()
def test_random_mixup_p0(self, device, dtype):
torch.manual_seed(76)
f = RandomCutMix(p=0.0, width=4, height=3)
input = torch.stack(
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)]
)
label = torch.tensor([1, 0], device=device)
expected = input.clone()
out_image, out_label = f(input, label)
assert_close(out_image, expected, rtol=1e-4, atol=1e-4)
assert (out_label == label).all()
def test_random_mixup_beta0(self, device, dtype):
torch.manual_seed(76)
# beta 0 => resample 0.5 area
# beta cannot be 0 after torch 1.8.0
f = RandomCutMix(beta=1e-7, width=4, height=3, p=1.0)
input = torch.stack(
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)]
)
label = torch.tensor([1, 0], device=device)
expected = torch.tensor(
[
[[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
[[[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]],
],
device=device,
dtype=dtype,
)
out_image, out_label = f(input, label)
assert_close(out_image, expected, rtol=1e-4, atol=1e-4)
assert (out_label[0, :, 0] == label).all()
assert (out_label[0, :, 1] == torch.tensor([0, 1], device=device, dtype=dtype)).all()
# cut area = 4 / 12
assert_close(out_label[0, :, 2], torch.tensor([0.33333, 0.33333], device=device, dtype=dtype))
def test_random_mixup_num2(self, device, dtype):
torch.manual_seed(76)
f = RandomCutMix(width=4, height=3, num_mix=5, p=1.0)
input = torch.stack(
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)]
)
label = torch.tensor([1, 0], device=device)
expected = torch.tensor(
[
[[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
[[[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]],
],
device=device,
dtype=dtype,
)
out_image, out_label = f(input, label)
assert_close(out_image, expected, rtol=1e-4, atol=1e-4)
assert (out_label[:, :, 0] == label).all()
assert (out_label[:, :, 1] == torch.tensor([[1, 0], [1, 0], [1, 0], [1, 0], [0, 1]], device=device)).all()
assert_close(
out_label[:, :, 2],
torch.tensor(
[[0.0833, 0.3333], [0.0, 0.1667], [0.5, 0.0833], [0.0833, 0.0], [0.5, 0.3333]],
device=device,
dtype=dtype,
),
rtol=1e-4,
atol=1e-4,
)
def test_random_mixup_same_on_batch(self, device, dtype):
torch.manual_seed(42)
f = RandomCutMix(same_on_batch=True, width=4, height=3, p=1.0)
input = torch.stack(
[torch.ones(1, 3, 4, device=device, dtype=dtype), torch.zeros(1, 3, 4, device=device, dtype=dtype)]
)
label = torch.tensor([1, 0], device=device)
# TODO(jian): where is it used ?
# lam = torch.tensor([0.0885, 0.0885], device=device, dtype=dtype)
expected = torch.tensor(
[
[[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
[[[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]],
],
device=device,
dtype=dtype,
)
out_image, out_label = f(input, label)
assert_close(out_image, expected, rtol=1e-4, atol=1e-4)
assert (out_label[0, :, 0] == label).all()
assert (out_label[0, :, 1] == torch.tensor([0, 1], device=device, dtype=dtype)).all()
assert_close(
out_label[0, :, 2], torch.tensor([0.5000, 0.5000], device=device, dtype=dtype), rtol=1e-4, atol=1e-4
)