| import unittest | |
| import torch | |
| from PIL import Image | |
| from torchvision.transforms import Compose | |
| from gat.datasets.transforms import (hflip, norm, resize_224, resize_256_224, | |
| resize_512_448, to_color, to_pil, to_ts) | |
| class TestResize(unittest.TestCase): | |
| def test_resize_256_224(self): | |
| inputs = torch.rand(3, 256, 256) | |
| outputs = Compose(resize_256_224())(inputs) | |
| self.assertEqual(outputs.shape, (3, 224, 224)) | |
| def test_resize_512_448(self): | |
| inputs = torch.rand(3, 512, 512) | |
| outputs = Compose(resize_512_448())(inputs) | |
| self.assertEqual(outputs.shape, (3, 448, 448)) | |
| def test_resize_224(self): | |
| inputs = torch.rand(3, 224, 224) | |
| outputs = Compose(resize_224())(inputs) | |
| self.assertEqual(outputs.shape, (3, 224, 224)) | |
| class TestAug(unittest.TestCase): | |
| def test_probability_invalid(self): | |
| with self.assertRaises(AssertionError): | |
| hflip(-0.1) | |
| def test_probability_valid(self): | |
| for valid_p in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]: | |
| result = hflip(valid_p) | |
| self.assertEqual(result[0].p, valid_p) | |
| class TestTypeTransforms(unittest.TestCase): | |
| def test_to_ts(self): | |
| pil_image = Image.new('RGB', (224, 224)) | |
| outputs = Compose(to_ts())(pil_image) | |
| self.assertIsInstance(outputs, torch.Tensor) | |
| def test_to_pil(self): | |
| inputs = torch.rand(3, 224, 224) | |
| outputs = Compose(to_pil())(inputs) | |
| self.assertIsInstance(outputs, Image.Image) | |
| def test_to_color_grayscale(self): | |
| inputs = Image.new('L', (224, 224)) | |
| outputs = Compose(to_color())(inputs) | |
| self.assertEqual(outputs.mode, 'RGB') | |
| def test_to_color_rgb(self): | |
| inputs = Image.new('RGB', (224, 224)) | |
| outputs = Compose(to_color())(inputs) | |
| self.assertEqual(outputs.mode, 'RGB') | |
| class TestNorm(unittest.TestCase): | |
| def test_default(self): | |
| self.assertEqual(norm()[0].mean, (0.485, 0.456, 0.406)) | |
| self.assertEqual(norm()[0].std, (0.229, 0.224, 0.225)) | |
| def test_imagenet(self): | |
| self.assertEqual(norm('IMAGENET')[0].mean, (0.485, 0.456, 0.406)) | |
| self.assertEqual(norm('IMAGENET')[0].std, (0.229, 0.224, 0.225)) | |
| def test_invalid_ds(self): | |
| with self.assertRaises(AttributeError): | |
| norm('imagenets') | |
| if __name__ == '__main__': | |
| unittest.main() | |