File size: 2,524 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import unittest

import torch
from PIL import Image
from torchvision.transforms import Compose

from gat.datasets.transforms import (hflip, norm, resize_224, resize_256_224,
                                     resize_512_448, to_color, to_pil, to_ts)


class TestResize(unittest.TestCase):

    def test_resize_256_224(self):
        inputs = torch.rand(3, 256, 256)
        outputs = Compose(resize_256_224())(inputs)
        self.assertEqual(outputs.shape, (3, 224, 224))

    def test_resize_512_448(self):
        inputs = torch.rand(3, 512, 512)
        outputs = Compose(resize_512_448())(inputs)
        self.assertEqual(outputs.shape, (3, 448, 448))

    def test_resize_224(self):
        inputs = torch.rand(3, 224, 224)
        outputs = Compose(resize_224())(inputs)
        self.assertEqual(outputs.shape, (3, 224, 224))


class TestAug(unittest.TestCase):

    def test_probability_invalid(self):
        with self.assertRaises(AssertionError):
            hflip(-0.1)

    def test_probability_valid(self):
        for valid_p in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
            result = hflip(valid_p)
            self.assertEqual(result[0].p, valid_p)


class TestTypeTransforms(unittest.TestCase):

    def test_to_ts(self):
        pil_image = Image.new('RGB', (224, 224))
        outputs = Compose(to_ts())(pil_image)
        self.assertIsInstance(outputs, torch.Tensor)

    def test_to_pil(self):
        inputs = torch.rand(3, 224, 224)
        outputs = Compose(to_pil())(inputs)
        self.assertIsInstance(outputs, Image.Image)

    def test_to_color_grayscale(self):
        inputs = Image.new('L', (224, 224))
        outputs = Compose(to_color())(inputs)
        self.assertEqual(outputs.mode, 'RGB')

    def test_to_color_rgb(self):
        inputs = Image.new('RGB', (224, 224))
        outputs = Compose(to_color())(inputs)
        self.assertEqual(outputs.mode, 'RGB')


class TestNorm(unittest.TestCase):

    def test_default(self):
        self.assertEqual(norm()[0].mean, (0.485, 0.456, 0.406))
        self.assertEqual(norm()[0].std, (0.229, 0.224, 0.225))

    def test_imagenet(self):
        self.assertEqual(norm('IMAGENET')[0].mean, (0.485, 0.456, 0.406))
        self.assertEqual(norm('IMAGENET')[0].std, (0.229, 0.224, 0.225))

    def test_invalid_ds(self):
        with self.assertRaises(AttributeError):
            norm('imagenets')


if __name__ == '__main__':
    unittest.main()