manevamarija commited on
Commit
eef716f
·
verified ·
1 Parent(s): 28cb177

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +36 -14
handler.py CHANGED
@@ -1,27 +1,49 @@
1
-
2
  from typing import Dict, List, Any
3
  from PIL import Image
4
  from io import BytesIO
5
- from transformers import pipeline
6
  import base64
 
 
7
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
10
- self.pipeline = pipeline("zero-shot-image-classification", model="openai/clip-vit-large-patch14-336")
 
 
11
 
12
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
  """
14
- data args:
15
- images (:obj:`string`)
16
- candiates (:obj:`list`)
17
- Return:
18
- A :obj:`list`:. The list contains items that are dicts like {"label": "XXX", "score": 0.82}
 
 
 
 
 
19
  """
20
- inputs = data.pop("inputs", data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # decode base64 image to PIL
23
- image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
 
 
24
 
25
- # run prediction with provided candiates
26
- prediction = self.pipeline(images=[image], candidate_labels=inputs["candiates"])
27
- return prediction[0]
 
 
1
  from typing import Dict, List, Any
2
  from PIL import Image
3
  from io import BytesIO
 
4
  import base64
5
+ import torch
6
+ from transformers import CLIPProcessor, CLIPModel
7
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
10
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
11
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
12
+ self.model.eval()
13
 
14
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
  """
16
+ Args:
17
+ data: {
18
+ "inputs": {
19
+ "image": base64 string,
20
+ "candiates": list of strings
21
+ }
22
+ }
23
+
24
+ Returns:
25
+ List of dicts with raw cosine similarity scores (not softmax probabilities).
26
  """
27
+ inputs = data.get("inputs", data)
28
+
29
+ # Decode and process image
30
+ image = Image.open(BytesIO(base64.b64decode(inputs["image"]))).convert("RGB")
31
+ categories = inputs["candiates"]
32
+
33
+ # Get image and text features
34
+ processed = self.processor(text=categories, images=image, return_tensors="pt", padding=True)
35
+ with torch.no_grad():
36
+ image_features = self.model.get_image_features(processed["pixel_values"])
37
+ text_features = self.model.get_text_features(processed["input_ids"], attention_mask=processed["attention_mask"])
38
+
39
+ # Normalize (L2) to get cosine similarity
40
+ image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
41
+ text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
42
+
43
+ similarity = (image_features @ text_features.T).squeeze(0) # shape: (num_labels,)
44
 
45
+ # Format output with raw cosine scores
46
+ result = [{"label": label, "score": score.item()} for label, score in zip(categories, similarity)]
47
+ result = sorted(result, key=lambda x: x["score"], reverse=True)
48
+ return result
49