File size: 775 Bytes
2666e68 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | import torchvision.transforms as transforms
class CLIPTransform(object):
def __init__(self, mode='train'):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if mode == 'train':
self.transforms = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
transforms.ToTensor(),
normalize
])
else:
self.transforms = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
def __call__(self, image):
return self.transforms(image) |