|
|
from typing import Dict, List, Any |
|
|
from tempfile import TemporaryDirectory |
|
|
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration |
|
|
from PIL import Image |
|
|
import torch |
|
|
import requests |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self): |
|
|
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") |
|
|
|
|
|
|
|
|
device = 'gpu' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
model = LlavaNextForConditionalGeneration.from_pretrained( |
|
|
"llava-hf/llava-v1.6-mistral-7b-hf", |
|
|
torch_dtype=torch.float32 if device == 'cpu' else torch.float16, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
model.to(device) |
|
|
|
|
|
self.model = model |
|
|
self.device = device |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
data args: |
|
|
inputs (:obj: `dict`) |
|
|
Return: |
|
|
A :obj:`list` | `dict`: will be serialized and returned |
|
|
""" |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
if not inputs: |
|
|
return f"Inputs not in payload got {data.keys()}" |
|
|
|
|
|
|
|
|
prompt = inputs.pop("prompt", None) |
|
|
image_url = inputs.pop("image", None) |
|
|
|
|
|
if image_url is None: |
|
|
return "You need to upload an image URL for LLaVA to work." |
|
|
|
|
|
if prompt is None: |
|
|
prompt = "Can you describe this picture focusing on specifics visual artifacts and ambiance (objects, colors, person, athmosphere..). Please stay concise only output keywords and concepts detected." |
|
|
|
|
|
if not self.model: |
|
|
return "Model was not initialized" |
|
|
|
|
|
if not self.processor: |
|
|
return "Processor was not initialized" |
|
|
|
|
|
|
|
|
with TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
response = requests.get(image_url) |
|
|
if response.status_code != 200: |
|
|
return "Failed to download the image." |
|
|
|
|
|
|
|
|
image_path = f"{tmpdirname}/image.jpg" |
|
|
with open(image_path, "wb") as f: |
|
|
f.write(response.content) |
|
|
|
|
|
|
|
|
with Image.open(image_path).convert("RGB") as image: |
|
|
prompt = f"[INST] <image>\n{prompt} [/INST]" |
|
|
|
|
|
inputs = self.processor(prompt, image, return_tensors="pt").to(self.device) |
|
|
|
|
|
output = self.model.generate(**inputs, max_new_tokens=100) |
|
|
|
|
|
clean = self.processor.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
return clean |