finhdev commited on
Commit
88c8e02
·
verified ·
1 Parent(s): 0c4c344

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +71 -0
handler.py CHANGED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py – place in repo root
2
+ import io, base64, torch
3
+ from PIL import Image
4
+ from transformers import CLIPModel, CLIPProcessor
5
+
6
+
7
+ class EndpointHandler:
8
+ """
9
+ Custom zero‑shot classifier replicating local OpenAI‑CLIP logic.
10
+
11
+ Client JSON must look like:
12
+ {
13
+ "inputs": {
14
+ "image": "<base64 PNG/JPEG>",
15
+ "candidate_labels": ["car", "teddy bear", ...]
16
+ }
17
+ }
18
+ """
19
+
20
+ # -------- initialisation (runs once per container) --------
21
+ def __init__(self, path: str = ""):
22
+ # `path` points to ./ (repo root) where HF already downloaded weights
23
+ self.model = CLIPModel.from_pretrained(path)
24
+ self.processor = CLIPProcessor.from_pretrained(path)
25
+
26
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ self.model.to(self.device).eval()
28
+
29
+ # cache: {prompt → 1×768 tensor on device}
30
+ self.cache: dict[str, torch.Tensor] = {}
31
+
32
+ # --------------------- inference --------------------------
33
+ def __call__(self, data):
34
+ payload = data.get("inputs", data) # unwrap HF envelope
35
+
36
+ img_b64 = payload["image"]
37
+ names = payload.get("candidate_labels", [])
38
+ if not names:
39
+ return {"error": "candidate_labels list is empty"}
40
+
41
+ # ---- prompt engineering identical to local code ----
42
+ prompts = [f"a photo of a {p}" for p in names]
43
+
44
+ # ---- text embeddings with cache --------------------
45
+ missing = [p for p in prompts if p not in self.cache]
46
+ if missing:
47
+ txt_in = self.processor(text=missing, return_tensors="pt",
48
+ padding=True).to(self.device)
49
+ with torch.no_grad():
50
+ txt_emb = self.model.get_text_features(**txt_in)
51
+ txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True)
52
+ for p, e in zip(missing, txt_emb):
53
+ self.cache[p] = e
54
+ txt_feat = torch.stack([self.cache[p] for p in prompts])
55
+
56
+ # ---- image preprocessing ---------------------------
57
+ img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
58
+ img_in = self.processor(images=img, return_tensors="pt").to(self.device)
59
+
60
+ with torch.no_grad(), torch.cuda.amp.autocast():
61
+ img_feat = self.model.get_image_features(**img_in)
62
+ img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
63
+
64
+ # ---- similarity & softmax (same as local) ----------
65
+ probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
66
+
67
+ # ---- return sorted list ----------------------------
68
+ return [
69
+ {"label": n, "score": float(p)}
70
+ for n, p in sorted(zip(names, probs), key=lambda x: x[1], reverse=True)
71
+ ]