finhdev commited on
Commit
270502a
Β·
verified Β·
1 Parent(s): e53312a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +96 -30
handler.py CHANGED
@@ -1,34 +1,27 @@
1
  # handler.py
2
- import io, base64, torch
3
  from PIL import Image
4
  from transformers import CLIPModel, CLIPProcessor
5
 
6
  class EndpointHandler:
7
- """
8
- CLIP ViT‑L/14 zero‑shot classifier.
9
- Expects JSON: {
10
- "inputs": {
11
- "image": "<base64>",
12
- "candidate_labels": ["prompt‑1", "prompt‑2", ...]
13
- }
14
- }
15
- """
16
-
17
  def __init__(self, path=""):
18
- self.model = CLIPModel.from_pretrained(path)
19
- self.processor = CLIPProcessor.from_pretrained(path)
20
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
  self.model.to(self.device).eval()
22
- self.cache: dict[str, torch.Tensor] = {} # prompt -> emb
23
 
 
24
  def __call__(self, data):
25
- payload = data.get("inputs", data)
26
- img_b64 = payload["image"]
27
- prompts = payload.get("candidate_labels", [])
28
- if not prompts:
29
- return {"error": "candidate_labels list is empty"}
 
30
 
31
- # --- text embeddings with per‑process cache ----------
 
32
  missing = [p for p in prompts if p not in self.cache]
33
  if missing:
34
  tok = self.processor(text=missing, return_tensors="pt",
@@ -39,25 +32,98 @@ class EndpointHandler:
39
  for p, e in zip(missing, emb):
40
  self.cache[p] = e
41
  txt_feat = torch.stack([self.cache[p] for p in prompts])
 
42
 
43
- # --- image embedding ---------------------------------
 
44
  img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
45
  img_in = self.processor(images=img, return_tensors="pt").to(self.device)
46
-
 
 
 
47
  with torch.no_grad(), torch.cuda.amp.autocast():
48
  img_feat = self.model.get_image_features(**img_in)
49
-
50
  img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
51
- # txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
52
-
53
- img_feat = img_feat.float() # ← add these two lines
54
- txt_feat = txt_feat.float() # ←
55
-
56
  probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
 
57
 
58
-
 
 
 
 
59
 
 
60
  return [
61
  {"label": p, "score": float(s)}
62
  for p, s in sorted(zip(prompts, probs), key=lambda x: x[1], reverse=True)
63
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # handler.py
2
+ import io, base64, time, torch
3
  from PIL import Image
4
  from transformers import CLIPModel, CLIPProcessor
5
 
6
  class EndpointHandler:
 
 
 
 
 
 
 
 
 
 
7
  def __init__(self, path=""):
8
+ self.model = CLIPModel.from_pretrained(path)
9
+ self.processor = CLIPProcessor.from_pretrained(path)
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
  self.model.to(self.device).eval()
12
+ self.cache: dict[str, torch.Tensor] = {}
13
 
14
+ # -------------------------------------------------------
15
  def __call__(self, data):
16
+ T = {} # timing dict
17
+ t0 = time.perf_counter()
18
+
19
+ payload = data.get("inputs", data)
20
+ img_b64 = payload["image"]
21
+ prompts = payload["candidate_labels"]
22
 
23
+ # β€”β€” text embeddings (cache) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
24
+ t = time.perf_counter()
25
  missing = [p for p in prompts if p not in self.cache]
26
  if missing:
27
  tok = self.processor(text=missing, return_tensors="pt",
 
32
  for p, e in zip(missing, emb):
33
  self.cache[p] = e
34
  txt_feat = torch.stack([self.cache[p] for p in prompts])
35
+ T["encode_text"] = (time.perf_counter() - t) * 1000 # ms
36
 
37
+ # β€”β€” image preprocessing β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
38
+ t = time.perf_counter()
39
  img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
40
  img_in = self.processor(images=img, return_tensors="pt").to(self.device)
41
+ T["decode_resize"] = (time.perf_counter() - t) * 1000
42
+
43
+ # β€”β€” image embedding β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
44
+ t = time.perf_counter()
45
  with torch.no_grad(), torch.cuda.amp.autocast():
46
  img_feat = self.model.get_image_features(**img_in)
 
47
  img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
48
+ img_feat = img_feat.float(); txt_feat = txt_feat.float()
49
+ T["encode_image"] = (time.perf_counter() - t) * 1000
50
+
51
+ # β€”β€” similarity & softmax β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
52
+ t = time.perf_counter()
53
  probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
54
+ T["similarity_softmax"] = (time.perf_counter() - t) * 1000
55
 
56
+ # β€”β€” log timings β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
57
+ total = (time.perf_counter() - t0) * 1000
58
+ print(f"[CLIP timings] total={total:.1f}β€―ms | " +
59
+ " | ".join(f"{k}={v:.1f}" for k, v in T.items()),
60
+ flush=True)
61
 
62
+ # β€”β€” build response β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
63
  return [
64
  {"label": p, "score": float(s)}
65
  for p, s in sorted(zip(prompts, probs), key=lambda x: x[1], reverse=True)
66
  ]
67
+
68
+ # import io, base64, torch
69
+ # from PIL import Image
70
+ # from transformers import CLIPModel, CLIPProcessor
71
+
72
+ # class EndpointHandler:
73
+ # """
74
+ # CLIP ViT‑L/14 zero‑shot classifier.
75
+ # Expects JSON: {
76
+ # "inputs": {
77
+ # "image": "<base64>",
78
+ # "candidate_labels": ["prompt‑1", "prompt‑2", ...]
79
+ # }
80
+ # }
81
+ # """
82
+
83
+ # def __init__(self, path=""):
84
+ # self.model = CLIPModel.from_pretrained(path)
85
+ # self.processor = CLIPProcessor.from_pretrained(path)
86
+ # self.device = "cuda" if torch.cuda.is_available() else "cpu"
87
+ # self.model.to(self.device).eval()
88
+ # self.cache: dict[str, torch.Tensor] = {} # prompt -> emb
89
+
90
+ # def __call__(self, data):
91
+ # payload = data.get("inputs", data)
92
+ # img_b64 = payload["image"]
93
+ # prompts = payload.get("candidate_labels", [])
94
+ # if not prompts:
95
+ # return {"error": "candidate_labels list is empty"}
96
+
97
+ # # --- text embeddings with per‑process cache ----------
98
+ # missing = [p for p in prompts if p not in self.cache]
99
+ # if missing:
100
+ # tok = self.processor(text=missing, return_tensors="pt",
101
+ # padding=True).to(self.device)
102
+ # with torch.no_grad():
103
+ # emb = self.model.get_text_features(**tok)
104
+ # emb = emb / emb.norm(dim=-1, keepdim=True)
105
+ # for p, e in zip(missing, emb):
106
+ # self.cache[p] = e
107
+ # txt_feat = torch.stack([self.cache[p] for p in prompts])
108
+
109
+ # # --- image embedding ---------------------------------
110
+ # img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
111
+ # img_in = self.processor(images=img, return_tensors="pt").to(self.device)
112
+
113
+ # with torch.no_grad(), torch.cuda.amp.autocast():
114
+ # img_feat = self.model.get_image_features(**img_in)
115
+
116
+ # img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
117
+ # # txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
118
+
119
+ # img_feat = img_feat.float() # ← add these two lines
120
+ # txt_feat = txt_feat.float() # ←
121
+
122
+ # probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
123
+
124
+
125
+
126
+ # return [
127
+ # {"label": p, "score": float(s)}
128
+ # for p, s in sorted(zip(prompts, probs), key=lambda x: x[1], reverse=True)
129
+ # ]