File size: 1,340 Bytes
998bb30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from typing import List
import torchvision
from . import env
def resize_256_224() -> List:
return [
torchvision.transforms.Resize(size=256),
torchvision.transforms.CenterCrop(size=(224, 224)),
]
def resize_512_448() -> List:
return [
torchvision.transforms.Resize(size=512),
torchvision.transforms.CenterCrop(size=(448, 448)),
]
def resize_224() -> List:
return [torchvision.transforms.Resize(size=224)]
def hflip(p: float = 0.5) -> List:
assert 0 <= p <= 1
return [torchvision.transforms.RandomHorizontalFlip(p)]
def to_ts() -> List:
return [torchvision.transforms.ToTensor()]
def to_pil() -> List:
return [torchvision.transforms.ToPILImage()]
def to_color() -> List:
return [
torchvision.transforms.Lambda(lambda x: x.convert('RGB')
if x.mode != 'RGB' else x)
]
def norm(dataset: str = 'IMAGENET', _callable: bool = False) -> List:
dataset = dataset.upper()
mean_std = (
getattr(env, dataset + '_DEFAULT_MEAN'),
getattr(env, dataset + '_DEFAULT_STD'),
)
transforms = [torchvision.transforms.Normalize(*mean_std)]
if _callable:
return torchvision.transforms.Compose(transforms)
return transforms
|