Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision.transforms as transforms | |
| import PIL.Image as Image | |
| def build_transform_classification(normalize, crop_size=224, resize=256, tta=True): | |
| transformations_list = [] | |
| if normalize.lower() == "imagenet": | |
| normalize = transforms.Normalize( | |
| [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| elif normalize.lower() == "chestx-ray": | |
| normalize = transforms.Normalize( | |
| [0.5056, 0.5056, 0.5056], [0.252, 0.252, 0.252]) | |
| elif normalize.lower() == "none": | |
| normalize = None | |
| else: | |
| print("mean and std for [{}] dataset do not exist!".format(normalize)) | |
| exit(-1) | |
| if tta: | |
| transformations_list.append(transforms.Resize((resize, resize))) | |
| transformations_list.append(transforms.TenCrop(crop_size)) | |
| transformations_list.append( | |
| transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))) | |
| transformations_list.append(transforms.Lambda( | |
| lambda crops: torch.stack([normalize(crop) for crop in crops]))) | |
| else: | |
| transformations_list.append(transforms.Resize((resize, resize))) | |
| transformations_list.append(transforms.CenterCrop(crop_size)) | |
| transformations_list.append(transforms.ToTensor()) | |
| transformSequence = transforms.Compose(transformations_list) | |
| print(transformSequence) | |
| return transformSequence | |