new-test-model / handler.py
jeff-RQ's picture
Update handler.py
d94fef7
raw
history blame
1.07 kB
from typing import Any, Dict
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import io
from PIL import Image
import base64
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = Blip2Processor.from_pretrained(path)
self.model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
self.device = "cuda"
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
# process input
data = data.pop("inputs", data)
text = data.pop("text", data)
image_string = base64.b64decode(data["image"])
image = Image.open(io.BytesIO(image_string))
inputs = self.processor(images=image, text=text, return_tensors="pt").to(device, torch.float16)
generated_ids = self.model.generate(**inputs)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return [{"answer": generated_text}]