finhdev commited on
Commit
566678a
·
verified ·
1 Parent(s): e1369ab

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +96 -32
handler.py CHANGED
@@ -1,81 +1,145 @@
1
  # handler.py (repo root)
2
 
3
- # handler.py (repo root)
4
-
5
- import io, base64, torch
6
  from PIL import Image
7
- import open_clip
8
-
9
 
10
  class EndpointHandler:
11
  """
12
- Zero‑shot classifier for MobileCLIP‑B (OpenCLIP) with a text‑embedding cache.
13
-
14
- Client JSON:
15
- {
16
- "inputs": {
17
- "image": "<base64 PNG/JPEG>",
18
- "candidate_labels": ["cat", "dog", ...]
19
  }
20
- }
21
  """
22
 
23
- # ------------------------------------------------- #
24
- # INITIALISATION #
25
- # ------------------------------------------------- #
26
- def __init__(self, path: str = ""):
27
- weights = f"{path}/mobileclip_b.pt"
28
-
29
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
30
- "MobileCLIP-B", pretrained=weights
31
  )
32
  self.model.eval()
33
 
34
- self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
35
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
  self.model.to(self.device)
37
 
38
- # cache: {prompt -> 1×512 tensor on device}
39
  self.label_cache: dict[str, torch.Tensor] = {}
40
 
41
- # ------------------------------------------------- #
42
- # INFERENCE #
43
- # ------------------------------------------------- #
44
  def __call__(self, data):
45
  payload = data.get("inputs", data)
46
-
47
  img_b64 = payload["image"]
48
  labels = payload.get("candidate_labels", [])
49
  if not labels:
50
  return {"error": "candidate_labels list is empty"}
51
 
52
- # --- image ----
53
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
54
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
55
 
56
- # --- text (with cache) ----
57
  missing = [l for l in labels if l not in self.label_cache]
58
  if missing:
59
- tokens = self.tokenizer(missing).to(self.device)
60
  with torch.no_grad():
61
- emb = self.model.encode_text(tokens)
62
  emb = emb / emb.norm(dim=-1, keepdim=True)
63
  for l, e in zip(missing, emb):
64
  self.label_cache[l] = e
65
  txt_feat = torch.stack([self.label_cache[l] for l in labels])
66
 
67
- # --- forward & softmax ----
68
  with torch.no_grad(), torch.cuda.amp.autocast():
69
  img_feat = self.model.encode_image(img_tensor)
70
  img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
71
  probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
72
 
73
- # --- sorted output ----
74
  return [
75
  {"label": l, "score": float(p)}
76
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
77
  ]
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # # handler.py (repo root)
80
  # import io, base64, torch
81
  # from PIL import Image
 
1
  # handler.py (repo root)
2
 
3
+ import io, base64, torch, open_clip
 
 
4
  from PIL import Image
 
 
5
 
6
  class EndpointHandler:
7
  """
8
+ MobileCLIP‑B zero‑shot (OpenCLIP, pretrained = 'datacompdr')
9
+ Expects JSON:
10
+ {
11
+ "inputs": {
12
+ "image": "<base64 PNG/JPEG>",
13
+ "candidate_labels": ["a photo of a cat", ...]
14
+ }
15
  }
 
16
  """
17
 
18
+ # ---------- initialisation (once per container) ----------
19
+ def __init__(self, path=""):
20
+ # Use the same checkpoint as your local workflow
21
+ # No need for the local mobileclip_b.pt file
 
 
22
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
23
+ "mobileclip_b", pretrained="datacompdr"
24
  )
25
  self.model.eval()
26
 
27
+ self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
  self.model.to(self.device)
30
 
31
+ # Cache: {prompt -> 1×512 tensor}
32
  self.label_cache: dict[str, torch.Tensor] = {}
33
 
34
+ # -------------------- inference --------------------------
 
 
35
  def __call__(self, data):
36
  payload = data.get("inputs", data)
 
37
  img_b64 = payload["image"]
38
  labels = payload.get("candidate_labels", [])
39
  if not labels:
40
  return {"error": "candidate_labels list is empty"}
41
 
42
+ # image → tensor
43
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
44
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
45
 
46
+ # text cached embeddings
47
  missing = [l for l in labels if l not in self.label_cache]
48
  if missing:
49
+ tok = self.tokenizer(missing).to(self.device)
50
  with torch.no_grad():
51
+ emb = self.model.encode_text(tok)
52
  emb = emb / emb.norm(dim=-1, keepdim=True)
53
  for l, e in zip(missing, emb):
54
  self.label_cache[l] = e
55
  txt_feat = torch.stack([self.label_cache[l] for l in labels])
56
 
57
+ # forward
58
  with torch.no_grad(), torch.cuda.amp.autocast():
59
  img_feat = self.model.encode_image(img_tensor)
60
  img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
61
  probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
62
 
63
+ # sorted result
64
  return [
65
  {"label": l, "score": float(p)}
66
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
67
  ]
68
 
69
+ # import io, base64, torch
70
+ # from PIL import Image
71
+ # import open_clip
72
+
73
+
74
+ # class EndpointHandler:
75
+ # """
76
+ # Zero‑shot classifier for MobileCLIP‑B (OpenCLIP) with a text‑embedding cache.
77
+
78
+ # Client JSON:
79
+ # {
80
+ # "inputs": {
81
+ # "image": "<base64 PNG/JPEG>",
82
+ # "candidate_labels": ["cat", "dog", ...]
83
+ # }
84
+ # }
85
+ # """
86
+
87
+ # # ------------------------------------------------- #
88
+ # # INITIALISATION #
89
+ # # ------------------------------------------------- #
90
+ # def __init__(self, path: str = ""):
91
+ # weights = f"{path}/mobileclip_b.pt"
92
+
93
+ # self.model, _, self.preprocess = open_clip.create_model_and_transforms(
94
+ # "MobileCLIP-B", pretrained=weights
95
+ # )
96
+ # self.model.eval()
97
+
98
+ # self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
99
+ # self.device = "cuda" if torch.cuda.is_available() else "cpu"
100
+ # self.model.to(self.device)
101
+
102
+ # # cache: {prompt -> 1×512 tensor on device}
103
+ # self.label_cache: dict[str, torch.Tensor] = {}
104
+
105
+ # # ------------------------------------------------- #
106
+ # # INFERENCE #
107
+ # # ------------------------------------------------- #
108
+ # def __call__(self, data):
109
+ # payload = data.get("inputs", data)
110
+
111
+ # img_b64 = payload["image"]
112
+ # labels = payload.get("candidate_labels", [])
113
+ # if not labels:
114
+ # return {"error": "candidate_labels list is empty"}
115
+
116
+ # # --- image ----
117
+ # image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
118
+ # img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
119
+
120
+ # # --- text (with cache) ----
121
+ # missing = [l for l in labels if l not in self.label_cache]
122
+ # if missing:
123
+ # tokens = self.tokenizer(missing).to(self.device)
124
+ # with torch.no_grad():
125
+ # emb = self.model.encode_text(tokens)
126
+ # emb = emb / emb.norm(dim=-1, keepdim=True)
127
+ # for l, e in zip(missing, emb):
128
+ # self.label_cache[l] = e
129
+ # txt_feat = torch.stack([self.label_cache[l] for l in labels])
130
+
131
+ # # --- forward & softmax ----
132
+ # with torch.no_grad(), torch.cuda.amp.autocast():
133
+ # img_feat = self.model.encode_image(img_tensor)
134
+ # img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
135
+ # probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
136
+
137
+ # # --- sorted output ----
138
+ # return [
139
+ # {"label": l, "score": float(p)}
140
+ # for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
141
+ # ]
142
+
143
  # # handler.py (repo root)
144
  # import io, base64, torch
145
  # from PIL import Image