finhdev commited on
Commit
233acb0
·
verified ·
1 Parent(s): 08ce2dc

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +19 -37
handler.py CHANGED
@@ -1,72 +1,54 @@
1
- # handler.py
2
- import io
3
- import base64
4
- import torch
5
  from PIL import Image
6
-
7
  import open_clip
8
- from open_clip import fuse_conv_bn_sequential
9
 
10
  class EndpointHandler:
11
  """
12
- Zero‑shot image classifier for MobileCLIP‑B (OpenCLIP).
13
-
14
- Expects JSON payload:
15
  {
16
- "image": "<base64‑encoded PNG/JPEG>",
17
  "candidate_labels": ["cat", "dog", ...]
18
  }
19
- Returns:
20
- [
21
- {"label": "cat", "score": 0.91},
22
- {"label": "dog", "score": 0.05},
23
- ...
24
- ]
25
  """
26
 
27
  def __init__(self, path: str = ""):
28
- # Path points to the repo root inside the container
29
- weights = f"{path}/mobileclip_b.pt"
30
-
31
- # Load model + transforms from OpenCLIP
32
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
33
  "MobileCLIP-B", pretrained=weights
34
  )
 
35
 
36
- # Fuse conv + BN for faster inference (same idea as MobileCLIP re‑param)
37
- self.model = fuse_conv_bn_sequential(self.model).eval()
38
-
39
- # Tokenizer for label prompts
40
  self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
41
-
42
- # Device selection
43
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
44
  self.model.to(self.device)
45
 
46
  def __call__(self, data):
47
- # 1. Parse request
48
- img_b64 = data["image"]
49
- labels = data.get("candidate_labels", [])
50
  if not labels:
51
  return {"error": "candidate_labels list is empty"}
52
 
53
- # 2. Decode & preprocess image
54
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
55
- image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
56
 
57
- # 3. Tokenize labels
58
  text_tokens = self.tokenizer(labels).to(self.device)
59
 
60
- # 4. Forward pass
61
  with torch.no_grad(), torch.cuda.amp.autocast():
62
- img_feat = self.model.encode_image(image_tensor)
63
  txt_feat = self.model.encode_text(text_tokens)
64
  img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
65
  txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
66
- probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
67
 
68
- # 5. Return sorted list
69
  return [
70
  {"label": l, "score": float(p)}
71
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
72
  ]
 
 
1
+ # handler.py (repo root)
2
+ import io, base64, torch
 
 
3
  from PIL import Image
 
4
  import open_clip
 
5
 
6
  class EndpointHandler:
7
  """
8
+ Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).
9
+ Request:
 
10
  {
11
+ "image": "<base64‑png/jpeg>",
12
  "candidate_labels": ["cat", "dog", ...]
13
  }
14
+ Response: list[{"label": str, "score": float}]
 
 
 
 
 
15
  """
16
 
17
  def __init__(self, path: str = ""):
18
+ weights = f"{path}/mobileclip_b.pt" # ckpt in your repo
 
 
 
19
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
20
  "MobileCLIP-B", pretrained=weights
21
  )
22
+ self.model.eval()
23
 
 
 
 
 
24
  self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
25
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
26
  self.model.to(self.device)
27
 
28
  def __call__(self, data):
29
+ img_b64 = data["image"]
30
+ labels = data.get("candidate_labels", [])
 
31
  if not labels:
32
  return {"error": "candidate_labels list is empty"}
33
 
34
+ # Decode + preprocess image
35
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
36
+ img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
37
 
38
+ # Tokenise labels
39
  text_tokens = self.tokenizer(labels).to(self.device)
40
 
41
+ # Forward pass
42
  with torch.no_grad(), torch.cuda.amp.autocast():
43
+ img_feat = self.model.encode_image(img_tensor)
44
  txt_feat = self.model.encode_text(text_tokens)
45
  img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
46
  txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
47
+ probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
48
 
49
+ # Return sorted results
50
  return [
51
  {"label": l, "score": float(p)}
52
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
53
  ]
54
+