from abc import abstractmethod import torchvision.transforms as transforms from utils.class_registry import ClassRegistry transforms_registry = ClassRegistry() class TransformsConfig(object): def __init__(self): pass @abstractmethod def get_transforms(self): pass class FaceTransforms(TransformsConfig): def __init__(self): super(FaceTransforms, self).__init__() self.image_size = None def get_transforms(self): transforms_dict = { "train": transforms.Compose( [ transforms.Resize(self.image_size), transforms.RandomHorizontalFlip(0.5), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ), "test": transforms.Compose( [ transforms.Resize(self.image_size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) } return transforms_dict @transforms_registry.add_to_registry(name="face_256") class Face256Transforms(FaceTransforms): def __init__(self): super(Face256Transforms, self).__init__() self.image_size = (256, 256) @transforms_registry.add_to_registry(name="face_1024") class Face1024Transforms(FaceTransforms): def __init__(self): super(Face1024Transforms, self).__init__() self.image_size = (1024, 1024)