finhdev commited on
Commit
cd0f6ce
·
verified ·
1 Parent(s): 88c8e02

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +23 -38
handler.py CHANGED
@@ -1,71 +1,56 @@
1
- # handler.py – place in repo root
2
  import io, base64, torch
3
  from PIL import Image
4
  from transformers import CLIPModel, CLIPProcessor
5
 
6
-
7
  class EndpointHandler:
8
  """
9
- Custom zero‑shot classifier replicating local OpenAI‑CLIP logic.
10
-
11
- Client JSON must look like:
12
- {
13
- "inputs": {
14
- "image": "<base64 PNG/JPEG>",
15
- "candidate_labels": ["car", "teddy bear", ...]
16
- }
17
  }
 
18
  """
19
 
20
- # -------- initialisation (runs once per container) --------
21
- def __init__(self, path: str = ""):
22
- # `path` points to ./ (repo root) where HF already downloaded weights
23
  self.model = CLIPModel.from_pretrained(path)
24
  self.processor = CLIPProcessor.from_pretrained(path)
25
-
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
  self.model.to(self.device).eval()
 
28
 
29
- # cache: {prompt → 1×768 tensor on device}
30
- self.cache: dict[str, torch.Tensor] = {}
31
-
32
- # --------------------- inference --------------------------
33
  def __call__(self, data):
34
- payload = data.get("inputs", data) # unwrap HF envelope
35
-
36
  img_b64 = payload["image"]
37
- names = payload.get("candidate_labels", [])
38
- if not names:
39
  return {"error": "candidate_labels list is empty"}
40
 
41
- # ---- prompt engineering identical to local code ----
42
- prompts = [f"a photo of a {p}" for p in names]
43
-
44
- # ---- text embeddings with cache --------------------
45
  missing = [p for p in prompts if p not in self.cache]
46
  if missing:
47
- txt_in = self.processor(text=missing, return_tensors="pt",
48
- padding=True).to(self.device)
49
  with torch.no_grad():
50
- txt_emb = self.model.get_text_features(**txt_in)
51
- txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True)
52
- for p, e in zip(missing, txt_emb):
53
  self.cache[p] = e
54
  txt_feat = torch.stack([self.cache[p] for p in prompts])
55
 
56
- # ---- image preprocessing ---------------------------
57
  img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
58
  img_in = self.processor(images=img, return_tensors="pt").to(self.device)
59
-
60
  with torch.no_grad(), torch.cuda.amp.autocast():
61
  img_feat = self.model.get_image_features(**img_in)
62
- img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
63
 
64
- # ---- similarity & softmax (same as local) ----------
65
  probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
66
 
67
- # ---- return sorted list ----------------------------
68
  return [
69
- {"label": n, "score": float(p)}
70
- for n, p in sorted(zip(names, probs), key=lambda x: x[1], reverse=True)
71
  ]
 
1
+ # handler.py
2
  import io, base64, torch
3
  from PIL import Image
4
  from transformers import CLIPModel, CLIPProcessor
5
 
 
6
  class EndpointHandler:
7
  """
8
+ CLIP ViT‑L/14 zero‑shot classifier.
9
+ Expects JSON: {
10
+ "inputs": {
11
+ "image": "<base64>",
12
+ "candidate_labels": ["prompt‑1", "prompt‑2", ...]
 
 
 
13
  }
14
+ }
15
  """
16
 
17
+ def __init__(self, path=""):
 
 
18
  self.model = CLIPModel.from_pretrained(path)
19
  self.processor = CLIPProcessor.from_pretrained(path)
 
20
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
  self.model.to(self.device).eval()
22
+ self.cache: dict[str, torch.Tensor] = {} # prompt -> emb
23
 
 
 
 
 
24
  def __call__(self, data):
25
+ payload = data.get("inputs", data)
 
26
  img_b64 = payload["image"]
27
+ prompts = payload.get("candidate_labels", [])
28
+ if not prompts:
29
  return {"error": "candidate_labels list is empty"}
30
 
31
+ # --- text embeddings with per‑process cache ----------
 
 
 
32
  missing = [p for p in prompts if p not in self.cache]
33
  if missing:
34
+ tok = self.processor(text=missing, return_tensors="pt",
35
+ padding=True).to(self.device)
36
  with torch.no_grad():
37
+ emb = self.model.get_text_features(**tok)
38
+ emb = emb / emb.norm(dim=-1, keepdim=True)
39
+ for p, e in zip(missing, emb):
40
  self.cache[p] = e
41
  txt_feat = torch.stack([self.cache[p] for p in prompts])
42
 
43
+ # --- image embedding ---------------------------------
44
  img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
45
  img_in = self.processor(images=img, return_tensors="pt").to(self.device)
 
46
  with torch.no_grad(), torch.cuda.amp.autocast():
47
  img_feat = self.model.get_image_features(**img_in)
48
+ img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
49
 
50
+ # --- similarity & softmax (identical to local) -------
51
  probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
52
 
 
53
  return [
54
+ {"label": p, "score": float(s)}
55
+ for p, s in sorted(zip(prompts, probs), key=lambda x: x[1], reverse=True)
56
  ]