SAE / attacks /AIM /src /gat /datasets /transforms.py
Ttius's picture
Upload 192 files
998bb30 verified
from typing import List
import torchvision
from . import env
def resize_256_224() -> List:
return [
torchvision.transforms.Resize(size=256),
torchvision.transforms.CenterCrop(size=(224, 224)),
]
def resize_512_448() -> List:
return [
torchvision.transforms.Resize(size=512),
torchvision.transforms.CenterCrop(size=(448, 448)),
]
def resize_224() -> List:
return [torchvision.transforms.Resize(size=224)]
def hflip(p: float = 0.5) -> List:
assert 0 <= p <= 1
return [torchvision.transforms.RandomHorizontalFlip(p)]
def to_ts() -> List:
return [torchvision.transforms.ToTensor()]
def to_pil() -> List:
return [torchvision.transforms.ToPILImage()]
def to_color() -> List:
return [
torchvision.transforms.Lambda(lambda x: x.convert('RGB')
if x.mode != 'RGB' else x)
]
def norm(dataset: str = 'IMAGENET', _callable: bool = False) -> List:
dataset = dataset.upper()
mean_std = (
getattr(env, dataset + '_DEFAULT_MEAN'),
getattr(env, dataset + '_DEFAULT_STD'),
)
transforms = [torchvision.transforms.Normalize(*mean_std)]
if _callable:
return torchvision.transforms.Compose(transforms)
return transforms