vision
rbanfield's picture
debugging error not seen locally
eddf2f6
raw
history blame
1.48 kB
from io import BytesIO
import base64
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPTextModel, CLIPVisionModelWithProjection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.text_model = CLIPTextModel.from_pretrained("rbanfield/clip-vit-large-patch14")
self.image_model = CLIPVisionModelWithProjection.from_pretrained("rbanfield/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
def __call__(self, data):
inputs = data.pop("inputs", None)
print(inputs)
text_input = inputs["text"] if "text" in inputs else None
image_input = inputs["image"] if "image" in inputs else None
if text_input:
print("in text mode")
print(text_input)
processor = self.processor(text=text_input, return_tensors="pt", padding=True)
with torch.no_grad():
return self.text_model(**processor).pooler_output.tolist()
elif image_input:
print("in image mode")
print(image_input)
image = Image.open(BytesIO(base64.b64decode(image_input)))
processor = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
return self.image_model(**processor).image_embeds.tolist()
else:
return None