jacstrong's picture
Update handler.py to accept a single image. Image will be provided via web url.
dabc20d verified
from typing import Dict, Any
from PIL import Image
import torch
import requests
from io import BytesIO
from transformers import BlipForConditionalGeneration, BlipProcessor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
# Load the processor and model
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
self.model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base"
).to(device)
self.model.eval()
def __call__(self, data: Any) -> Dict[str, Any]:
"""
Args:
data (:obj:`dict`):
Includes the input data and the parameters for the inference.
Return:
A :obj:`dict`. The object returned contains:
- "caption": A string corresponding to the generated caption.
"""
# Extract image URL and parameters
image_url = data.get("image")
parameters = data.get("parameters", {})
if not image_url:
return {"error": "Missing 'image' field in request body."}
try:
# Download the image
response = requests.get(image_url)
response.raise_for_status()
raw_image = Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
return {"error": f"Failed to fetch image from URL: {str(e)}"}
# Preprocess the image
processed_image = self.processor(images=raw_image, return_tensors="pt")
processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
# Merge parameters if needed
processed_image = {**processed_image, **parameters}
with torch.no_grad():
out = self.model.generate(**processed_image)
# Decode the output
caption = self.processor.decode(out[0], skip_special_tokens=True)
return {"caption": caption}