| | import pytest |
| | import torch |
| |
|
| | from kornia.augmentation.random_generator import ( |
| | center_crop_generator, |
| | random_affine_generator, |
| | random_color_jitter_generator, |
| | random_crop_generator, |
| | random_crop_size_generator, |
| | random_cutmix_generator, |
| | random_mixup_generator, |
| | random_motion_blur_generator, |
| | random_perspective_generator, |
| | random_posterize_generator, |
| | random_prob_generator, |
| | random_rectangles_params_generator, |
| | random_rotation_generator, |
| | random_sharpness_generator, |
| | random_solarize_generator, |
| | ) |
| | from kornia.testing import assert_close |
| | from kornia.utils._compat import torch_version_geq |
| |
|
| |
|
| | class RandomGeneratorBaseTests: |
| | def test_valid_param_combinations(self, device, dtype): |
| | raise NotImplementedError |
| |
|
| | def test_invalid_param_combinations(self, device, dtype): |
| | raise NotImplementedError |
| |
|
| | def test_random_gen(self, device, dtype): |
| | raise NotImplementedError |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | raise NotImplementedError |
| |
|
| |
|
| | class TestRandomProbGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('p', [0.0, 0.5, 1.0]) |
| | @pytest.mark.parametrize('batch_size', [0, 1, 8]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations(self, p, batch_size, same_on_batch, device, dtype): |
| | random_prob_generator(batch_size=batch_size, p=p, same_on_batch=same_on_batch) |
| |
|
| | @pytest.mark.parametrize( |
| | 'p', |
| | [ |
| | |
| | (-1.0), |
| | (2.0), |
| | ], |
| | ) |
| | def test_invalid_param_combinations(self, p, device, dtype): |
| | with pytest.raises(Exception): |
| | random_prob_generator(batch_size=8, p=p) |
| |
|
| | @pytest.mark.parametrize( |
| | 'p,expected', |
| | [(0.0, [False] * 8), (0.5, [False, False, True, False, True, False, True, False]), (1.0, [True] * 8)], |
| | ) |
| | def test_random_gen(self, p, expected, device, dtype): |
| | torch.manual_seed(42) |
| | batch_size = 8 |
| | res = random_prob_generator(batch_size=batch_size, p=p) |
| | assert (res == torch.tensor(expected)).long().sum() == batch_size |
| |
|
| | @pytest.mark.parametrize("seed,expected", [(42, [False] * 8), (0, [True] * 8)]) |
| | def test_same_on_batch(self, seed, expected, device, dtype): |
| | torch.manual_seed(seed) |
| | batch_size = 8 |
| | res = random_prob_generator(batch_size=batch_size, p=0.5, same_on_batch=True) |
| | assert (res == torch.tensor(expected)).long().sum() == batch_size |
| |
|
| |
|
| | class TestColorJitterGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('brightness', [None, torch.tensor([0.8, 1.2])]) |
| | @pytest.mark.parametrize('contrast', [None, torch.tensor([0.8, 1.2])]) |
| | @pytest.mark.parametrize('saturation', [None, torch.tensor([0.8, 1.2])]) |
| | @pytest.mark.parametrize('hue', [None, torch.tensor([-0.1, 0.1])]) |
| | @pytest.mark.parametrize('batch_size', [0, 1, 8]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations( |
| | self, brightness, contrast, saturation, hue, batch_size, same_on_batch, device, dtype |
| | ): |
| | random_color_jitter_generator( |
| | batch_size, |
| | brightness.to(device=device, dtype=dtype) if brightness is not None else None, |
| | contrast.to(device=device, dtype=dtype) if contrast is not None else None, |
| | saturation.to(device=device, dtype=dtype) if saturation is not None else None, |
| | hue.to(device=device, dtype=dtype) if hue is not None else None, |
| | same_on_batch, |
| | ) |
| |
|
| | @pytest.mark.parametrize( |
| | 'brightness,contrast,saturation,hue', |
| | [ |
| | |
| | (torch.tensor([-1.0, 2.0]), None, None, None), |
| | (torch.tensor([0.0, 3.0]), None, None, None), |
| | (torch.tensor(0.0), None, None, None), |
| | (torch.tensor([0.0]), None, None, None), |
| | (torch.tensor([0.0, 1.0, 2.0]), None, None, None), |
| | (None, torch.tensor([-1.0, 2.0]), None, None), |
| | (None, torch.tensor(0.0), None, None), |
| | (None, torch.tensor([0.0]), None, None), |
| | (None, torch.tensor([0.0, 1.0, 2.0]), None, None), |
| | (None, None, torch.tensor([-1.0, 2.0]), None), |
| | (None, None, torch.tensor(0.0), None), |
| | (None, None, torch.tensor([0.0]), None), |
| | (None, None, torch.tensor([0.0, 1.0, 2.0]), None), |
| | (None, None, None, torch.tensor([-1.0, 0.0])), |
| | (None, None, None, torch.tensor([0, 1.0])), |
| | (None, None, None, torch.tensor(0.0)), |
| | (None, None, None, torch.tensor([0.0])), |
| | (None, None, None, torch.tensor([0.0, 1.0, 2.0])), |
| | ], |
| | ) |
| | def test_invalid_param_combinations(self, brightness, contrast, saturation, hue, device, dtype): |
| | with pytest.raises(Exception): |
| | random_color_jitter_generator( |
| | 8, |
| | brightness.to(device=device, dtype=dtype) if brightness is not None else None, |
| | contrast.to(device=device, dtype=dtype) if contrast is not None else None, |
| | saturation.to(device=device, dtype=dtype) if saturation is not None else None, |
| | hue.to(device=device, dtype=dtype) if hue is not None else None, |
| | ) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | |
| | if torch_version_geq(1, 10) and "cuda" in str(device): |
| | pytest.skip("AssertionError: Tensor-likes are not close!") |
| | torch.manual_seed(42) |
| | batch_size = 8 |
| | jitter_params = random_color_jitter_generator( |
| | batch_size, |
| | brightness=torch.tensor([0.8, 1.2], device=device, dtype=dtype), |
| | contrast=torch.tensor([0.7, 1.3], device=device, dtype=dtype), |
| | saturation=torch.tensor([0.6, 1.4], device=device, dtype=dtype), |
| | hue=torch.tensor([-0.1, 0.1], device=device, dtype=dtype), |
| | ) |
| |
|
| | expected_jitter_params = { |
| | 'brightness_factor': torch.tensor( |
| | [1.1529, 1.1660, 0.9531, 1.1837, 0.9562, 1.0404, 0.9026, 1.1175], device=device, dtype=dtype |
| | ), |
| | 'contrast_factor': torch.tensor( |
| | [1.2645, 0.7799, 1.2608, 1.0561, 1.2216, 1.0406, 1.1447, 0.9576], device=device, dtype=dtype |
| | ), |
| | 'hue_factor': torch.tensor( |
| | [0.0771, 0.0148, -0.0467, 0.0255, -0.0461, -0.0117, -0.0406, 0.0663], device=device, dtype=dtype |
| | ), |
| | 'saturation_factor': torch.tensor( |
| | [0.6843, 0.8156, 0.8871, 0.7595, 1.0378, 0.6049, 1.3612, 0.6602], device=device, dtype=dtype |
| | ), |
| | 'order': torch.tensor([3, 2, 0, 1], device=device, dtype=dtype), |
| | } |
| |
|
| | assert set(list(jitter_params.keys())) == { |
| | 'brightness_factor', |
| | 'contrast_factor', |
| | 'hue_factor', |
| | 'saturation_factor', |
| | 'order', |
| | }, "Redundant keys found apart from \ |
| | 'brightness_factor', 'contrast_factor', 'hue_factor', 'saturation_factor', 'order'" |
| |
|
| | assert_close( |
| | jitter_params['brightness_factor'], expected_jitter_params['brightness_factor'], rtol=1e-4, atol=1e-4 |
| | ) |
| | assert_close(jitter_params['contrast_factor'], expected_jitter_params['contrast_factor'], rtol=1e-4, atol=1e-4) |
| | assert_close(jitter_params['hue_factor'], expected_jitter_params['hue_factor'], rtol=1e-4, atol=1e-4) |
| | assert_close( |
| | jitter_params['saturation_factor'], expected_jitter_params['saturation_factor'], rtol=1e-4, atol=1e-4 |
| | ) |
| | assert_close(jitter_params['order'].to(dtype), expected_jitter_params['order'], rtol=1e-4, atol=1e-4) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(42) |
| | batch_size = 8 |
| | jitter_params = random_color_jitter_generator( |
| | batch_size, |
| | brightness=torch.tensor([0.8, 1.2], device=device, dtype=dtype), |
| | contrast=torch.tensor([0.7, 1.3], device=device, dtype=dtype), |
| | saturation=torch.tensor([0.6, 1.4], device=device, dtype=dtype), |
| | hue=torch.tensor([-0.1, 0.1], device=device, dtype=dtype), |
| | same_on_batch=True, |
| | ) |
| |
|
| | expected_res = { |
| | 'brightness_factor': torch.tensor([1.1529] * batch_size, device=device, dtype=dtype), |
| | 'contrast_factor': torch.tensor([1.2490] * batch_size, device=device, dtype=dtype), |
| | 'hue_factor': torch.tensor([-0.0234] * batch_size, device=device, dtype=dtype), |
| | 'saturation_factor': torch.tensor([1.3674] * batch_size, device=device, dtype=dtype), |
| | 'order': torch.tensor([2, 3, 0, 1], device=device, dtype=dtype), |
| | } |
| |
|
| | assert_close(jitter_params['brightness_factor'], expected_res['brightness_factor'], rtol=1e-4, atol=1e-4) |
| | assert_close(jitter_params['contrast_factor'], expected_res['contrast_factor'], rtol=1e-4, atol=1e-4) |
| | assert_close(jitter_params['hue_factor'], expected_res['hue_factor'], rtol=1e-4, atol=1e-4) |
| | assert_close(jitter_params['saturation_factor'], expected_res['saturation_factor'], rtol=1e-4, atol=1e-4) |
| | assert_close(jitter_params['order'].to(dtype), expected_res['order'], rtol=1e-4, atol=1e-4) |
| |
|
| |
|
| | class TestRandomPerspectiveGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('height,width', [(200, 200)]) |
| | @pytest.mark.parametrize('distortion_scale', [torch.tensor(0.0), torch.tensor(0.5), torch.tensor(1.0)]) |
| | @pytest.mark.parametrize('batch_size', [0, 1, 8]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations(self, height, width, distortion_scale, batch_size, same_on_batch, device, dtype): |
| | random_perspective_generator( |
| | batch_size=8, |
| | height=height, |
| | width=width, |
| | distortion_scale=distortion_scale.to(device=device, dtype=dtype), |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | @pytest.mark.parametrize( |
| | 'height,width,distortion_scale', |
| | [ |
| | |
| | (-100, 100, torch.tensor(0.5)), |
| | (100, -100, torch.tensor(0.5)), |
| | (100, 100, torch.tensor(-0.5)), |
| | (100, 100, torch.tensor(1.5)), |
| | (100, 100, torch.tensor([0.0, 0.5])), |
| | ], |
| | ) |
| | def test_invalid_param_combinations(self, height, width, distortion_scale, device, dtype): |
| | with pytest.raises(Exception): |
| | random_perspective_generator( |
| | batch_size=8, |
| | height=height, |
| | width=width, |
| | distortion_scale=distortion_scale.to(device=device, dtype=dtype), |
| | ) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(42) |
| | batch_size = 2 |
| | res = random_perspective_generator(batch_size, 200, 200, torch.tensor(0.5, device=device, dtype=dtype)) |
| | expected = dict( |
| | start_points=torch.tensor( |
| | [ |
| | [[0.0, 0.0], [199.0, 0.0], [199.0, 199.0], [0.0, 199.0]], |
| | [[0.0, 0.0], [199.0, 0.0], [199.0, 199.0], [0.0, 199.0]], |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ), |
| | end_points=torch.tensor( |
| | [ |
| | [[44.1135, 45.7502], [179.8568, 47.9653], [179.4776, 168.9552], [12.8286, 159.3179]], |
| | [[47.0386, 6.6593], [152.2701, 29.6790], [155.5298, 170.6142], [37.0547, 177.5298]], |
| | ], |
| | device=device, |
| | dtype=dtype, |
| | ), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['start_points'], expected['start_points']) |
| | assert_close(res['end_points'], expected['end_points']) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(42) |
| | batch_size = 2 |
| | res = random_perspective_generator( |
| | batch_size, 200, 200, torch.tensor(0.5, device=device, dtype=dtype), same_on_batch=True |
| | ) |
| | expected = dict( |
| | start_points=torch.tensor( |
| | [[[0.0, 0.0], [199.0, 0.0], [199.0, 199.0], [0.0, 199.0]]], device=device, dtype=dtype |
| | ).repeat(2, 1, 1), |
| | end_points=torch.tensor( |
| | [[[44.1135, 45.7502], [179.8568, 47.9653], [179.4776, 168.9552], [12.8286, 159.3179]]], |
| | device=device, |
| | dtype=dtype, |
| | ).repeat(2, 1, 1), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['start_points'], expected['start_points']) |
| | assert_close(res['end_points'], expected['end_points']) |
| |
|
| |
|
| | class TestRandomAffineGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('batch_size', [0, 1, 4]) |
| | @pytest.mark.parametrize('height', [200]) |
| | @pytest.mark.parametrize('width', [300]) |
| | @pytest.mark.parametrize('degrees', [torch.tensor([0, 30])]) |
| | @pytest.mark.parametrize('translate', [None, torch.tensor([0.1, 0.1])]) |
| | @pytest.mark.parametrize('scale', [None, torch.tensor([0.7, 1.2])]) |
| | @pytest.mark.parametrize('shear', [None, torch.tensor([[0, 20], [0, 20]])]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations( |
| | self, batch_size, height, width, degrees, translate, scale, shear, same_on_batch, device, dtype |
| | ): |
| | random_affine_generator( |
| | batch_size=batch_size, |
| | height=height, |
| | width=width, |
| | degrees=degrees.to(device=device, dtype=dtype), |
| | translate=translate.to(device=device, dtype=dtype) if translate is not None else None, |
| | scale=scale.to(device=device, dtype=dtype) if scale is not None else None, |
| | shear=shear.to(device=device, dtype=dtype) if shear is not None else None, |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | @pytest.mark.parametrize( |
| | 'height,width,degrees,translate,scale,shear', |
| | [ |
| | (-100, 100, torch.tensor([10, 20]), None, None, None), |
| | (100, -100, torch.tensor([10, 20]), None, None, None), |
| | (100, 100, 0.5, None, None, None), |
| | (100, 100, torch.tensor([10, 20, 30]), None, None, None), |
| | (100, 100, torch.tensor([10, 20]), torch.tensor([0.1]), None, None), |
| | (10, 10, torch.tensor([1, 2]), torch.tensor([0.1, 0.2, 0.3]), None, None), |
| | (100, 100, torch.tensor([10, 20]), None, torch.tensor([1]), None), |
| | (100, 100, torch.tensor([10, 20]), None, torch.tensor([1, 2, 3]), None), |
| | (100, 100, torch.tensor([10, 20]), None, None, torch.tensor([1])), |
| | (100, 100, torch.tensor([10, 20]), None, None, torch.tensor([1, 2])), |
| | (10, 10, torch.tensor([1, 2]), None, None, torch.tensor([1, 2, 3])), |
| | (10, 10, torch.tensor([1, 2]), None, None, torch.tensor([1, 2, 3, 4])), |
| | (10, 10, torch.tensor([1, 2]), None, None, torch.tensor([1, 2, 3, 4, 5])), |
| | ], |
| | ) |
| | def test_invalid_param_combinations(self, height, width, degrees, translate, scale, shear, device, dtype): |
| | with pytest.raises(Exception): |
| | random_affine_generator( |
| | batch_size=8, |
| | height=height, |
| | width=width, |
| | degrees=degrees.to(device=device, dtype=dtype), |
| | translate=translate.to(device=device, dtype=dtype) if translate is not None else None, |
| | scale=scale.to(device=device, dtype=dtype) if scale is not None else None, |
| | shear=shear.to(device=device, dtype=dtype) if shear is not None else None, |
| | ) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(42) |
| | degrees = torch.tensor([10, 20], device=device, dtype=dtype) |
| | translate = torch.tensor([0.1, 0.1], device=device, dtype=dtype) |
| | scale = torch.tensor([0.7, 1.2], device=device, dtype=dtype) |
| | shear = torch.tensor([[10, 20], [10, 20]], device=device, dtype=dtype) |
| | res = random_affine_generator( |
| | batch_size=2, |
| | height=200, |
| | width=200, |
| | degrees=degrees, |
| | translate=translate, |
| | scale=scale, |
| | shear=shear, |
| | same_on_batch=False, |
| | ) |
| | expected = dict( |
| | translations=torch.tensor([[-4.3821, -9.7371], [4.0358, 11.7457]], device=device, dtype=dtype), |
| | center=torch.tensor([[99.5000, 99.5000], [99.5000, 99.5000]], device=device, dtype=dtype), |
| | scale=torch.tensor([[0.8914, 0.8914], [1.1797, 1.1797]], device=device, dtype=dtype), |
| | angle=torch.tensor([18.8227, 19.1500], device=device, dtype=dtype), |
| | sx=torch.tensor([19.4077, 11.3319], device=device, dtype=dtype), |
| | sy=torch.tensor([19.3460, 15.9358], device=device, dtype=dtype), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['translations'], expected['translations'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['center'], expected['center'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['angle'], expected['angle'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['sx'], expected['sx'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['sy'], expected['sy'], rtol=1e-4, atol=1e-4) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(42) |
| | degrees = torch.tensor([10, 20], device=device, dtype=dtype) |
| | translate = torch.tensor([0.1, 0.1], device=device, dtype=dtype) |
| | scale = torch.tensor([0.7, 1.2], device=device, dtype=dtype) |
| | shear = torch.tensor([[10, 20], [10, 20]], device=device, dtype=dtype) |
| | res = random_affine_generator( |
| | batch_size=2, |
| | height=200, |
| | width=200, |
| | degrees=degrees, |
| | translate=translate, |
| | scale=scale, |
| | shear=shear, |
| | same_on_batch=True, |
| | ) |
| | expected = dict( |
| | translations=torch.tensor([[-4.6854, 18.3722], [-4.6854, 18.3722]], device=device, dtype=dtype), |
| | center=torch.tensor([[99.5000, 99.5000], [99.5000, 99.5000]], device=device, dtype=dtype), |
| | scale=torch.tensor([[1.1575, 1.1575], [1.1575, 1.1575]], device=device, dtype=dtype), |
| | angle=torch.tensor([18.8227, 18.8227], device=device, dtype=dtype), |
| | sx=torch.tensor([13.9045, 13.9045], device=device, dtype=dtype), |
| | sy=torch.tensor([16.0090, 16.0090], device=device, dtype=dtype), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['translations'], expected['translations'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['center'], expected['center'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['angle'], expected['angle'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['sx'], expected['sx'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['sy'], expected['sy'], rtol=1e-4, atol=1e-4) |
| |
|
| |
|
| | class TestRandomRotationGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('batch_size', [0, 1, 8]) |
| | @pytest.mark.parametrize('degrees', [torch.tensor([0, 30])]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations(self, batch_size, degrees, same_on_batch, device, dtype): |
| | random_rotation_generator( |
| | batch_size=batch_size, degrees=degrees.to(device=device, dtype=dtype), same_on_batch=same_on_batch |
| | ) |
| |
|
| | @pytest.mark.parametrize('degrees', [(torch.tensor(10)), (torch.tensor([10])), (torch.tensor([10, 20, 30]))]) |
| | def test_invalid_param_combinations(self, degrees, device, dtype): |
| | batch_size = 8 |
| | with pytest.raises(Exception): |
| | random_rotation_generator(batch_size=batch_size, degrees=degrees.to(device=device, dtype=dtype)) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(42) |
| | degrees = torch.tensor([10, 20]) |
| | res = random_rotation_generator( |
| | batch_size=2, degrees=degrees.to(device=device, dtype=dtype), same_on_batch=False |
| | ) |
| | expected = dict(degrees=torch.tensor([18.8227, 19.1500], device=device, dtype=dtype)) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['degrees'], expected['degrees']) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(42) |
| | degrees = torch.tensor([10, 20]) |
| | res = random_rotation_generator( |
| | batch_size=2, degrees=degrees.to(device=device, dtype=dtype), same_on_batch=True |
| | ) |
| | expected = dict(degrees=torch.tensor([18.8227, 18.8227], device=device, dtype=dtype)) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['degrees'], expected['degrees']) |
| |
|
| |
|
| | class TestRandomCropGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('batch_size', [0, 2]) |
| | @pytest.mark.parametrize('input_size', [(200, 200)]) |
| | @pytest.mark.parametrize('size', [(100, 100), torch.tensor([50, 50])]) |
| | @pytest.mark.parametrize('resize_to', [None, (100, 100)]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations(self, batch_size, input_size, size, resize_to, same_on_batch, device, dtype): |
| | if isinstance(size, torch.Tensor): |
| | size = size.repeat(batch_size, 1).to(device=device, dtype=dtype) |
| | random_crop_generator( |
| | batch_size=batch_size, |
| | input_size=input_size, |
| | size=size, |
| | resize_to=resize_to, |
| | same_on_batch=same_on_batch, |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | @pytest.mark.parametrize( |
| | 'input_size,size,resize_to', |
| | [((-300, 300), (200, 200), (100, 100)), ((200, 200), torch.tensor([50, 50]), (100, 100))], |
| | ) |
| | def test_invalid_param_combinations(self, input_size, size, resize_to, device, dtype): |
| | batch_size = 2 |
| | with pytest.raises(Exception): |
| | random_crop_generator( |
| | batch_size=batch_size, |
| | input_size=input_size, |
| | size=size.to(device=device, dtype=dtype) if isinstance(size, torch.Tensor) else size, |
| | resize_to=resize_to, |
| | ) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(42) |
| | res = random_crop_generator( |
| | batch_size=2, |
| | input_size=(100, 100), |
| | size=torch.tensor([[50, 60], [70, 80]], device=device, dtype=dtype), |
| | resize_to=(200, 200), |
| | ) |
| | expected = dict( |
| | src=torch.tensor( |
| | [[[36, 19], [95, 19], [95, 68], [36, 68]], [[19, 29], [98, 29], [98, 98], [19, 98]]], |
| | device=device, |
| | dtype=dtype, |
| | ), |
| | dst=torch.tensor( |
| | [[[0, 0], [199, 0], [199, 199], [0, 199]], [[0, 0], [199, 0], [199, 199], [0, 199]]], |
| | device=device, |
| | dtype=dtype, |
| | ), |
| | input_size=torch.tensor([[100, 100], [100, 100]], device=device, dtype=torch.long), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['src'], expected['src']) |
| | assert_close(res['dst'], expected['dst']) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(42) |
| | res = random_crop_generator( |
| | batch_size=2, |
| | input_size=(100, 100), |
| | size=torch.tensor([[50, 60], [70, 80]], device=device, dtype=dtype), |
| | resize_to=(200, 200), |
| | same_on_batch=True, |
| | ) |
| | expected = dict( |
| | src=torch.tensor( |
| | [[[36, 46], [95, 46], [95, 95], [36, 95]], [[36, 46], [115, 46], [115, 115], [36, 115]]], |
| | device=device, |
| | dtype=dtype, |
| | ), |
| | dst=torch.tensor( |
| | [[[0, 0], [199, 0], [199, 199], [0, 199]], [[0, 0], [199, 0], [199, 199], [0, 199]]], |
| | device=device, |
| | dtype=dtype, |
| | ), |
| | input_size=torch.tensor([[100, 100], [100, 100]], device=device, dtype=torch.long), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['src'], expected['src']) |
| | assert_close(res['dst'], expected['dst']) |
| |
|
| |
|
| | class TestRandomCropSizeGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('batch_size', [0, 1, 8]) |
| | @pytest.mark.parametrize('size', [(200, 200)]) |
| | @pytest.mark.parametrize('scale', [torch.tensor([0.7, 1.3])]) |
| | @pytest.mark.parametrize('ratio', [torch.tensor([0.9, 1.1])]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations(self, batch_size, size, scale, ratio, same_on_batch, device, dtype): |
| | random_crop_size_generator( |
| | batch_size=batch_size, |
| | size=size, |
| | scale=scale.to(device=device, dtype=dtype), |
| | ratio=ratio.to(device=device, dtype=dtype), |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | @pytest.mark.parametrize( |
| | 'size,scale,ratio', |
| | [ |
| | ((100), torch.tensor([0.7, 1.3]), torch.tensor([0.9, 1.1])), |
| | ((100, 100, 100), torch.tensor([0.7, 1.3]), torch.tensor([0.9, 1.1])), |
| | ((100, 100), torch.tensor([0.7]), torch.tensor([0.9, 1.1])), |
| | ((100, 100), torch.tensor([0.7, 1.3, 1.5]), torch.tensor([0.9, 1.1])), |
| | ((100, 100), torch.tensor([0.7, 1.3]), torch.tensor([0.9])), |
| | ((100, 100), torch.tensor([0.7, 1.3]), torch.tensor([0.9, 1.1, 1.3])), |
| | ], |
| | ) |
| | def test_invalid_param_combinations(self, size, scale, ratio, device, dtype): |
| | batch_size = 2 |
| | with pytest.raises(Exception): |
| | random_crop_size_generator( |
| | batch_size=batch_size, |
| | size=size, |
| | scale=scale.to(device=device, dtype=dtype), |
| | ratio=ratio.to(device=device, dtype=dtype), |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(42) |
| | res = random_crop_size_generator( |
| | batch_size=8, |
| | size=(100, 100), |
| | scale=torch.tensor([0.7, 1.3], device=device, dtype=dtype), |
| | ratio=torch.tensor([0.9, 1.1], device=device, dtype=dtype), |
| | same_on_batch=False, |
| | ) |
| | expected = dict( |
| | size=torch.tensor( |
| | [[99, 94], [91, 95], [90, 96], [87, 86], [94, 98], [87, 81], [85, 93], [83, 90]], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['size'], expected['size']) |
| |
|
| | res = random_crop_size_generator( |
| | batch_size=100, |
| | size=(100, 100), |
| | scale=torch.tensor([0.999, 1.0], device=device, dtype=dtype), |
| | ratio=torch.tensor([1.0, 1.0], device=device, dtype=dtype), |
| | same_on_batch=False, |
| | ) |
| | expected = dict(size=torch.tensor([[100, 100]], device=device, dtype=dtype).repeat(100, 1)) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['size'], expected['size']) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(42) |
| | res = random_crop_size_generator( |
| | batch_size=8, |
| | size=(100, 100), |
| | scale=torch.tensor([0.7, 1.3], device=device, dtype=dtype), |
| | ratio=torch.tensor([0.9, 1.1], device=device, dtype=dtype), |
| | same_on_batch=True, |
| | ) |
| | expected = dict( |
| | size=torch.tensor( |
| | [[99, 95], [99, 95], [99, 95], [99, 95], [99, 95], [99, 95], [99, 95], [99, 95]], |
| | device=device, |
| | dtype=dtype, |
| | ) |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['size'], expected['size']) |
| |
|
| |
|
| | class TestRandomRectangleGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('batch_size', [0, 1, 8]) |
| | @pytest.mark.parametrize('height', [200]) |
| | @pytest.mark.parametrize('width', [300]) |
| | @pytest.mark.parametrize('scale', [torch.tensor([0.7, 1.1])]) |
| | @pytest.mark.parametrize('ratio', [torch.tensor([0.7, 1.1])]) |
| | @pytest.mark.parametrize('value', [0]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations( |
| | self, batch_size, height, width, scale, ratio, value, same_on_batch, device, dtype |
| | ): |
| | random_rectangles_params_generator( |
| | batch_size=batch_size, |
| | height=height, |
| | width=width, |
| | scale=scale.to(device=device, dtype=dtype), |
| | ratio=ratio.to(device=device, dtype=dtype), |
| | value=value, |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | @pytest.mark.parametrize( |
| | 'height,width,scale,ratio,value', |
| | [ |
| | (-100, 100, torch.tensor([0.7, 1.3]), torch.tensor([0.7, 1.3]), 0), |
| | (100, -100, torch.tensor([0.7, 1.3]), torch.tensor([0.7, 1.3]), 0), |
| | (100, -100, torch.tensor([0.7]), torch.tensor([0.7, 1.3]), 0), |
| | (100, 100, torch.tensor([0.7, 1.3, 1.5]), torch.tensor([0.7, 1.3]), 0), |
| | (100, 100, torch.tensor([0.7, 1.3]), torch.tensor([0.7]), 0), |
| | (100, 100, torch.tensor([0.7, 1.3]), torch.tensor([0.7, 1.3, 1.5]), 0), |
| | (100, 100, torch.tensor([0.7, 1.3]), torch.tensor([0.7, 1.3]), -1), |
| | (100, 100, torch.tensor([0.7, 1.3]), torch.tensor([0.7, 1.3]), 2), |
| | (100, 100, torch.tensor([0.5, 0.7]), torch.tensor([0.7, 0.9]), torch.tensor(0.5)), |
| | ], |
| | ) |
| | def test_invalid_param_combinations(self, height, width, scale, ratio, value, device, dtype): |
| | batch_size = 8 |
| | with pytest.raises(Exception): |
| | random_rectangles_params_generator( |
| | batch_size=batch_size, |
| | height=height, |
| | width=width, |
| | scale=scale.to(device=device, dtype=dtype), |
| | ratio=ratio.to(device=device, dtype=dtype), |
| | value=value, |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(42) |
| | width, height = 100, 150 |
| | scale = torch.tensor([0.7, 1.3], device=device, dtype=dtype) |
| | ratio = torch.tensor([0.7, 1.3], device=device, dtype=dtype) |
| | value = 0.5 |
| | res = random_rectangles_params_generator( |
| | batch_size=2, height=height, width=width, scale=scale, ratio=ratio, value=value, same_on_batch=False |
| | ) |
| | expected = dict( |
| | widths=torch.tensor([100, 100], device=device, dtype=dtype), |
| | heights=torch.tensor([0, 0], device=device, dtype=dtype), |
| | xs=torch.tensor([0, 0], device=device, dtype=dtype), |
| | ys=torch.tensor([6, 8], device=device, dtype=dtype), |
| | values=torch.tensor([0.5000, 0.5000], device=device, dtype=dtype), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['widths'], expected['widths']) |
| | assert_close(res['widths'], expected['widths']) |
| | assert_close(res['xs'], expected['xs']) |
| | assert_close(res['ys'], expected['ys']) |
| | assert_close(res['values'], expected['values']) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(42) |
| | width, height = 100, 150 |
| | scale = torch.tensor([0.7, 1.3], device=device, dtype=dtype) |
| | ratio = torch.tensor([0.7, 1.3], device=device, dtype=dtype) |
| | value = 0.5 |
| | res = random_rectangles_params_generator( |
| | batch_size=2, height=height, width=width, scale=scale, ratio=ratio, value=value, same_on_batch=True |
| | ) |
| | expected = dict( |
| | widths=torch.tensor([100, 100], device=device, dtype=dtype), |
| | heights=torch.tensor([0, 0], device=device, dtype=dtype), |
| | xs=torch.tensor([0, 0], device=device, dtype=dtype), |
| | ys=torch.tensor([10, 10], device=device, dtype=dtype), |
| | values=torch.tensor([0.5000, 0.5000], device=device, dtype=dtype), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['widths'], expected['widths']) |
| | assert_close(res['widths'], expected['widths']) |
| | assert_close(res['xs'], expected['xs']) |
| | assert_close(res['ys'], expected['ys']) |
| | assert_close(res['values'], expected['values']) |
| |
|
| |
|
| | class TestCenterCropGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('batch_size', [0, 2]) |
| | @pytest.mark.parametrize('height', [200]) |
| | @pytest.mark.parametrize('width', [200]) |
| | @pytest.mark.parametrize('size', [(100, 100)]) |
| | def test_valid_param_combinations(self, batch_size, height, width, size, device, dtype): |
| | center_crop_generator(batch_size=batch_size, height=height, width=width, size=size) |
| |
|
| | @pytest.mark.parametrize( |
| | 'height,width,size', |
| | [ |
| | (200, -200, (100, 100)), |
| | (-200, 200, (100, 100)), |
| | (100, 100, (120, 120)), |
| | (150, 100, (120, 120)), |
| | (100, 150, (120, 120)), |
| | ], |
| | ) |
| | def test_invalid_param_combinations(self, height, width, size, device, dtype): |
| | batch_size = 2 |
| | with pytest.raises(Exception): |
| | center_crop_generator(batch_size=batch_size, height=height, width=width, size=size) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(42) |
| | res = center_crop_generator(batch_size=2, height=200, width=200, size=(120, 150)) |
| | expected = dict( |
| | src=torch.tensor( |
| | [[[25, 40], [174, 40], [174, 159], [25, 159]], [[25, 40], [174, 40], [174, 159], [25, 159]]], |
| | device=device, |
| | dtype=torch.long, |
| | ), |
| | dst=torch.tensor( |
| | [[[0, 0], [149, 0], [149, 119], [0, 119]], [[0, 0], [149, 0], [149, 119], [0, 119]]], |
| | device=device, |
| | dtype=torch.long, |
| | ), |
| | input_size=torch.tensor([[200, 200], [200, 200]], device=device, dtype=torch.long), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['src'].to(device=device), expected['src']) |
| | assert_close(res['dst'].to(device=device), expected['dst']) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | pass |
| |
|
| |
|
| | class TestRandomMotionBlur(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('batch_size', [0, 1, 8]) |
| | @pytest.mark.parametrize('kernel_size', [3, (3, 5)]) |
| | @pytest.mark.parametrize('angle', [torch.tensor([10, 30])]) |
| | @pytest.mark.parametrize('direction', [torch.tensor([-1, -1]), torch.tensor([1, 1])]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations(self, batch_size, kernel_size, angle, direction, same_on_batch, device, dtype): |
| | random_motion_blur_generator( |
| | batch_size=batch_size, |
| | kernel_size=kernel_size, |
| | angle=angle.to(device=device, dtype=dtype), |
| | direction=direction.to(device=device, dtype=dtype), |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | @pytest.mark.parametrize( |
| | 'kernel_size,angle,direction', |
| | [ |
| | (4, torch.tensor([30, 100]), torch.tensor([-1, 1])), |
| | (1, torch.tensor([30, 100]), torch.tensor([-1, 1])), |
| | ((1, 2, 3), torch.tensor([30, 100]), torch.tensor([-1, 1])), |
| | (3, torch.tensor([30, 100]), torch.tensor([-2, 1])), |
| | (3, torch.tensor([30, 100]), torch.tensor([-1, 2])), |
| | ], |
| | ) |
| | def test_invalid_param_combinations(self, kernel_size, angle, direction, device, dtype): |
| | with pytest.raises(Exception): |
| | random_motion_blur_generator( |
| | batch_size=8, |
| | kernel_size=kernel_size, |
| | angle=angle.to(device=device, dtype=dtype), |
| | direction=direction.to(device=device, dtype=dtype), |
| | ) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(42) |
| | angle = torch.tensor([30, 90]) |
| | direction = torch.tensor([-1, 1]) |
| | res = random_motion_blur_generator( |
| | batch_size=2, |
| | kernel_size=3, |
| | angle=angle.to(device=device, dtype=dtype), |
| | direction=direction.to(device=device, dtype=dtype), |
| | same_on_batch=False, |
| | ) |
| | expected = dict( |
| | ksize_factor=torch.tensor([3, 3], device=device, dtype=torch.int32), |
| | angle_factor=torch.tensor([82.9362, 84.9002], device=device, dtype=dtype), |
| | direction_factor=torch.tensor([-0.2343, 0.9186], device=device, dtype=dtype), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['ksize_factor'], expected['ksize_factor'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['angle_factor'], expected['angle_factor'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['direction_factor'], expected['direction_factor'], rtol=1e-4, atol=1e-4) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(42) |
| | angle = torch.tensor([30, 90]) |
| | direction = torch.tensor([-1, 1]) |
| | res = random_motion_blur_generator( |
| | batch_size=2, |
| | kernel_size=3, |
| | angle=angle.to(device=device, dtype=dtype), |
| | direction=direction.to(device=device, dtype=dtype), |
| | same_on_batch=True, |
| | ) |
| | expected = dict( |
| | ksize_factor=torch.tensor([3, 3], device=device, dtype=torch.int32), |
| | angle_factor=torch.tensor([82.9362, 82.9362], device=device, dtype=dtype), |
| | direction_factor=torch.tensor([0.8300, 0.8300], device=device, dtype=dtype), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['ksize_factor'], expected['ksize_factor'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['angle_factor'], expected['angle_factor'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['direction_factor'], expected['direction_factor'], rtol=1e-4, atol=1e-4) |
| |
|
| |
|
| | class TestRandomSolarizeGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('batch_size', [0, 1, 8]) |
| | @pytest.mark.parametrize('thresholds', [torch.tensor([0, 1]), torch.tensor([0.4, 0.6])]) |
| | @pytest.mark.parametrize('additions', [torch.tensor([-0.5, 0.5])]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations(self, batch_size, thresholds, additions, same_on_batch, device, dtype): |
| | random_solarize_generator( |
| | batch_size=batch_size, |
| | thresholds=thresholds.to(device=device, dtype=dtype), |
| | additions=additions.to(device=device, dtype=dtype), |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | @pytest.mark.parametrize( |
| | 'thresholds,additions', |
| | [ |
| | (torch.tensor([0, 2]), torch.tensor([-0.5, 0.5])), |
| | (torch.tensor([-1, 1]), torch.tensor([-0.5, 0.5])), |
| | ([0, 1], torch.tensor([-0.5, 0.5])), |
| | (torch.tensor([0, 1]), torch.tensor([-0.5, 1])), |
| | (torch.tensor([0, 1]), torch.tensor([-1, 0.5])), |
| | (torch.tensor([0, 1]), [-0.5, 0.5]), |
| | ], |
| | ) |
| | def test_invalid_param_combinations(self, thresholds, additions, device, dtype): |
| | with pytest.raises(Exception): |
| | random_solarize_generator( |
| | batch_size=batch_size, |
| | thresholds=thresholds.to(device=device, dtype=dtype), |
| | additions=additions.to(device=device, dtype=dtype), |
| | ) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(42) |
| | batch_size = 8 |
| | res = random_solarize_generator( |
| | batch_size=batch_size, |
| | thresholds=torch.tensor([0, 1], device=device, dtype=dtype), |
| | additions=torch.tensor([-0.5, 0.5], device=device, dtype=dtype), |
| | same_on_batch=False, |
| | ) |
| | expected = dict( |
| | thresholds_factor=torch.tensor( |
| | [0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936], device=device, dtype=dtype |
| | ), |
| | additions_factor=torch.tensor( |
| | [0.4408, -0.3668, 0.4346, 0.0936, 0.3694, 0.0677, 0.2411, -0.0706], device=device, dtype=dtype |
| | ), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['thresholds_factor'], expected['thresholds_factor'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['additions_factor'], expected['additions_factor'], rtol=1e-4, atol=1e-4) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(42) |
| | batch_size = 8 |
| | res = random_solarize_generator( |
| | batch_size=batch_size, |
| | thresholds=torch.tensor([0, 1], device=device, dtype=dtype), |
| | additions=torch.tensor([-0.5, 0.5], device=device, dtype=dtype), |
| | same_on_batch=True, |
| | ) |
| | expected = dict( |
| | thresholds_factor=torch.tensor( |
| | [0.8823, 0.8823, 0.8823, 0.8823, 0.8823, 0.8823, 0.8823, 0.8823], device=device, dtype=dtype |
| | ), |
| | additions_factor=torch.tensor( |
| | [0.4150, 0.4150, 0.4150, 0.4150, 0.4150, 0.4150, 0.4150, 0.4150], device=device, dtype=dtype |
| | ), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['thresholds_factor'], expected['thresholds_factor'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['additions_factor'], expected['additions_factor'], rtol=1e-4, atol=1e-4) |
| |
|
| |
|
| | class TestRandomPosterizeGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('batch_size', [0, 1, 8]) |
| | @pytest.mark.parametrize('bits', [torch.tensor([0, 8])]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations(self, batch_size, bits, same_on_batch, device, dtype): |
| | random_posterize_generator( |
| | batch_size=batch_size, bits=bits.to(device=device, dtype=dtype), same_on_batch=same_on_batch |
| | ) |
| |
|
| | @pytest.mark.parametrize('bits', [(torch.tensor([-1, 1])), (torch.tensor([0, 9])), (torch.tensor([3])), ([0, 8])]) |
| | def test_invalid_param_combinations(self, bits, device, dtype): |
| | with pytest.raises(Exception): |
| | random_posterize_generator( |
| | batch_size=batch_size, |
| | bits=bits.to(device=device, dtype=dtype), |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(9) |
| | batch_size = 8 |
| | res = random_posterize_generator( |
| | batch_size=batch_size, bits=torch.tensor([0, 8], device=device, dtype=dtype), same_on_batch=False |
| | ) |
| | expected = dict(bits_factor=torch.tensor([5, 2, 3, 6, 7, 7, 2, 7], device=device, dtype=torch.int32)) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['bits_factor'], expected['bits_factor'], rtol=1e-4, atol=1e-4) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(9) |
| | batch_size = 8 |
| | res = random_posterize_generator( |
| | batch_size=batch_size, bits=torch.tensor([0, 8], device=device, dtype=dtype), same_on_batch=True |
| | ) |
| | expected = dict(bits_factor=torch.tensor([5, 5, 5, 5, 5, 5, 5, 5], device=device, dtype=torch.int32)) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['bits_factor'], expected['bits_factor'], rtol=1e-4, atol=1e-4) |
| |
|
| |
|
| | class TestRandomSharpnessGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('batch_size', [0, 1, 8]) |
| | @pytest.mark.parametrize('sharpness', [torch.tensor([0.0, 1.0])]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations(self, batch_size, sharpness, same_on_batch, device, dtype): |
| | random_sharpness_generator( |
| | batch_size=batch_size, sharpness=sharpness.to(device=device, dtype=dtype), same_on_batch=same_on_batch |
| | ) |
| |
|
| | @pytest.mark.parametrize('sharpness', [(torch.tensor([-1, 5])), (torch.tensor([3])), ([0, 1.0])]) |
| | def test_invalid_param_combinations(self, sharpness, device, dtype): |
| | with pytest.raises(Exception): |
| | random_sharpness_generator( |
| | batch_size=batch_size, |
| | sharpness=sharpness.to(device=device, dtype=dtype), |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(42) |
| | batch_size = 8 |
| | res = random_sharpness_generator( |
| | batch_size=batch_size, sharpness=torch.tensor([0.0, 1.0], device=device, dtype=dtype), same_on_batch=False |
| | ) |
| | expected = dict( |
| | sharpness_factor=torch.tensor( |
| | [0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936], device=device, dtype=dtype |
| | ) |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['sharpness_factor'], expected['sharpness_factor'], rtol=1e-4, atol=1e-4) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(42) |
| | batch_size = 8 |
| | res = random_sharpness_generator( |
| | batch_size=batch_size, sharpness=torch.tensor([0.0, 1.0], device=device, dtype=dtype), same_on_batch=True |
| | ) |
| | expected = dict( |
| | sharpness_factor=torch.tensor( |
| | [0.8823, 0.8823, 0.8823, 0.8823, 0.8823, 0.8823, 0.8823, 0.8823], device=device, dtype=dtype |
| | ) |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['sharpness_factor'], expected['sharpness_factor'], rtol=1e-4, atol=1e-4) |
| |
|
| |
|
| | class TestRandomMixUpGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('batch_size', [0, 1, 8]) |
| | @pytest.mark.parametrize('p', [0.0, 0.5, 1.0]) |
| | @pytest.mark.parametrize('lambda_val', [None, torch.tensor([0.0, 1.0])]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations(self, batch_size, p, lambda_val, same_on_batch, device, dtype): |
| | random_mixup_generator( |
| | batch_size=batch_size, |
| | p=p, |
| | lambda_val=lambda_val.to(device=device, dtype=dtype) |
| | if isinstance(lambda_val, (torch.Tensor)) |
| | else lambda_val, |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | @pytest.mark.parametrize( |
| | 'lambda_val', [(torch.tensor([-1, 1])), (torch.tensor([0, 2])), (torch.tensor([0, 0.5, 1])), ([0.0, 1.0])] |
| | ) |
| | def test_invalid_param_combinations(self, lambda_val, device, dtype): |
| | with pytest.raises(Exception): |
| | random_mixup_generator(batch_size=8, lambda_val=lambda_val.to(device=device, dtype=dtype)) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(42) |
| | batch_size = 8 |
| | res = random_mixup_generator( |
| | batch_size=batch_size, |
| | p=0.5, |
| | lambda_val=torch.tensor([0.0, 1.0], device=device, dtype=dtype), |
| | same_on_batch=False, |
| | ) |
| | expected = dict( |
| | mixup_pairs=torch.tensor([6, 1, 0, 7, 2, 5, 3, 4], device=device, dtype=torch.long), |
| | mixup_lambdas=torch.tensor( |
| | [0.0000, 0.0000, 0.5739, 0.0000, 0.6274, 0.0000, 0.4414, 0.0000], device=device, dtype=dtype |
| | ), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['mixup_pairs'], expected['mixup_pairs'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['mixup_lambdas'], expected['mixup_lambdas'], rtol=1e-4, atol=1e-4) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(9) |
| | batch_size = 8 |
| | res = random_mixup_generator( |
| | batch_size=batch_size, |
| | p=0.9999, |
| | lambda_val=torch.tensor([0.0, 1.0], device=device, dtype=dtype), |
| | same_on_batch=True, |
| | ) |
| | expected = dict( |
| | mixup_pairs=torch.tensor([4, 6, 7, 5, 0, 1, 3, 2], device=device, dtype=torch.long), |
| | mixup_lambdas=torch.tensor( |
| | [0.3804, 0.3804, 0.3804, 0.3804, 0.3804, 0.3804, 0.3804, 0.3804], device=device, dtype=dtype |
| | ), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['mixup_pairs'], expected['mixup_pairs'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['mixup_lambdas'], expected['mixup_lambdas'], rtol=1e-4, atol=1e-4) |
| |
|
| |
|
| | class TestRandomCutMixGen(RandomGeneratorBaseTests): |
| | @pytest.mark.parametrize('batch_size', [0, 1, 8]) |
| | @pytest.mark.parametrize('p', [0, 0.5, 1.0]) |
| | @pytest.mark.parametrize('width,height', [(200, 200)]) |
| | @pytest.mark.parametrize('num_mix', [1, 3]) |
| | @pytest.mark.parametrize('beta', [None, torch.tensor(1e-15), torch.tensor(1.0)]) |
| | @pytest.mark.parametrize('cut_size', [None, torch.tensor([0.0, 1.0]), torch.tensor([0.3, 0.6])]) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_valid_param_combinations( |
| | self, batch_size, p, width, height, num_mix, beta, cut_size, same_on_batch, device, dtype |
| | ): |
| | random_cutmix_generator( |
| | batch_size=batch_size, |
| | p=p, |
| | width=width, |
| | height=height, |
| | num_mix=num_mix, |
| | beta=beta.to(device=device, dtype=dtype) if isinstance(beta, (torch.Tensor)) else beta, |
| | cut_size=cut_size.to(device=device, dtype=dtype) if isinstance(cut_size, (torch.Tensor)) else cut_size, |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | @pytest.mark.parametrize( |
| | 'width,height,num_mix,beta,cut_size', |
| | [ |
| | (200, -200, 1, None, None), |
| | (-200, 200, 1, None, None), |
| | (200, 200, 0, None, None), |
| | (200, 200, 1.5, None, None), |
| | (200, 200, 1, torch.tensor([0.0, 1.0]), None), |
| | (200, 200, 1, None, torch.tensor([-1.0, 1.0])), |
| | (200, 200, 1, None, torch.tensor([0.0, 2.0])), |
| | ], |
| | ) |
| | @pytest.mark.parametrize('same_on_batch', [True, False]) |
| | def test_invalid_param_combinations(self, width, height, num_mix, beta, cut_size, same_on_batch, device, dtype): |
| | with pytest.raises(Exception): |
| | random_cutmix_generator( |
| | batch_size=8, |
| | p=0.5, |
| | width=width, |
| | height=height, |
| | num_mix=num_mix, |
| | beta=beta.to(device=device, dtype=dtype) if isinstance(beta, (torch.Tensor)) else beta, |
| | cut_size=beta.to(device=device, dtype=dtype) if isinstance(cut_size, (torch.Tensor)) else cut_size, |
| | same_on_batch=same_on_batch, |
| | ) |
| |
|
| | def test_random_gen(self, device, dtype): |
| | torch.manual_seed(42) |
| | batch_size = 2 |
| | res = random_cutmix_generator( |
| | batch_size=batch_size, |
| | width=200, |
| | height=200, |
| | p=0.5, |
| | num_mix=1, |
| | beta=torch.tensor(1.0, device=device, dtype=dtype), |
| | cut_size=torch.tensor([0.0, 1.0], device=device, dtype=dtype), |
| | same_on_batch=False, |
| | ) |
| | expected = dict( |
| | mix_pairs=torch.tensor([[0, 1]], device=device, dtype=torch.long), |
| | crop_src=torch.tensor( |
| | [[[[71, 108], [70, 108], [70, 107], [71, 107]], [[39, 1], [38, 1], [38, 0], [39, 0]]]], |
| | device=device, |
| | dtype=dtype, |
| | ), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['mix_pairs'], expected['mix_pairs'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['crop_src'], expected['crop_src'], rtol=1e-4, atol=1e-4) |
| |
|
| | def test_same_on_batch(self, device, dtype): |
| | torch.manual_seed(42) |
| | batch_size = 2 |
| | res = random_cutmix_generator( |
| | batch_size=batch_size, |
| | width=200, |
| | height=200, |
| | p=0.5, |
| | num_mix=1, |
| | beta=torch.tensor(1.0, device=device, dtype=dtype), |
| | cut_size=torch.tensor([0.0, 1.0], device=device, dtype=dtype), |
| | same_on_batch=True, |
| | ) |
| | expected = dict( |
| | mix_pairs=torch.tensor([[1, 0]], device=device, dtype=torch.long), |
| | crop_src=torch.tensor( |
| | [[[[114, 53], [113, 53], [113, 52], [114, 52]], [[114, 53], [113, 53], [113, 52], [114, 52]]]], |
| | device=device, |
| | dtype=dtype, |
| | ), |
| | ) |
| | assert res.keys() == expected.keys() |
| | assert_close(res['mix_pairs'], expected['mix_pairs'], rtol=1e-4, atol=1e-4) |
| | assert_close(res['crop_src'], expected['crop_src'], rtol=1e-4, atol=1e-4) |
| |
|