vision
rbanfield's picture
continue debugging adventures
e36f852
raw
history blame
1.58 kB
import sys
from io import BytesIO
import base64
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPTextModel, CLIPVisionModelWithProjection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.text_model = CLIPTextModel.from_pretrained("rbanfield/clip-vit-large-patch14")
self.image_model = CLIPVisionModelWithProjection.from_pretrained("rbanfield/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
def __call__(self, data):
inputs = data.pop("inputs", None)
print(inputs, file=sys.stderr)
text_input = inputs["text"] if "text" in inputs else None
image_input = inputs["image"] if "image" in inputs else None
if text_input:
print("in text mode", file=sys.stderr)
print(text_input, file=sys.stderr)
processor = self.processor(text=text_input, return_tensors="pt", padding=True)
with torch.no_grad():
return self.text_model(**processor).pooler_output.tolist()
elif image_input:
print("in image mode", file=sys.stderr)
print(image_input, file=sys.stderr)
image = Image.open(BytesIO(base64.b64decode(image_input)))
processor = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
return self.image_model(**processor).image_embeds.tolist()
else:
return None