kumararvindibs's picture
Upload 4 files
ed105db verified
raw
history blame
2.33 kB
import requests
from typing import Dict, Any
from PIL import Image
import torch
import base64
import io
from transformers import BlipForConditionalGeneration, BlipProcessor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
input_data = data.get("inputs", {})
encoded_images = input_data.get("images")
if not encoded_images:
return {"captions": [], "error": "No images provided"}
texts = input_data.get("texts", ["a photography of"] * len(encoded_images))
try:
byteImgIO = io.BytesIO()
byteImg = Image.open(encoded_images[0])
byteImg.save(byteImgIO, "PNG")
byteImgIO.seek(0)
byteImg = byteImgIO.read()
# Non test code
dataBytesIO = io.BytesIO(byteImg)
raw_images =[Image.open(dataBytesIO)]
# Check if any images were successfully decoded
if not raw_images:
print("No valid images found.")
processed_inputs = [
self.processor(image, text, return_tensors="pt") for image, text in zip(raw_images, texts)
]
processed_inputs = {
"pixel_values": torch.cat([inp["pixel_values"] for inp in processed_inputs], dim=0).to(device),
"input_ids": torch.cat([inp["input_ids"] for inp in processed_inputs], dim=0).to(device),
"attention_mask": torch.cat([inp["attention_mask"] for inp in processed_inputs], dim=0).to(device),
"max_new_tokens":40
}
with torch.no_grad():
out = self.model.generate(**processed_inputs)
captions = self.processor.batch_decode(out, skip_special_tokens=True)
return {"captions": captions}
except Exception as e:
print(f"Error during processing: {str(e)}")
return {"captions": [], "error": str(e)}