moooji commited on
Commit
403ae51
·
1 Parent(s): e818162

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +39 -0
handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from PIL import Image
3
+ import torch
4
+ import base64
5
+ from io import BytesIO
6
+ from transformers import CLIPProcessor, CLIPModel
7
+
8
+ # check for GPU
9
+ device = 0 if torch.cuda.is_available() else -1
10
+
11
+ class EndpointHandler():
12
+ def __init__(self, path=""):
13
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
14
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
15
+
16
+ def __call__(self, data: Any) -> List[float]:
17
+ inputs = data.pop("inputs", data)
18
+
19
+ if "image" in inputs:
20
+ # decode base64 image to PIL
21
+ image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
22
+ inputs = self.processor(images=image, text=None, return_tensors="pt", padding=True)
23
+
24
+ image_embeds = self.model.get_image_features(
25
+ pixel_values=inputs["pixel_values"]
26
+ )
27
+
28
+ return image_embeds[0].tolist()
29
+ if "text" in inputs:
30
+ text = inputs['text']
31
+ inputs = self.processor(images=None, text=text, return_tensors="pt", padding=True)
32
+
33
+ text_embeds = self.model.get_text_features(
34
+ input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
35
+ )
36
+
37
+ return text_embeds[0].tolist()
38
+
39
+ raise Exception("No 'image' or 'text' provided")