x-ray-vision / src /preprocess_data.py
taheera's picture
initial commit
df9c255
import torch
import torchvision.transforms as transforms
import PIL.Image as Image
def build_transform_classification(normalize, crop_size=224, resize=256, tta=True):
transformations_list = []
if normalize.lower() == "imagenet":
normalize = transforms.Normalize(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
elif normalize.lower() == "chestx-ray":
normalize = transforms.Normalize(
[0.5056, 0.5056, 0.5056], [0.252, 0.252, 0.252])
elif normalize.lower() == "none":
normalize = None
else:
print("mean and std for [{}] dataset do not exist!".format(normalize))
exit(-1)
if tta:
transformations_list.append(transforms.Resize((resize, resize)))
transformations_list.append(transforms.TenCrop(crop_size))
transformations_list.append(
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])))
transformations_list.append(transforms.Lambda(
lambda crops: torch.stack([normalize(crop) for crop in crops])))
else:
transformations_list.append(transforms.Resize((resize, resize)))
transformations_list.append(transforms.CenterCrop(crop_size))
transformations_list.append(transforms.ToTensor())
transformSequence = transforms.Compose(transformations_list)
print(transformSequence)
return transformSequence