File size: 961 Bytes
9188b68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import io, base64, torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

class EndpointHandler:
    def __init__(self, path=""):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = CLIPModel.from_pretrained(path).to(device)
        self.processor = CLIPProcessor.from_pretrained(path)
        self.device = device

    def __call__(self, data):
        # Expect JSON {"image": "<base64 PNG/JPEG>", "candidate_labels": ["cat","dog"]}
        img_b64 = data["image"]
        labels  = data.get("candidate_labels", [])
        image   = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")

        inputs = self.processor(text=labels, images=image,
                                return_tensors="pt", padding=True).to(self.device)
        probs = self.model(**inputs).logits_per_image.softmax(dim=-1)[0].tolist()
        return [{"label": l, "score": float(p)} for l, p in zip(labels, probs)]