| | from abc import abstractmethod |
| | import torchvision.transforms as transforms |
| |
|
| |
|
| | class TransformsConfig(object): |
| |
|
| | def __init__(self, opts): |
| | self.opts = opts |
| |
|
| | @abstractmethod |
| | def get_transforms(self): |
| | pass |
| |
|
| |
|
| | class EncodeTransforms(TransformsConfig): |
| |
|
| | def __init__(self, opts): |
| | super(EncodeTransforms, self).__init__(opts) |
| |
|
| | def get_transforms(self): |
| | transforms_dict = { |
| | 'transform_gt_train': transforms.Compose([ |
| | transforms.Resize((256, 256)), |
| | transforms.RandomHorizontalFlip(0.5), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), |
| | 'transform_source': None, |
| | 'transform_test': transforms.Compose([ |
| | transforms.Resize((256, 256)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), |
| | 'transform_inference': transforms.Compose([ |
| | transforms.Resize((256, 256)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
| | } |
| | return transforms_dict |
| |
|
| |
|
| | class CarsEncodeTransforms(TransformsConfig): |
| |
|
| | def __init__(self, opts): |
| | super(CarsEncodeTransforms, self).__init__(opts) |
| |
|
| | def get_transforms(self): |
| | transforms_dict = { |
| | 'transform_gt_train': transforms.Compose([ |
| | transforms.Resize((192, 256)), |
| | transforms.RandomHorizontalFlip(0.5), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), |
| | 'transform_source': None, |
| | 'transform_test': transforms.Compose([ |
| | transforms.Resize((192, 256)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), |
| | 'transform_inference': transforms.Compose([ |
| | transforms.Resize((192, 256)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
| | } |
| | return transforms_dict |
| |
|