# handler.py – place in repo root import io, base64, torch from PIL import Image from transformers import CLIPModel, CLIPProcessor class EndpointHandler: """ Custom zero‑shot classifier replicating local OpenAI‑CLIP logic. Client JSON must look like: { "inputs": { "image": "", "candidate_labels": ["car", "teddy bear", ...] } } """ # -------- initialisation (runs once per container) -------- def __init__(self, path: str = ""): # `path` points to ./ (repo root) where HF already downloaded weights self.model = CLIPModel.from_pretrained(path) self.processor = CLIPProcessor.from_pretrained(path) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device).eval() # cache: {prompt → 1×768 tensor on device} self.cache: dict[str, torch.Tensor] = {} # --------------------- inference -------------------------- def __call__(self, data): payload = data.get("inputs", data) # unwrap HF envelope img_b64 = payload["image"] names = payload.get("candidate_labels", []) if not names: return {"error": "candidate_labels list is empty"} # ---- prompt engineering identical to local code ---- prompts = [f"a photo of a {p}" for p in names] # ---- text embeddings with cache -------------------- missing = [p for p in prompts if p not in self.cache] if missing: txt_in = self.processor(text=missing, return_tensors="pt", padding=True).to(self.device) with torch.no_grad(): txt_emb = self.model.get_text_features(**txt_in) txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True) for p, e in zip(missing, txt_emb): self.cache[p] = e txt_feat = torch.stack([self.cache[p] for p in prompts]) # ---- image preprocessing --------------------------- img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") img_in = self.processor(images=img, return_tensors="pt").to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): img_feat = self.model.get_image_features(**img_in) img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) # ---- similarity & softmax (same as local) ---------- probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist() # ---- return sorted list ---------------------------- return [ {"label": n, "score": float(p)} for n, p in sorted(zip(names, probs), key=lambda x: x[1], reverse=True) ]