| import torchvision.transforms as transforms | |
| class CLIPTransform(object): | |
| def __init__(self, mode='train'): | |
| normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| if mode == 'train': | |
| self.transforms = transforms.Compose([ | |
| transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), | |
| transforms.ToTensor(), | |
| normalize | |
| ]) | |
| else: | |
| self.transforms = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| normalize | |
| ]) | |
| def __call__(self, image): | |
| return self.transforms(image) |