File size: 1,476 Bytes
e22ef17 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import requests
import torch
import base64
import io
class EndpointHandler:
def __init__(self, path):
self.processor = BlipProcessor.from_pretrained(path)
self.model = BlipForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def _load_image(self, image_input):
# URL
if isinstance(image_input, str) and image_input.startswith("http"):
return Image.open(requests.get(image_input, stream=True).raw).convert("RGB")
# Base64
if isinstance(image_input, str):
image_bytes = base64.b64decode(image_input)
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
raise ValueError("Unsupported image input format")
def __call__(self, data):
image_input = data.get("inputs")
if image_input is None:
raise ValueError("No image provided")
image = self._load_image(image_input)
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
output = self.model.generate(**inputs, max_new_tokens=50)
caption = self.processor.decode(output[0], skip_special_tokens=True)
return {"caption": caption}
|