from typing import Dict, List, Any from tempfile import TemporaryDirectory from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration from PIL import Image import torch import requests class EndpointHandler: def __init__(self): self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") device = 'gpu' if torch.cuda.is_available() else 'cpu' model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float32 if device == 'cpu' else torch.float16, low_cpu_mem_usage=True ) model.to(device) self.model = model self.device = device def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `dict`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs inputs = data.pop("inputs", data) if not inputs: return f"Inputs not in payload got {data.keys()}" # get additional date field0 prompt = inputs.pop("prompt", None) image_url = inputs.pop("image", None) if image_url is None: return "You need to upload an image URL for LLaVA to work." if prompt is None: prompt = "Can you describe this picture focusing on specifics visual artifacts and ambiance (objects, colors, person, athmosphere..). Please stay concise only output keywords and concepts detected." if not self.model: return "Model was not initialized" if not self.processor: return "Processor was not initialized" # Create a temporary directory with TemporaryDirectory() as tmpdirname: # Download the image response = requests.get(image_url) if response.status_code != 200: return "Failed to download the image." # Define the path for the downloaded image image_path = f"{tmpdirname}/image.jpg" with open(image_path, "wb") as f: f.write(response.content) # Open the downloaded image with Image.open(image_path).convert("RGB") as image: prompt = f"[INST] \n{prompt} [/INST]" inputs = self.processor(prompt, image, return_tensors="pt").to(self.device) output = self.model.generate(**inputs, max_new_tokens=100) clean = self.processor.decode(output[0], skip_special_tokens=True) return clean