SAE / attacks /AIM /tests /test_datasets /test_transforms.py
Ttius's picture
Upload 192 files
998bb30 verified
import unittest
import torch
from PIL import Image
from torchvision.transforms import Compose
from gat.datasets.transforms import (hflip, norm, resize_224, resize_256_224,
resize_512_448, to_color, to_pil, to_ts)
class TestResize(unittest.TestCase):
def test_resize_256_224(self):
inputs = torch.rand(3, 256, 256)
outputs = Compose(resize_256_224())(inputs)
self.assertEqual(outputs.shape, (3, 224, 224))
def test_resize_512_448(self):
inputs = torch.rand(3, 512, 512)
outputs = Compose(resize_512_448())(inputs)
self.assertEqual(outputs.shape, (3, 448, 448))
def test_resize_224(self):
inputs = torch.rand(3, 224, 224)
outputs = Compose(resize_224())(inputs)
self.assertEqual(outputs.shape, (3, 224, 224))
class TestAug(unittest.TestCase):
def test_probability_invalid(self):
with self.assertRaises(AssertionError):
hflip(-0.1)
def test_probability_valid(self):
for valid_p in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
result = hflip(valid_p)
self.assertEqual(result[0].p, valid_p)
class TestTypeTransforms(unittest.TestCase):
def test_to_ts(self):
pil_image = Image.new('RGB', (224, 224))
outputs = Compose(to_ts())(pil_image)
self.assertIsInstance(outputs, torch.Tensor)
def test_to_pil(self):
inputs = torch.rand(3, 224, 224)
outputs = Compose(to_pil())(inputs)
self.assertIsInstance(outputs, Image.Image)
def test_to_color_grayscale(self):
inputs = Image.new('L', (224, 224))
outputs = Compose(to_color())(inputs)
self.assertEqual(outputs.mode, 'RGB')
def test_to_color_rgb(self):
inputs = Image.new('RGB', (224, 224))
outputs = Compose(to_color())(inputs)
self.assertEqual(outputs.mode, 'RGB')
class TestNorm(unittest.TestCase):
def test_default(self):
self.assertEqual(norm()[0].mean, (0.485, 0.456, 0.406))
self.assertEqual(norm()[0].std, (0.229, 0.224, 0.225))
def test_imagenet(self):
self.assertEqual(norm('IMAGENET')[0].mean, (0.485, 0.456, 0.406))
self.assertEqual(norm('IMAGENET')[0].std, (0.229, 0.224, 0.225))
def test_invalid_ds(self):
with self.assertRaises(AttributeError):
norm('imagenets')
if __name__ == '__main__':
unittest.main()