Forge-EMB-mmclip / utils /vision_transform.py
AL-GR's picture
Upload model
2666e68 verified
raw
history blame contribute delete
775 Bytes
import torchvision.transforms as transforms
class CLIPTransform(object):
def __init__(self, mode='train'):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if mode == 'train':
self.transforms = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
transforms.ToTensor(),
normalize
])
else:
self.transforms = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
def __call__(self, image):
return self.transforms(image)