Spaces:
Sleeping
Sleeping
File size: 1,552 Bytes
95b1715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from abc import abstractmethod
import torchvision.transforms as transforms
from utils.class_registry import ClassRegistry
transforms_registry = ClassRegistry()
class TransformsConfig(object):
def __init__(self):
pass
@abstractmethod
def get_transforms(self):
pass
class FaceTransforms(TransformsConfig):
def __init__(self):
super(FaceTransforms, self).__init__()
self.image_size = None
def get_transforms(self):
transforms_dict = {
"train": transforms.Compose(
[
transforms.Resize(self.image_size),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
),
"test": transforms.Compose(
[
transforms.Resize(self.image_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
}
return transforms_dict
@transforms_registry.add_to_registry(name="face_256")
class Face256Transforms(FaceTransforms):
def __init__(self):
super(Face256Transforms, self).__init__()
self.image_size = (256, 256)
@transforms_registry.add_to_registry(name="face_1024")
class Face1024Transforms(FaceTransforms):
def __init__(self):
super(Face1024Transforms, self).__init__()
self.image_size = (1024, 1024)
|