YassineB's picture
Add clip for test inference
1b1dde5
from typing import Dict, List, Any
from PIL import Image
import clip
import torch
import requests
import io
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
# load the optimized model
self.model, self.preprocess = clip.load('ViT-B/32', device)
self.model.eval()
self.model = self.model.to(device)
def __call__(self, data: Any) -> Dict[str, List[float]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`dict`:. The object returned should be a dict like {"feature_vector": [0.6331314444541931,0.8802216053009033,...,-0.7866355180740356,]} containing :
- "feature_vector": A list of floats corresponding to the image embedding.
"""
inputs = data.pop("inputs", data)
if inputs.startswith("http") or inputs.startswith("www"):
response = requests.get(inputs).content
img = Image.open(io.BytesIO(response))
else:
img = Image.open(inputs['image'])
# decode base64 image to PIL
image_input = self.preprocess(img).unsqueeze(0).to(device)
# Calculate features
with torch.no_grad():
image_features = self.model.encode_image(image_input)
# postprocess the prediction
return {"feature_vector": image_features.tolist()[0]}