import torchvision.transforms as T from core.vision_encoder.tokenizer import SimpleTokenizer def get_image_transform( image_size: int, center_crop: bool = False, interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR # We used bilinear during training ): if center_crop: crop = [ T.Resize(image_size, interpolation=interpolation), T.CenterCrop(image_size) ] else: # "Squash": most versatile crop = [ T.Resize((image_size, image_size), interpolation=interpolation) ] return T.Compose(crop + [ T.Lambda(lambda x: x.convert("RGB")), T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True), ]) def get_text_tokenizer(context_length: int): return SimpleTokenizer(context_length=context_length)