from transformers import BlipProcessor, BlipForConditionalGeneration from PIL import Image import requests import torch import base64 import io class EndpointHandler: def __init__(self, path): self.processor = BlipProcessor.from_pretrained(path) self.model = BlipForConditionalGeneration.from_pretrained( path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) def _load_image(self, image_input): # URL if isinstance(image_input, str) and image_input.startswith("http"): return Image.open(requests.get(image_input, stream=True).raw).convert("RGB") # Base64 if isinstance(image_input, str): image_bytes = base64.b64decode(image_input) return Image.open(io.BytesIO(image_bytes)).convert("RGB") raise ValueError("Unsupported image input format") def __call__(self, data): image_input = data.get("inputs") if image_input is None: raise ValueError("No image provided") image = self._load_image(image_input) inputs = self.processor(images=image, return_tensors="pt").to(self.device) output = self.model.generate(**inputs, max_new_tokens=50) caption = self.processor.decode(output[0], skip_special_tokens=True) return {"caption": caption}