manevamarija's picture
Upload model with custom handler
28cb177 verified
raw
history blame
922 Bytes
from typing import Dict, List, Any
from PIL import Image
from io import BytesIO
from transformers import pipeline
import base64
class EndpointHandler():
def __init__(self, path=""):
self.pipeline = pipeline("zero-shot-image-classification", model="openai/clip-vit-large-patch14-336")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
images (:obj:`string`)
candiates (:obj:`list`)
Return:
A :obj:`list`:. The list contains items that are dicts like {"label": "XXX", "score": 0.82}
"""
inputs = data.pop("inputs", data)
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
# run prediction with provided candiates
prediction = self.pipeline(images=[image], candidate_labels=inputs["candiates"])
return prediction[0]