finhdev commited on
Commit
dc1caec
·
verified ·
1 Parent(s): 147df04

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -85
handler.py CHANGED
@@ -1,20 +1,14 @@
1
 
2
- # handler.py (repo root)
3
  import io, base64, torch
4
  from PIL import Image
5
  import open_clip
 
 
 
6
 
7
  class EndpointHandler:
8
  """
9
- Zeroshot classifier for MobileCLIPB (OpenCLIP).
10
-
11
- Expected client JSON *to the endpoint*:
12
- {
13
- "inputs": {
14
- "image": "<base64 PNG/JPEG>",
15
- "candidate_labels": ["cat", "dog", ...]
16
- }
17
- }
18
  """
19
 
20
  def __init__(self, path: str = ""):
@@ -24,11 +18,15 @@ class EndpointHandler:
24
  )
25
  self.model.eval()
26
 
 
 
 
27
  self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
  self.model.to(self.device)
30
 
31
  def __call__(self, data):
 
32
  # ── unwrap Hugging Face's `inputs` envelope ───────────
33
  payload = data.get("inputs", data)
34
 
@@ -57,81 +55,6 @@ class EndpointHandler:
57
  {"label": l, "score": float(p)}
58
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
59
  ]
60
-
61
- # import io, base64, torch
62
- # from PIL import Image
63
- # import open_clip
64
-
65
-
66
- # class EndpointHandler:
67
- # """
68
- # Zero‑shot classifier for MobileCLIP‑B (OpenCLIP) with a text‑embedding cache.
69
-
70
- # Client JSON:
71
- # {
72
- # "inputs": {
73
- # "image": "<base64 PNG/JPEG>",
74
- # "candidate_labels": ["cat", "dog", ...]
75
- # }
76
- # }
77
- # """
78
-
79
- # # ------------------------------------------------- #
80
- # # INITIALISATION #
81
- # # ------------------------------------------------- #
82
- # def __init__(self, path: str = ""):
83
- # weights = f"{path}/mobileclip_b.pt"
84
-
85
- # self.model, _, self.preprocess = open_clip.create_model_and_transforms(
86
- # "MobileCLIP-B", pretrained=weights
87
- # )
88
- # self.model.eval()
89
-
90
- # self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
91
- # self.device = "cuda" if torch.cuda.is_available() else "cpu"
92
- # self.model.to(self.device)
93
-
94
- # # cache: {prompt -> 1×512 tensor on device}
95
- # self.label_cache: dict[str, torch.Tensor] = {}
96
-
97
- # # ------------------------------------------------- #
98
- # # INFERENCE #
99
- # # ------------------------------------------------- #
100
- # def __call__(self, data):
101
- # payload = data.get("inputs", data)
102
-
103
- # img_b64 = payload["image"]
104
- # labels = payload.get("candidate_labels", [])
105
- # if not labels:
106
- # return {"error": "candidate_labels list is empty"}
107
-
108
- # # --- image ----
109
- # image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
110
- # img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
111
-
112
- # # --- text (with cache) ----
113
- # missing = [l for l in labels if l not in self.label_cache]
114
- # if missing:
115
- # tokens = self.tokenizer(missing).to(self.device)
116
- # with torch.no_grad():
117
- # emb = self.model.encode_text(tokens)
118
- # emb = emb / emb.norm(dim=-1, keepdim=True)
119
- # for l, e in zip(missing, emb):
120
- # self.label_cache[l] = e
121
- # txt_feat = torch.stack([self.label_cache[l] for l in labels])
122
-
123
- # # --- forward & softmax ----
124
- # with torch.no_grad(), torch.cuda.amp.autocast():
125
- # img_feat = self.model.encode_image(img_tensor)
126
- # img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
127
- # probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
128
-
129
- # # --- sorted output ----
130
- # return [
131
- # {"label": l, "score": float(p)}
132
- # for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
133
- # ]
134
-
135
  # # handler.py (repo root)
136
  # import io, base64, torch
137
  # from PIL import Image
 
1
 
 
2
  import io, base64, torch
3
  from PIL import Image
4
  import open_clip
5
+ # Make sure the mobileclip library is installed in your Hugging Face environment
6
+ # You might need to add it to your requirements.txt
7
+ from mobileclip.modules.common.mobileone import reparameterize_model
8
 
9
  class EndpointHandler:
10
  """
11
+ Zero-shot classifier for MobileCLIP-B (OpenCLIP).
 
 
 
 
 
 
 
 
12
  """
13
 
14
  def __init__(self, path: str = ""):
 
18
  )
19
  self.model.eval()
20
 
21
+ # *** THIS IS THE CRUCIAL ADDITION ***
22
+ self.model = reparameterize_model(self.model)
23
+
24
  self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
25
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
  self.model.to(self.device)
27
 
28
  def __call__(self, data):
29
+ # ... (the rest of your __call__ method remains the same)
30
  # ── unwrap Hugging Face's `inputs` envelope ───────────
31
  payload = data.get("inputs", data)
32
 
 
55
  {"label": l, "score": float(p)}
56
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
57
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # # handler.py (repo root)
59
  # import io, base64, torch
60
  # from PIL import Image