llava-next-inference / handler.py
eBoreal's picture
fix and log inputs
de4c8be
raw
history blame
2.7 kB
from typing import Dict, List, Any
from tempfile import TemporaryDirectory
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
import torch
import requests
class EndpointHandler:
def __init__(self):
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
device = 'gpu' if torch.cuda.is_available() else 'cpu'
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
torch_dtype=torch.float32 if device == 'cpu' else torch.float16,
low_cpu_mem_usage=True
)
model.to(device)
self.model = model
self.device = device
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `dict`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs", data)
if not inputs:
return f"Inputs not in payload got {data.keys()}"
# get additional date field0
prompt = inputs.pop("prompt", None)
image_url = inputs.pop("image", None)
if image_url is None:
return "You need to upload an image URL for LLaVA to work."
if prompt is None:
prompt = "Can you describe this picture focusing on specifics visual artifacts and ambiance (objects, colors, person, athmosphere..). Please stay concise only output keywords and concepts detected."
if not self.model:
return "Model was not initialized"
if not self.processor:
return "Processor was not initialized"
# Create a temporary directory
with TemporaryDirectory() as tmpdirname:
# Download the image
response = requests.get(image_url)
if response.status_code != 200:
return "Failed to download the image."
# Define the path for the downloaded image
image_path = f"{tmpdirname}/image.jpg"
with open(image_path, "wb") as f:
f.write(response.content)
# Open the downloaded image
with Image.open(image_path).convert("RGB") as image:
prompt = f"[INST] <image>\n{prompt} [/INST]"
inputs = self.processor(prompt, image, return_tensors="pt").to(self.device)
output = self.model.generate(**inputs, max_new_tokens=100)
clean = self.processor.decode(output[0], skip_special_tokens=True)
return clean