finhdev commited on
Commit
652b877
·
verified ·
1 Parent(s): 98ef5e5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -23
handler.py CHANGED
@@ -1,60 +1,53 @@
1
 
2
- import io, base64, torch
3
  from PIL import Image
4
  import open_clip
5
- # Make sure the mobileclip library is installed in your Hugging Face environment
6
- # You might need to add it to your requirements.txt
7
  from mobileclip.modules.common.mobileone import reparameterize_model
8
 
9
- class EndpointHandler:
10
- """
11
- Zero-shot classifier for MobileCLIP-B (OpenCLIP).
12
- """
13
 
 
14
  def __init__(self, path: str = ""):
 
15
  weights = f"{path}/mobileclip_b.pt"
16
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
17
  "MobileCLIP-B", pretrained=weights
18
  )
19
  self.model.eval()
 
20
 
21
- # *** THIS IS THE CRUCIAL ADDITION ***
22
- self.model = reparameterize_model(self.model)
23
-
24
- self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
25
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
  self.model.to(self.device)
 
27
 
28
  def __call__(self, data):
29
- # ... (the rest of your __call__ method remains the same)
30
- # ── unwrap Hugging Face's `inputs` envelope ───────────
31
  payload = data.get("inputs", data)
32
-
33
  img_b64 = payload["image"]
34
  labels = payload.get("candidate_labels", [])
35
  if not labels:
36
  return {"error": "candidate_labels list is empty"}
37
 
38
- # Decode & preprocess image
39
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
40
- img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
41
-
42
- # Tokenise labels
43
  text_tokens = self.tokenizer(labels).to(self.device)
44
 
45
- # Forward pass
46
- with torch.no_grad(), torch.cuda.amp.autocast():
 
 
 
47
  img_feat = self.model.encode_image(img_tensor)
48
  txt_feat = self.model.encode_text(text_tokens)
49
- img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
50
- txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
51
  probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
52
 
53
- # Sorted output
54
  return [
55
  {"label": l, "score": float(p)}
56
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
57
  ]
 
 
58
  # # handler.py (repo root)
59
  # import io, base64, torch
60
  # from PIL import Image
 
1
 
2
+ import contextlib, io, base64, torch
3
  from PIL import Image
4
  import open_clip
 
 
5
  from mobileclip.modules.common.mobileone import reparameterize_model
6
 
 
 
 
 
7
 
8
+ class EndpointHandler:
9
  def __init__(self, path: str = ""):
10
+ # You can also pass pretrained='datacompdr' to let OpenCLIP download
11
  weights = f"{path}/mobileclip_b.pt"
12
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
13
  "MobileCLIP-B", pretrained=weights
14
  )
15
  self.model.eval()
16
+ self.model = reparameterize_model(self.model) # *** fuse branches ***
17
 
 
 
 
 
18
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
  self.model.to(self.device)
20
+ self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
21
 
22
  def __call__(self, data):
 
 
23
  payload = data.get("inputs", data)
 
24
  img_b64 = payload["image"]
25
  labels = payload.get("candidate_labels", [])
26
  if not labels:
27
  return {"error": "candidate_labels list is empty"}
28
 
29
+ # ---------------- decode inputs ----------------
30
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
31
+ img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
 
 
32
  text_tokens = self.tokenizer(labels).to(self.device)
33
 
34
+ # ---------------- forward pass -----------------
35
+ autocast_ctx = (
36
+ torch.cuda.amp.autocast if self.device.startswith("cuda") else contextlib.nullcontext
37
+ )
38
+ with torch.no_grad(), autocast_ctx():
39
  img_feat = self.model.encode_image(img_tensor)
40
  txt_feat = self.model.encode_text(text_tokens)
41
+ img_feat /= img_feat.norm(dim=-1, keepdim=True)
42
+ txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
43
  probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
44
 
 
45
  return [
46
  {"label": l, "score": float(p)}
47
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
48
  ]
49
+
50
+
51
  # # handler.py (repo root)
52
  # import io, base64, torch
53
  # from PIL import Image