finhdev commited on
Commit
08ce2dc
·
verified ·
1 Parent(s): 52af22a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +38 -22
handler.py CHANGED
@@ -1,56 +1,72 @@
1
- # handler.py – place in repo root
2
- import io, base64, torch
 
 
3
  from PIL import Image
4
 
5
  import open_clip
6
- from mobileclip.modules.common.mobileone import reparameterize_model
7
 
8
  class EndpointHandler:
9
  """
10
- Zero‑shot image classifier for MobileCLIP‑B using OpenCLIP.
11
- Expects JSON:
 
12
  {
13
  "image": "<base64‑encoded PNG/JPEG>",
14
  "candidate_labels": ["cat", "dog", ...]
15
  }
 
 
 
 
 
 
16
  """
 
17
  def __init__(self, path: str = ""):
18
- # Hugging Face Endpoints clones the repo into `path`.
19
- # The weights file is mobileclip_b.pt (already in the repo).
20
  weights = f"{path}/mobileclip_b.pt"
 
 
21
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
22
  "MobileCLIP-B", pretrained=weights
23
  )
24
 
25
- # Re‑parameterize once for faster inference (as per MobileCLIP docs)
26
- self.model = reparameterize_model(self.model)
27
- self.model.eval()
28
 
29
- # OpenCLIP tokenizer (same as CLIP)
30
  self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
31
 
 
32
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
  self.model.to(self.device)
34
 
35
  def __call__(self, data):
36
- # Decode input
37
- img_b64 = data["image"]
38
- labels = data.get("candidate_labels", [])
39
- image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
 
40
 
41
- # Preprocess
 
42
  image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
43
- text_tokens = self.tokenizer(labels).to(self.device)
44
 
 
 
 
 
45
  with torch.no_grad(), torch.cuda.amp.autocast():
46
  img_feat = self.model.encode_image(image_tensor)
47
  txt_feat = self.model.encode_text(text_tokens)
48
- img_feat /= img_feat.norm(dim=-1, keepdim=True)
49
- txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
50
  probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
51
 
 
52
  return [
53
- {"label": l, "score": float(p)} for l, p in sorted(
54
- zip(labels, probs), key=lambda x: x[1], reverse=True
55
- )
56
  ]
 
1
+ # handler.py
2
+ import io
3
+ import base64
4
+ import torch
5
  from PIL import Image
6
 
7
  import open_clip
8
+ from open_clip import fuse_conv_bn_sequential
9
 
10
  class EndpointHandler:
11
  """
12
+ Zero‑shot image classifier for MobileCLIP‑B (OpenCLIP).
13
+
14
+ Expects JSON payload:
15
  {
16
  "image": "<base64‑encoded PNG/JPEG>",
17
  "candidate_labels": ["cat", "dog", ...]
18
  }
19
+ Returns:
20
+ [
21
+ {"label": "cat", "score": 0.91},
22
+ {"label": "dog", "score": 0.05},
23
+ ...
24
+ ]
25
  """
26
+
27
  def __init__(self, path: str = ""):
28
+ # Path points to the repo root inside the container
 
29
  weights = f"{path}/mobileclip_b.pt"
30
+
31
+ # Load model + transforms from OpenCLIP
32
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
33
  "MobileCLIP-B", pretrained=weights
34
  )
35
 
36
+ # Fuse conv + BN for faster inference (same idea as MobileCLIP re‑param)
37
+ self.model = fuse_conv_bn_sequential(self.model).eval()
 
38
 
39
+ # Tokenizer for label prompts
40
  self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
41
 
42
+ # Device selection
43
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
44
  self.model.to(self.device)
45
 
46
  def __call__(self, data):
47
+ # 1. Parse request
48
+ img_b64 = data["image"]
49
+ labels = data.get("candidate_labels", [])
50
+ if not labels:
51
+ return {"error": "candidate_labels list is empty"}
52
 
53
+ # 2. Decode & preprocess image
54
+ image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
55
  image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
 
56
 
57
+ # 3. Tokenize labels
58
+ text_tokens = self.tokenizer(labels).to(self.device)
59
+
60
+ # 4. Forward pass
61
  with torch.no_grad(), torch.cuda.amp.autocast():
62
  img_feat = self.model.encode_image(image_tensor)
63
  txt_feat = self.model.encode_text(text_tokens)
64
+ img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
65
+ txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
66
  probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
67
 
68
+ # 5. Return sorted list
69
  return [
70
+ {"label": l, "score": float(p)}
71
+ for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
 
72
  ]