Florence2 / handler.py
SebastianRuff's picture
Update handler.py
4ffd51d verified
raw
history blame
1.83 kB
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import requests
import torch
class EndpointHandler:
def __init__(self, model_dir):
# Check if a GPU is available; use CPU if not
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the model with trust_remote_code=True
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True
).eval().to(device) # Dynamically move to the correct device
self.processor = AutoProcessor.from_pretrained(
model_dir,
trust_remote_code=True
)
self.device = device
def __call__(self, data):
# Extract inputs from the request data
task_prompt = data.get("task_prompt", "<MORE_DETAILED_CAPTION>")
image_url = data.get("image_url")
# Load and process the image
image = self.load_image(image_url)
# Prepare inputs for the model
inputs = self.processor(
text=task_prompt,
images=image,
return_tensors="pt"
).to(self.device) # Use the correct device
# Generate output
generated_ids = self.model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
)
# Decode and post-process the output
generated_text = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True
)[0]
return {"caption": generated_text}
def load_image(self, image_url):
# Load image from the provided URL
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
return image