Upload handler.py

#54
by kas1293 - opened
Files changed (1) hide show
  1. handler.py +33 -0
handler.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from PIL import Image
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ import torch
5
+ import base64
6
+ from io import BytesIO
7
+
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ self.model = CLIPModel.from_pretrained(path)
12
+ self.processor = CLIPProcessor.from_pretrained(path)
13
+ self.model.eval()
14
+
15
+ def _to_image(self, x) -> Image.Image:
16
+ if isinstance(x, Image.Image):
17
+ return x.convert("RGB")
18
+ if isinstance(x, (bytes, bytearray)):
19
+ return Image.open(BytesIO(x)).convert("RGB")
20
+ if isinstance(x, str):
21
+ return Image.open(BytesIO(base64.b64decode(x))).convert("RGB")
22
+ if isinstance(x, dict) and "image" in x:
23
+ return self._to_image(x["image"])
24
+ raise ValueError("Unsupported image input")
25
+
26
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
27
+ inputs = data.get("inputs", data)
28
+ image = self._to_image(inputs)
29
+ proc = self.processor(images=image, return_tensors="pt")
30
+ with torch.no_grad():
31
+ feats = self.model.get_image_features(**proc)
32
+ feats = feats / feats.norm(p=2, dim=-1, keepdim=True) # L2 normalize
33
+ return {"embedding": feats[0].tolist(), "dim": int(feats.shape[-1])}