File size: 1,340 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import List

import torchvision

from . import env


def resize_256_224() -> List:
    return [
        torchvision.transforms.Resize(size=256),
        torchvision.transforms.CenterCrop(size=(224, 224)),
    ]


def resize_512_448() -> List:
    return [
        torchvision.transforms.Resize(size=512),
        torchvision.transforms.CenterCrop(size=(448, 448)),
    ]


def resize_224() -> List:
    return [torchvision.transforms.Resize(size=224)]


def hflip(p: float = 0.5) -> List:
    assert 0 <= p <= 1
    return [torchvision.transforms.RandomHorizontalFlip(p)]


def to_ts() -> List:
    return [torchvision.transforms.ToTensor()]


def to_pil() -> List:
    return [torchvision.transforms.ToPILImage()]


def to_color() -> List:
    return [
        torchvision.transforms.Lambda(lambda x: x.convert('RGB')
                                      if x.mode != 'RGB' else x)
    ]


def norm(dataset: str = 'IMAGENET', _callable: bool = False) -> List:
    dataset = dataset.upper()
    mean_std = (
        getattr(env, dataset + '_DEFAULT_MEAN'),
        getattr(env, dataset + '_DEFAULT_STD'),
    )
    transforms = [torchvision.transforms.Normalize(*mean_std)]
    if _callable:
        return torchvision.transforms.Compose(transforms)
    return transforms