|
|
from typing import Dict, List, Any |
|
|
from PIL import Image |
|
|
import clip |
|
|
import torch |
|
|
import requests |
|
|
import io |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
|
|
|
self.model, self.preprocess = clip.load('ViT-B/32', device) |
|
|
self.model.eval() |
|
|
self.model = self.model.to(device) |
|
|
|
|
|
|
|
|
def __call__(self, data: Any) -> Dict[str, List[float]]: |
|
|
""" |
|
|
Args: |
|
|
data (:obj:): |
|
|
includes the input data and the parameters for the inference. |
|
|
Return: |
|
|
A :obj:`dict`:. The object returned should be a dict like {"feature_vector": [0.6331314444541931,0.8802216053009033,...,-0.7866355180740356,]} containing : |
|
|
- "feature_vector": A list of floats corresponding to the image embedding. |
|
|
""" |
|
|
inputs = data.pop("inputs", data) |
|
|
if inputs.startswith("http") or inputs.startswith("www"): |
|
|
response = requests.get(inputs).content |
|
|
img = Image.open(io.BytesIO(response)) |
|
|
else: |
|
|
img = Image.open(inputs['image']) |
|
|
|
|
|
|
|
|
image_input = self.preprocess(img).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
image_features = self.model.encode_image(image_input) |
|
|
|
|
|
return {"feature_vector": image_features.tolist()[0]} |
|
|
|