|
|
from typing import Dict, Any, List
|
|
|
from transformers import AutoProcessor, BlipForConditionalGeneration
|
|
|
from PIL import Image
|
|
|
import torch
|
|
|
import io
|
|
|
import base64
|
|
|
|
|
|
class EndpointHandler:
|
|
|
def __init__(self, path: str = ""):
|
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=False)
|
|
|
self.model = BlipForConditionalGeneration.from_pretrained(path)
|
|
|
self.model.eval()
|
|
|
|
|
|
self.default_args = {
|
|
|
"max_new_tokens": 30,
|
|
|
"temperature": 0.4,
|
|
|
"do_sample": True,
|
|
|
"top_k": 40,
|
|
|
"top_p": 0.4,
|
|
|
}
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
Args:
|
|
|
data (dict): {
|
|
|
"inputs": base64-encoded image,
|
|
|
"generation_args": optional generation parameters
|
|
|
}
|
|
|
|
|
|
Returns:
|
|
|
List[Dict[str, str]]: generated caption or error
|
|
|
"""
|
|
|
|
|
|
image_data = data.get("inputs")
|
|
|
if image_data is None:
|
|
|
return [{"error": "Missing 'inputs' key"}]
|
|
|
|
|
|
if image_data == "wake":
|
|
|
return [{"status": "woken"}]
|
|
|
|
|
|
|
|
|
args = data.get("generation_args", {})
|
|
|
generation_args = self.default_args.copy()
|
|
|
for k in self.default_args:
|
|
|
if k in args and args[k] is not None:
|
|
|
generation_args[k] = args[k]
|
|
|
|
|
|
|
|
|
try:
|
|
|
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
|
|
|
except Exception as e:
|
|
|
return [{"error": f"Image decoding failed: {str(e)}"}]
|
|
|
|
|
|
|
|
|
try:
|
|
|
inputs = self.processor(image, return_tensors="pt")
|
|
|
with torch.no_grad():
|
|
|
output_tokens = self.model.generate(**inputs, **generation_args)
|
|
|
caption = self.processor.decode(output_tokens[0], skip_special_tokens=True)
|
|
|
return [{"generated_caption": caption}]
|
|
|
except Exception as e:
|
|
|
return [{"error": f"Inference failed: {str(e)}"}] |