from typing import Dict, Any, List from transformers import AutoProcessor, BlipForConditionalGeneration from PIL import Image import torch import io import base64 class EndpointHandler: def __init__(self, path: str = ""): # Load model and processor self.processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=False) self.model = BlipForConditionalGeneration.from_pretrained(path) self.model.eval() self.default_args = { "max_new_tokens": 30, "temperature": 0.4, "do_sample": True, "top_k": 40, "top_p": 0.4, } def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Args: data (dict): { "inputs": base64-encoded image, "generation_args": optional generation parameters } Returns: List[Dict[str, str]]: generated caption or error """ image_data = data.get("inputs") if image_data is None: return [{"error": "Missing 'inputs' key"}] # Wake up function to restart the server if image_data == "wake": return [{"status": "woken"}] # Generation args args = data.get("generation_args", {}) generation_args = self.default_args.copy() for k in self.default_args: if k in args and args[k] is not None: generation_args[k] = args[k] # Decode base64 image try: image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB") except Exception as e: return [{"error": f"Image decoding failed: {str(e)}"}] # Model Inference try: inputs = self.processor(image, return_tensors="pt") with torch.no_grad(): output_tokens = self.model.generate(**inputs, **generation_args) caption = self.processor.decode(output_tokens[0], skip_special_tokens=True) return [{"generated_caption": caption}] except Exception as e: return [{"error": f"Inference failed: {str(e)}"}]