| from typing import List | |
| import torchvision | |
| from . import env | |
| def resize_256_224() -> List: | |
| return [ | |
| torchvision.transforms.Resize(size=256), | |
| torchvision.transforms.CenterCrop(size=(224, 224)), | |
| ] | |
| def resize_512_448() -> List: | |
| return [ | |
| torchvision.transforms.Resize(size=512), | |
| torchvision.transforms.CenterCrop(size=(448, 448)), | |
| ] | |
| def resize_224() -> List: | |
| return [torchvision.transforms.Resize(size=224)] | |
| def hflip(p: float = 0.5) -> List: | |
| assert 0 <= p <= 1 | |
| return [torchvision.transforms.RandomHorizontalFlip(p)] | |
| def to_ts() -> List: | |
| return [torchvision.transforms.ToTensor()] | |
| def to_pil() -> List: | |
| return [torchvision.transforms.ToPILImage()] | |
| def to_color() -> List: | |
| return [ | |
| torchvision.transforms.Lambda(lambda x: x.convert('RGB') | |
| if x.mode != 'RGB' else x) | |
| ] | |
| def norm(dataset: str = 'IMAGENET', _callable: bool = False) -> List: | |
| dataset = dataset.upper() | |
| mean_std = ( | |
| getattr(env, dataset + '_DEFAULT_MEAN'), | |
| getattr(env, dataset + '_DEFAULT_STD'), | |
| ) | |
| transforms = [torchvision.transforms.Normalize(*mean_std)] | |
| if _callable: | |
| return torchvision.transforms.Compose(transforms) | |
| return transforms | |