import pytest import torch from kornia.augmentation.random_generator import ( center_crop_generator3d, random_affine_generator3d, random_crop_generator3d, random_motion_blur_generator3d, random_perspective_generator3d, random_rotation_generator3d, ) from kornia.testing import assert_close class RandomGeneratorBaseTests: def test_valid_param_combinations(self, device, dtype): raise NotImplementedError def test_invalid_param_combinations(self, device, dtype): raise NotImplementedError def test_random_gen(self, device, dtype): raise NotImplementedError def test_same_on_batch(self, device, dtype): raise NotImplementedError class TestRandomPerspectiveGen3D(RandomGeneratorBaseTests): @pytest.mark.parametrize('batch_size', [0, 1, 8]) @pytest.mark.parametrize('depth,height,width', [(200, 200, 200)]) @pytest.mark.parametrize('distortion_scale', [torch.tensor(0.0), torch.tensor(0.5), torch.tensor(1.0)]) @pytest.mark.parametrize('same_on_batch', [True, False]) def test_valid_param_combinations( self, depth, height, width, distortion_scale, batch_size, same_on_batch, device, dtype ): random_perspective_generator3d( batch_size=batch_size, depth=depth, height=height, width=width, distortion_scale=distortion_scale.to(device=device, dtype=dtype), same_on_batch=same_on_batch, ) @pytest.mark.parametrize( 'depth,height,width,distortion_scale', [ # Should be failed if distortion_scale > 1. or distortion_scale < 0. (-100, 100, 100, torch.tensor(0.5)), (100, -100, 100, torch.tensor(0.5)), (100, 100, -100, torch.tensor(-0.5)), (100, 100, 100, torch.tensor(1.5)), (100, 100, 100, torch.tensor([0.0, 0.5])), ], ) def test_invalid_param_combinations(self, depth, height, width, distortion_scale, device, dtype): with pytest.raises(Exception): random_perspective_generator3d( batch_size=8, height=height, width=width, distortion_scale=distortion_scale.to(device=device, dtype=dtype), ) def test_random_gen(self, device, dtype): torch.manual_seed(42) batch_size = 2 res = random_perspective_generator3d(batch_size, 200, 200, 200, torch.tensor(0.5, device=device, dtype=dtype)) expected = dict( start_points=torch.tensor( [ [ [0.0, 0.0, 0.0], [199.0, 0.0, 0.0], [199.0, 199.0, 0.0], [0.0, 199.0, 0.0], [0.0, 0.0, 199.0], [199.0, 0.0, 199.0], [199.0, 199.0, 199.0], [0.0, 199.0, 199.0], ], [ [0.0, 0.0, 0.0], [199.0, 0.0, 0.0], [199.0, 199.0, 0.0], [0.0, 199.0, 0.0], [0.0, 0.0, 199.0], [199.0, 0.0, 199.0], [199.0, 199.0, 199.0], [0.0, 199.0, 199.0], ], ], device=device, dtype=dtype, ), end_points=torch.tensor( [ [ [44.1135, 45.7502, 19.1432], [151.0347, 19.5224, 30.0448], [186.1714, 159.3179, 47.0386], [6.6593, 152.2701, 29.6790], [43.4702, 28.3858, 161.9453], [177.5298, 44.2721, 170.3048], [185.6710, 167.6275, 185.5184], [22.0682, 184.1540, 157.4157], ], [ [5.2657, 13.4747, 17.9406], [189.0318, 27.3596, 0.3080], [151.4223, 195.2367, 44.3007], [29.1605, 182.1176, 40.4487], [28.8963, 45.1991, 171.2670], [181.8843, 31.7171, 180.7795], [163.4786, 151.6794, 159.5485], [14.0707, 159.5684, 169.5268], ], ], device=device, dtype=dtype, ), ) assert res.keys() == expected.keys() assert_close(res['start_points'], expected['start_points'], atol=1e-4, rtol=1e-4) assert_close(res['end_points'], expected['end_points'], atol=1e-4, rtol=1e-4) def test_same_on_batch(self, device, dtype): torch.manual_seed(42) batch_size = 2 res = random_perspective_generator3d( batch_size, 200, 200, 200, torch.tensor(0.5, device=device, dtype=dtype), same_on_batch=True ) expected = dict( start_points=torch.tensor( [ [ [0.0, 0.0, 0.0], [199.0, 0.0, 0.0], [199.0, 199.0, 0.0], [0.0, 199.0, 0.0], [0.0, 0.0, 199.0], [199.0, 0.0, 199.0], [199.0, 199.0, 199.0], [0.0, 199.0, 199.0], ], [ [0.0, 0.0, 0.0], [199.0, 0.0, 0.0], [199.0, 199.0, 0.0], [0.0, 199.0, 0.0], [0.0, 0.0, 199.0], [199.0, 0.0, 199.0], [199.0, 199.0, 199.0], [0.0, 199.0, 199.0], ], ], device=device, dtype=dtype, ), end_points=torch.tensor( [ [ [44.1135, 45.7502, 19.1432], [151.0347, 19.5224, 30.0448], [186.1714, 159.3179, 47.0386], [6.6593, 152.2701, 29.6790], [43.4702, 28.3858, 161.9453], [177.5298, 44.2721, 170.3048], [185.6710, 167.6275, 185.5184], [22.0682, 184.1540, 157.4157], ], [ [44.1135, 45.7502, 19.1432], [151.0347, 19.5224, 30.0448], [186.1714, 159.3179, 47.0386], [6.6593, 152.2701, 29.6790], [43.4702, 28.3858, 161.9453], [177.5298, 44.2721, 170.3048], [185.6710, 167.6275, 185.5184], [22.0682, 184.1540, 157.4157], ], ], device=device, dtype=dtype, ), ) assert res.keys() == expected.keys() assert_close(res['start_points'], expected['start_points'], atol=1e-4, rtol=1e-4) assert_close(res['end_points'], expected['end_points'], atol=1e-4, rtol=1e-4) class TestRandomAffineGen3D(RandomGeneratorBaseTests): @pytest.mark.parametrize('batch_size', [0, 1, 8]) @pytest.mark.parametrize('depth,height,width', [(200, 300, 400)]) @pytest.mark.parametrize('degrees', [torch.tensor([(0, 30), (0, 30), (0, 30)])]) @pytest.mark.parametrize('translate', [None, torch.tensor([0.1, 0.1, 0.1])]) @pytest.mark.parametrize('scale', [None, torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]])]) @pytest.mark.parametrize('shear', [None, torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20], [0, 20]])]) @pytest.mark.parametrize('same_on_batch', [True, False]) def test_valid_param_combinations( self, batch_size, depth, height, width, degrees, translate, scale, shear, same_on_batch, device, dtype ): random_affine_generator3d( batch_size=batch_size, depth=depth, height=height, width=width, degrees=degrees.to(device=device, dtype=dtype), translate=translate.to(device=device, dtype=dtype) if translate is not None else None, scale=scale.to(device=device, dtype=dtype) if scale is not None else None, shears=shear.to(device=device, dtype=dtype) if shear is not None else None, same_on_batch=same_on_batch, ) @pytest.mark.parametrize( 'depth,height,width,degrees,translate,scale,shear', [ (-100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, None, None), (100, -100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, None, None), (100, 100, -100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, None, None), (100, 100, 100, torch.tensor([0, 9]), None, None, None), (100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), torch.tensor([0.1, 0.2]), None, None), (100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), torch.tensor([0.1, 0.2]), None, None), (100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), torch.tensor([0.1]), None, None), (100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, torch.tensor([[0.2, 0.2, 0.2]]), None), (100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, torch.tensor([0.2]), None), (100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, None, torch.tensor([[20, 20, 30]])), (100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, None, torch.tensor([20])), ], ) def test_invalid_param_combinations(self, depth, height, width, degrees, translate, scale, shear, device, dtype): with pytest.raises(Exception): random_affine_generator3d( batch_size=8, depth=depth, height=height, width=width, degrees=degrees.to(device=device, dtype=dtype), translate=translate.to(device=device, dtype=dtype) if translate is not None else None, scale=scale.to(device=device, dtype=dtype) if scale is not None else None, shears=shear.to(device=device, dtype=dtype) if shear is not None else None, ) def test_random_gen(self, device, dtype): torch.manual_seed(42) degrees = torch.tensor([[10, 20], [10, 20], [10, 20]]) translate = torch.tensor([0.1, 0.1, 0.1]) scale = torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]]) shear = torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20], [0, 20]]) res = random_affine_generator3d( batch_size=2, depth=200, height=200, width=200, degrees=degrees.to(device=device, dtype=dtype), translate=translate.to(device=device, dtype=dtype) if translate is not None else None, scale=scale.to(device=device, dtype=dtype) if scale is not None else None, shears=shear.to(device=device, dtype=dtype) if shear is not None else None, ) expected = dict( translations=torch.tensor( [[14.7762, 9.6438, 15.4177], [2.7086, -2.8238, 2.9562]], device=device, dtype=dtype ), center=torch.tensor([[99.5000, 99.5000, 99.5000], [99.5000, 99.5000, 99.5000]], device=device, dtype=dtype), scale=torch.tensor([[0.8283, 1.1704, 1.1673], [1.0968, 0.7666, 0.9968]], device=device, dtype=dtype), angles=torch.tensor([[18.8227, 13.8286, 13.9045], [19.1500, 19.5931, 16.0090]], device=device, dtype=dtype), sxy=torch.tensor([5.3316, 12.5490], device=device, dtype=dtype), sxz=torch.tensor([5.3926, 8.8273], device=device, dtype=dtype), syx=torch.tensor([5.9384, 16.6337], device=device, dtype=dtype), syz=torch.tensor([2.1063, 5.3899], device=device, dtype=dtype), szx=torch.tensor([7.1763, 3.9873], device=device, dtype=dtype), szy=torch.tensor([10.9438, 0.1232], device=device, dtype=dtype), ) assert res.keys() == expected.keys() assert_close(res['translations'], expected['translations'], rtol=1e-4, atol=1e-4) assert_close(res['center'], expected['center'], rtol=1e-4, atol=1e-4) assert_close(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4) assert_close(res['angles'], expected['angles'], rtol=1e-4, atol=1e-4) assert_close(res['sxy'], expected['sxy'], rtol=1e-4, atol=1e-4) assert_close(res['sxz'], expected['sxz'], rtol=1e-4, atol=1e-4) assert_close(res['syx'], expected['syx'], rtol=1e-4, atol=1e-4) assert_close(res['syz'], expected['syz'], rtol=1e-4, atol=1e-4) assert_close(res['szx'], expected['szx'], rtol=1e-4, atol=1e-4) assert_close(res['szy'], expected['szy'], rtol=1e-4, atol=1e-4) def test_same_on_batch(self, device, dtype): torch.manual_seed(42) degrees = torch.tensor([[10, 20], [10, 20], [10, 20]]) translate = torch.tensor([0.1, 0.1, 0.1]) scale = torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]]) shear = torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20], [0, 20]]) res = random_affine_generator3d( batch_size=2, depth=200, height=200, width=200, degrees=degrees.to(device=device, dtype=dtype), translate=translate.to(device=device, dtype=dtype) if translate is not None else None, scale=scale.to(device=device, dtype=dtype) if scale is not None else None, shears=shear.to(device=device, dtype=dtype) if shear is not None else None, same_on_batch=True, ) expected = dict( translations=torch.tensor( [[-9.7371, 11.7457, 17.6309], [-9.7371, 11.7457, 17.6309]], device=device, dtype=dtype ), center=torch.tensor([[99.5000, 99.5000, 99.5000], [99.5000, 99.5000, 99.5000]], device=device, dtype=dtype), scale=torch.tensor([[1.1797, 0.8952, 1.0004], [1.1797, 0.8952, 1.0004]], device=device, dtype=dtype), angles=torch.tensor([[18.8227, 19.1500, 13.8286], [18.8227, 19.1500, 13.8286]], device=device, dtype=dtype), sxy=torch.tensor([2.6637, 2.6637], device=device, dtype=dtype), sxz=torch.tensor([18.6920, 18.6920], device=device, dtype=dtype), syx=torch.tensor([11.8716, 11.8716], device=device, dtype=dtype), syz=torch.tensor([17.3881, 17.3881], device=device, dtype=dtype), szx=torch.tensor([11.3543, 11.3543], device=device, dtype=dtype), szy=torch.tensor([14.8219, 14.8219], device=device, dtype=dtype), ) assert res.keys() == expected.keys() assert_close(res['translations'], expected['translations'], rtol=1e-4, atol=1e-4) assert_close(res['center'], expected['center'], rtol=1e-4, atol=1e-4) assert_close(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4) assert_close(res['angles'], expected['angles'], rtol=1e-4, atol=1e-4) assert_close(res['sxy'], expected['sxy'], rtol=1e-4, atol=1e-4) assert_close(res['sxz'], expected['sxz'], rtol=1e-4, atol=1e-4) assert_close(res['syx'], expected['syx'], rtol=1e-4, atol=1e-4) assert_close(res['syz'], expected['syz'], rtol=1e-4, atol=1e-4) assert_close(res['szx'], expected['szx'], rtol=1e-4, atol=1e-4) assert_close(res['szy'], expected['szy'], rtol=1e-4, atol=1e-4) class TestRandomRotationGen3D(RandomGeneratorBaseTests): @pytest.mark.parametrize('batch_size', [0, 1, 8]) @pytest.mark.parametrize('degrees', [torch.tensor([[0, 30], [0, 30], [0, 30]])]) @pytest.mark.parametrize('same_on_batch', [True, False]) def test_valid_param_combinations(self, batch_size, degrees, same_on_batch, device, dtype): random_rotation_generator3d( batch_size=batch_size, degrees=degrees.to(device=device, dtype=dtype), same_on_batch=same_on_batch ) @pytest.mark.parametrize( 'degrees', [(torch.tensor(10)), (torch.tensor([10])), (torch.tensor([[0, 30]])), (torch.tensor([[0, 30], [0, 30]]))], ) def test_invalid_param_combinations(self, degrees, device, dtype): with pytest.raises(Exception): random_rotation_generator3d(batch_size=8, degrees=degrees.to(device=device, dtype=dtype)) def test_random_gen(self, device, dtype): torch.manual_seed(42) degrees = torch.tensor([[0, 30], [0, 30], [0, 30]]) res = random_rotation_generator3d( batch_size=2, degrees=degrees.to(device=device, dtype=dtype), same_on_batch=False ) expected = dict( yaw=torch.tensor([26.4681, 27.4501], device=device, dtype=dtype), pitch=torch.tensor([11.4859, 28.7792], device=device, dtype=dtype), roll=torch.tensor([11.7134, 18.0269], device=device, dtype=dtype), ) assert res.keys() == expected.keys() assert_close(res['yaw'], expected['yaw'], atol=1e-4, rtol=1e-4) assert_close(res['pitch'], expected['pitch'], atol=1e-4, rtol=1e-4) assert_close(res['roll'], expected['roll'], atol=1e-4, rtol=1e-4) def test_same_on_batch(self, device, dtype): torch.manual_seed(42) degrees = torch.tensor([[0, 30], [0, 30], [0, 30]]) res = random_rotation_generator3d( batch_size=2, degrees=degrees.to(device=device, dtype=dtype), same_on_batch=True ) expected = dict( yaw=torch.tensor([26.4681, 26.4681], device=device, dtype=dtype), pitch=torch.tensor([27.4501, 27.4501], device=device, dtype=dtype), roll=torch.tensor([11.4859, 11.4859], device=device, dtype=dtype), ) assert res.keys() == expected.keys() assert_close(res['yaw'], expected['yaw'], atol=1e-4, rtol=1e-4) assert_close(res['pitch'], expected['pitch'], atol=1e-4, rtol=1e-4) assert_close(res['roll'], expected['roll'], atol=1e-4, rtol=1e-4) class TestRandomCropGen3D(RandomGeneratorBaseTests): @pytest.mark.parametrize('batch_size', [0, 2]) @pytest.mark.parametrize('input_size', [(200, 200, 200)]) @pytest.mark.parametrize('size', [(100, 100, 100), torch.tensor([50, 60, 70])]) @pytest.mark.parametrize('resize_to', [None, (100, 100, 100)]) @pytest.mark.parametrize('same_on_batch', [True, False]) def test_valid_param_combinations(self, batch_size, input_size, size, resize_to, same_on_batch, device, dtype): if isinstance(size, torch.Tensor): size = size.repeat(batch_size, 1).to(device=device, dtype=dtype) random_crop_generator3d( batch_size=batch_size, input_size=input_size, size=size.to(device=device, dtype=dtype) if isinstance(size, torch.Tensor) else size, resize_to=resize_to, same_on_batch=same_on_batch, ) @pytest.mark.parametrize( 'input_size,size,resize_to', [ ((-300, 300, 300), (200, 200, 200), (100, 100, 100)), ((100, 100, 100), (200, 200, 200), (100, 100, 100)), ((200, 200, 200), torch.tensor([50, 50, 50]), (100, 100, 100)), ((100, 100, 100), torch.tensor([[50, 60, 70], [50, 60, 70]]), (100, 100)), ], ) def test_invalid_param_combinations(self, input_size, size, resize_to, device, dtype): with pytest.raises(Exception): random_crop_generator3d( batch_size=2, input_size=input_size, size=size.to(device=device, dtype=dtype) if isinstance(size, torch.Tensor) else size, resize_to=resize_to, ) def test_random_gen(self, device, dtype): torch.manual_seed(42) res = random_crop_generator3d( batch_size=2, input_size=(200, 200, 200), size=torch.tensor([[50, 60, 70], [50, 60, 70]], device=device, dtype=dtype), resize_to=(100, 100, 100), ) expected = dict( src=torch.tensor( [ [ [115, 53, 58], [184, 53, 58], [184, 112, 58], [115, 112, 58], [115, 53, 107], [184, 53, 107], [184, 112, 107], [115, 112, 107], ], [ [119, 135, 90], [188, 135, 90], [188, 194, 90], [119, 194, 90], [119, 135, 139], [188, 135, 139], [188, 194, 139], [119, 194, 139], ], ], device=device, dtype=dtype, ), dst=torch.tensor( [ [ [0, 0, 0], [99, 0, 0], [99, 99, 0], [0, 99, 0], [0, 0, 99], [99, 0, 99], [99, 99, 99], [0, 99, 99], ], [ [0, 0, 0], [99, 0, 0], [99, 99, 0], [0, 99, 0], [0, 0, 99], [99, 0, 99], [99, 99, 99], [0, 99, 99], ], ], device=device, dtype=dtype, ), ) assert res.keys() == expected.keys() assert_close(res['src'], expected['src'], atol=1e-4, rtol=1e-4) assert_close(res['dst'], expected['dst'], atol=1e-4, rtol=1e-4) def test_same_on_batch(self, device, dtype): torch.manual_seed(42) res = random_crop_generator3d( batch_size=2, input_size=(200, 200, 200), size=torch.tensor([[50, 60, 70], [50, 60, 70]], device=device, dtype=dtype), resize_to=(100, 100, 100), same_on_batch=True, ) expected = dict( src=torch.tensor( [ [ [115, 129, 57], [184, 129, 57], [184, 188, 57], [115, 188, 57], [115, 129, 106], [184, 129, 106], [184, 188, 106], [115, 188, 106], ], [ [115, 129, 57], [184, 129, 57], [184, 188, 57], [115, 188, 57], [115, 129, 106], [184, 129, 106], [184, 188, 106], [115, 188, 106], ], ], device=device, dtype=dtype, ), dst=torch.tensor( [ [ [0, 0, 0], [99, 0, 0], [99, 99, 0], [0, 99, 0], [0, 0, 99], [99, 0, 99], [99, 99, 99], [0, 99, 99], ], [ [0, 0, 0], [99, 0, 0], [99, 99, 0], [0, 99, 0], [0, 0, 99], [99, 0, 99], [99, 99, 99], [0, 99, 99], ], ], device=device, dtype=dtype, ), ) assert res.keys() == expected.keys() assert_close(res['src'], expected['src'], atol=1e-4, rtol=1e-4) assert_close(res['dst'], expected['dst'], atol=1e-4, rtol=1e-4) class TestCenterCropGen3D(RandomGeneratorBaseTests): @pytest.mark.parametrize('batch_size', [0, 2]) @pytest.mark.parametrize('depth,height,width', [(200, 200, 200)]) @pytest.mark.parametrize('size', [(100, 100, 100)]) def test_valid_param_combinations(self, batch_size, depth, height, width, size, device, dtype): center_crop_generator3d(batch_size=batch_size, depth=depth, height=height, width=width, size=size) @pytest.mark.parametrize( 'depth,height,width,size', [ (200, 200, -200, (100, 100, 100)), (200, -200, 200, (100, 100)), (200, 100, 100, (300, 120, 100)), (200, 150, 100, (120, 180, 100)), (200, 100, 150, (120, 80, 200)), ], ) def test_invalid_param_combinations(self, depth, height, width, size, device, dtype): with pytest.raises(Exception): center_crop_generator3d(batch_size=2, depth=depth, height=height, width=width, size=size) def test_random_gen(self, device, dtype): torch.manual_seed(42) res = center_crop_generator3d(batch_size=2, depth=200, height=200, width=200, size=(120, 150, 100)) expected = dict( src=torch.tensor( [ [ [50, 25, 40], [149, 25, 40], [149, 174, 40], [50, 174, 40], [50, 25, 159], [149, 25, 159], [149, 174, 159], [50, 174, 159], ] ], device=device, dtype=torch.long, ).repeat(2, 1, 1), dst=torch.tensor( [ [ [0, 0, 0], [99, 0, 0], [99, 149, 0], [0, 149, 0], [0, 0, 119], [99, 0, 119], [99, 149, 119], [0, 149, 119], ] ], device=device, dtype=torch.long, ).repeat(2, 1, 1), ) assert res.keys() == expected.keys() assert_close(res['src'].to(device=device), expected['src'], atol=1e-4, rtol=1e-4) assert_close(res['dst'].to(device=device), expected['dst'], atol=1e-4, rtol=1e-4) def test_same_on_batch(self, device, dtype): pass class TestRandomMotionBlur3D(RandomGeneratorBaseTests): @pytest.mark.parametrize('batch_size', [0, 1, 8]) @pytest.mark.parametrize('kernel_size', [3, (3, 5)]) @pytest.mark.parametrize('angle', [torch.tensor([(10, 30), (30, 60), (60, 90)])]) @pytest.mark.parametrize('direction', [torch.tensor([-1, -1]), torch.tensor([-1, 1]), torch.tensor([1, 1])]) @pytest.mark.parametrize('same_on_batch', [True, False]) def test_valid_param_combinations(self, batch_size, kernel_size, angle, direction, same_on_batch, device, dtype): random_motion_blur_generator3d( batch_size=batch_size, kernel_size=kernel_size, angle=angle.to(device=device, dtype=dtype), direction=direction.to(device=device, dtype=dtype), same_on_batch=same_on_batch, ) @pytest.mark.parametrize( 'kernel_size,angle,direction', [ (4, torch.tensor([(10, 30), (30, 60), (60, 90)]), torch.tensor([-1, 1])), (1, torch.tensor([(10, 30), (30, 60), (60, 90)]), torch.tensor([-1, 1])), ((3, 4, 5), torch.tensor([(10, 30), (30, 60), (60, 90)]), torch.tensor([-1, 1])), (3, torch.tensor([(10, 30), (30, 60), (60, 90)]), torch.tensor([-2, 1])), (3, torch.tensor([(10, 30), (30, 60), (60, 90)]), torch.tensor([-1, 2])), ], ) def test_invalid_param_combinations(self, kernel_size, angle, direction, device, dtype): with pytest.raises(Exception): random_motion_blur_generator3d( batch_size=8, kernel_size=kernel_size, angle=angle.to(device=device, dtype=dtype), direction=direction.to(device=device, dtype=dtype), ) def test_random_gen(self, device, dtype): torch.manual_seed(42) angle = torch.tensor([(10, 30), (30, 60), (60, 90)]) direction = torch.tensor([-1, 1]) res = random_motion_blur_generator3d( batch_size=2, kernel_size=3, angle=angle.to(device=device, dtype=dtype), direction=direction.to(device=device, dtype=dtype), same_on_batch=False, ) expected = dict( ksize_factor=torch.tensor([3, 3], device=device, dtype=torch.int32), angle_factor=torch.tensor( [[27.6454, 41.4859, 71.7134], [28.3001, 58.7792, 78.0269]], device=device, dtype=dtype ), direction_factor=torch.tensor([-0.4869, 0.5873], device=device, dtype=dtype), ) assert res.keys() == expected.keys() assert_close(res['ksize_factor'], expected['ksize_factor'], rtol=1e-4, atol=1e-4) assert_close(res['angle_factor'], expected['angle_factor'], rtol=1e-4, atol=1e-4) assert_close(res['direction_factor'], expected['direction_factor'], rtol=1e-4, atol=1e-4) def test_same_on_batch(self, device, dtype): torch.manual_seed(42) angle = torch.tensor([(10, 30), (30, 60), (60, 90)]) direction = torch.tensor([-1, 1]) res = random_motion_blur_generator3d( batch_size=2, kernel_size=3, angle=angle.to(device=device, dtype=dtype), direction=direction.to(device=device, dtype=dtype), same_on_batch=True, ) expected = dict( ksize_factor=torch.tensor([3, 3], device=device, dtype=torch.int32), angle_factor=torch.tensor( [[27.6454, 57.4501, 71.4859], [27.6454, 57.4501, 71.4859]], device=device, dtype=dtype ), direction_factor=torch.tensor([0.9186, 0.9186], device=device, dtype=dtype), ) assert res.keys() == expected.keys() assert_close(res['ksize_factor'], expected['ksize_factor'], rtol=1e-4, atol=1e-4) assert_close(res['angle_factor'], expected['angle_factor'], rtol=1e-4, atol=1e-4) assert_close(res['direction_factor'], expected['direction_factor'], rtol=1e-4, atol=1e-4)