"""Image transforms for training, validation, and test splits.""" from torchvision import transforms IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] IMG_SIZE = 224 def get_train_transforms(augment_level="standard"): if augment_level == "mild": return transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), transforms.RandomErasing(p=0.1), ]) return transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) def get_eval_transforms(): return transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ])