import torchvision.transforms as transforms def get_transforms(image_size=(224, 224), train=True): if train: return transforms.Compose([ transforms.Resize(image_size), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.1, contrast=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats ]) else: return transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])