# handler.py import io, base64, torch from PIL import Image from transformers import CLIPModel, CLIPProcessor class EndpointHandler: """ CLIP ViT‑L/14 zero‑shot classifier. Expects JSON: { "inputs": { "image": "", "candidate_labels": ["prompt‑1", "prompt‑2", ...] } } """ def __init__(self, path=""): self.model = CLIPModel.from_pretrained(path) self.processor = CLIPProcessor.from_pretrained(path) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device).eval() self.cache: dict[str, torch.Tensor] = {} # prompt -> emb def __call__(self, data): payload = data.get("inputs", data) img_b64 = payload["image"] prompts = payload.get("candidate_labels", []) if not prompts: return {"error": "candidate_labels list is empty"} # --- text embeddings with per‑process cache ---------- missing = [p for p in prompts if p not in self.cache] if missing: tok = self.processor(text=missing, return_tensors="pt", padding=True).to(self.device) with torch.no_grad(): emb = self.model.get_text_features(**tok) emb = emb / emb.norm(dim=-1, keepdim=True) for p, e in zip(missing, emb): self.cache[p] = e txt_feat = torch.stack([self.cache[p] for p in prompts]) # --- image embedding --------------------------------- img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") img_in = self.processor(images=img, return_tensors="pt").to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): img_feat = self.model.get_image_features(**img_in) img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) # --- similarity & softmax (identical to local) ------- probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist() return [ {"label": p, "score": float(s)} for p, s in sorted(zip(prompts, probs), key=lambda x: x[1], reverse=True) ]