Files changed (1) hide show
  1. Handler.py +43 -0
Handler.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BlipProcessor, BlipForConditionalGeneration
2
+ from PIL import Image
3
+ import requests
4
+ import torch
5
+ import base64
6
+ import io
7
+
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path):
11
+ self.processor = BlipProcessor.from_pretrained(path)
12
+ self.model = BlipForConditionalGeneration.from_pretrained(
13
+ path,
14
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
15
+ )
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ self.model.to(self.device)
18
+
19
+ def _load_image(self, image_input):
20
+ # URL
21
+ if isinstance(image_input, str) and image_input.startswith("http"):
22
+ return Image.open(requests.get(image_input, stream=True).raw).convert("RGB")
23
+
24
+ # Base64
25
+ if isinstance(image_input, str):
26
+ image_bytes = base64.b64decode(image_input)
27
+ return Image.open(io.BytesIO(image_bytes)).convert("RGB")
28
+
29
+ raise ValueError("Unsupported image input format")
30
+
31
+ def __call__(self, data):
32
+ image_input = data.get("inputs")
33
+ if image_input is None:
34
+ raise ValueError("No image provided")
35
+
36
+ image = self._load_image(image_input)
37
+
38
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
39
+
40
+ output = self.model.generate(**inputs, max_new_tokens=50)
41
+
42
+ caption = self.processor.decode(output[0], skip_special_tokens=True)
43
+ return {"caption": caption}