import unittest import torch from PIL import Image from torchvision.transforms import Compose from gat.datasets.transforms import (hflip, norm, resize_224, resize_256_224, resize_512_448, to_color, to_pil, to_ts) class TestResize(unittest.TestCase): def test_resize_256_224(self): inputs = torch.rand(3, 256, 256) outputs = Compose(resize_256_224())(inputs) self.assertEqual(outputs.shape, (3, 224, 224)) def test_resize_512_448(self): inputs = torch.rand(3, 512, 512) outputs = Compose(resize_512_448())(inputs) self.assertEqual(outputs.shape, (3, 448, 448)) def test_resize_224(self): inputs = torch.rand(3, 224, 224) outputs = Compose(resize_224())(inputs) self.assertEqual(outputs.shape, (3, 224, 224)) class TestAug(unittest.TestCase): def test_probability_invalid(self): with self.assertRaises(AssertionError): hflip(-0.1) def test_probability_valid(self): for valid_p in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]: result = hflip(valid_p) self.assertEqual(result[0].p, valid_p) class TestTypeTransforms(unittest.TestCase): def test_to_ts(self): pil_image = Image.new('RGB', (224, 224)) outputs = Compose(to_ts())(pil_image) self.assertIsInstance(outputs, torch.Tensor) def test_to_pil(self): inputs = torch.rand(3, 224, 224) outputs = Compose(to_pil())(inputs) self.assertIsInstance(outputs, Image.Image) def test_to_color_grayscale(self): inputs = Image.new('L', (224, 224)) outputs = Compose(to_color())(inputs) self.assertEqual(outputs.mode, 'RGB') def test_to_color_rgb(self): inputs = Image.new('RGB', (224, 224)) outputs = Compose(to_color())(inputs) self.assertEqual(outputs.mode, 'RGB') class TestNorm(unittest.TestCase): def test_default(self): self.assertEqual(norm()[0].mean, (0.485, 0.456, 0.406)) self.assertEqual(norm()[0].std, (0.229, 0.224, 0.225)) def test_imagenet(self): self.assertEqual(norm('IMAGENET')[0].mean, (0.485, 0.456, 0.406)) self.assertEqual(norm('IMAGENET')[0].std, (0.229, 0.224, 0.225)) def test_invalid_ds(self): with self.assertRaises(AttributeError): norm('imagenets') if __name__ == '__main__': unittest.main()