blip-yoda-captioning / handler.py
vkao8264's picture
Update handler.py
ea2341d verified
from typing import Dict, Any, List
from transformers import AutoProcessor, BlipForConditionalGeneration
from PIL import Image
import torch
import io
import base64
class EndpointHandler:
def __init__(self, path: str = ""):
# Load model and processor
self.processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=False)
self.model = BlipForConditionalGeneration.from_pretrained(path)
self.model.eval()
self.default_args = {
"max_new_tokens": 30,
"temperature": 0.4,
"do_sample": True,
"top_k": 40,
"top_p": 0.4,
}
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (dict): {
"inputs": base64-encoded image,
"generation_args": optional generation parameters
}
Returns:
List[Dict[str, str]]: generated caption or error
"""
image_data = data.get("inputs")
if image_data is None:
return [{"error": "Missing 'inputs' key"}]
# Wake up function to restart the server
if image_data == "wake":
return [{"status": "woken"}]
# Generation args
args = data.get("generation_args", {})
generation_args = self.default_args.copy()
for k in self.default_args:
if k in args and args[k] is not None:
generation_args[k] = args[k]
# Decode base64 image
try:
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
except Exception as e:
return [{"error": f"Image decoding failed: {str(e)}"}]
# Model Inference
try:
inputs = self.processor(image, return_tensors="pt")
with torch.no_grad():
output_tokens = self.model.generate(**inputs, **generation_args)
caption = self.processor.decode(output_tokens[0], skip_special_tokens=True)
return [{"generated_caption": caption}]
except Exception as e:
return [{"error": f"Inference failed: {str(e)}"}]