Spaces:
Sleeping
Sleeping
File size: 1,406 Bytes
df9c255 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
import torchvision.transforms as transforms
import PIL.Image as Image
def build_transform_classification(normalize, crop_size=224, resize=256, tta=True):
transformations_list = []
if normalize.lower() == "imagenet":
normalize = transforms.Normalize(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
elif normalize.lower() == "chestx-ray":
normalize = transforms.Normalize(
[0.5056, 0.5056, 0.5056], [0.252, 0.252, 0.252])
elif normalize.lower() == "none":
normalize = None
else:
print("mean and std for [{}] dataset do not exist!".format(normalize))
exit(-1)
if tta:
transformations_list.append(transforms.Resize((resize, resize)))
transformations_list.append(transforms.TenCrop(crop_size))
transformations_list.append(
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])))
transformations_list.append(transforms.Lambda(
lambda crops: torch.stack([normalize(crop) for crop in crops])))
else:
transformations_list.append(transforms.Resize((resize, resize)))
transformations_list.append(transforms.CenterCrop(crop_size))
transformations_list.append(transforms.ToTensor())
transformSequence = transforms.Compose(transformations_list)
print(transformSequence)
return transformSequence
|