Spaces:
Sleeping
Sleeping
| from abc import abstractmethod | |
| import torchvision.transforms as transforms | |
| from utils.class_registry import ClassRegistry | |
| transforms_registry = ClassRegistry() | |
| class TransformsConfig(object): | |
| def __init__(self): | |
| pass | |
| def get_transforms(self): | |
| pass | |
| class FaceTransforms(TransformsConfig): | |
| def __init__(self): | |
| super(FaceTransforms, self).__init__() | |
| self.image_size = None | |
| def get_transforms(self): | |
| transforms_dict = { | |
| "train": transforms.Compose( | |
| [ | |
| transforms.Resize(self.image_size), | |
| transforms.RandomHorizontalFlip(0.5), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ), | |
| "test": transforms.Compose( | |
| [ | |
| transforms.Resize(self.image_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| } | |
| return transforms_dict | |
| class Face256Transforms(FaceTransforms): | |
| def __init__(self): | |
| super(Face256Transforms, self).__init__() | |
| self.image_size = (256, 256) | |
| class Face1024Transforms(FaceTransforms): | |
| def __init__(self): | |
| super(Face1024Transforms, self).__init__() | |
| self.image_size = (1024, 1024) | |