Spaces:
Running
Running
| """Image transforms for training, validation, and test splits.""" | |
| from torchvision import transforms | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| IMG_SIZE = 224 | |
| def get_train_transforms(augment_level="standard"): | |
| if augment_level == "mild": | |
| return transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| transforms.RandomErasing(p=0.1), | |
| ]) | |
| return transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ]) | |
| def get_eval_transforms(): | |
| return transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ]) | |