File size: 1,406 Bytes
df9c255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

import torch
import torchvision.transforms as transforms
import PIL.Image as Image


def build_transform_classification(normalize, crop_size=224, resize=256, tta=True):
    transformations_list = []
    if normalize.lower() == "imagenet":
      normalize = transforms.Normalize(
          [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    elif normalize.lower() == "chestx-ray":
      normalize = transforms.Normalize(
          [0.5056, 0.5056, 0.5056], [0.252, 0.252, 0.252])
    elif normalize.lower() == "none":
      normalize = None
    else:
      print("mean and std for [{}] dataset do not exist!".format(normalize))
      exit(-1)
    if tta:
        transformations_list.append(transforms.Resize((resize, resize)))
        transformations_list.append(transforms.TenCrop(crop_size))
        transformations_list.append(
            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])))
        transformations_list.append(transforms.Lambda(
            lambda crops: torch.stack([normalize(crop) for crop in crops])))
    else:
        transformations_list.append(transforms.Resize((resize, resize)))
        transformations_list.append(transforms.CenterCrop(crop_size))
        transformations_list.append(transforms.ToTensor())

    transformSequence = transforms.Compose(transformations_list)
    print(transformSequence)
    return transformSequence