testmobileclip / handler.py
finhdev's picture
Create handler.py
9188b68 verified
raw
history blame
961 Bytes
import io, base64, torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
class EndpointHandler:
def __init__(self, path=""):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = CLIPModel.from_pretrained(path).to(device)
self.processor = CLIPProcessor.from_pretrained(path)
self.device = device
def __call__(self, data):
# Expect JSON {"image": "<base64 PNG/JPEG>", "candidate_labels": ["cat","dog"]}
img_b64 = data["image"]
labels = data.get("candidate_labels", [])
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
inputs = self.processor(text=labels, images=image,
return_tensors="pt", padding=True).to(self.device)
probs = self.model(**inputs).logits_per_image.softmax(dim=-1)[0].tolist()
return [{"label": l, "score": float(p)} for l, p in zip(labels, probs)]