llava-merged / handler.py
Abdulmateen's picture
Update handler.py
5a65a07 verified
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import base64
from io import BytesIO
class EndpointHandler:
def __init__(self, path=""):
# The 'path' is a self-contained directory with the complete, merged model.
print("Loading model and processor from local path...")
self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
self.model = LlavaForConditionalGeneration.from_pretrained(
path,
load_in_4bit=True,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
print("✅ Model loaded successfully.")
def __call__(self, data: dict) -> dict:
payload = data.pop("inputs", data)
prompt_text = payload.pop("prompt", "Describe the image in detail.")
image_b64 = payload.pop("image_b64", None)
max_new_tokens = payload.pop("max_new_tokens", 200)
image = None
if image_b64:
try:
image_bytes = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_bytes))
except Exception as e:
return {"error": f"Failed to decode or open base64 image: {e}"}
if image is not None:
# --- Case 1: Multimodal (Image + Text) ---
print("Processing multimodal request...")
prompt = f"USER: <image>\n{prompt_text} ASSISTANT:"
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.model.device)
else:
# --- Case 2: Text-Only - CORRECTED LOGIC ---
print("Processing text-only request...")
prompt = f"USER: {prompt_text} ASSISTANT:"
# First, process the text to get input_ids
inputs = self.processor(text=prompt, return_tensors="pt")
# --- THE FIX: Get image dimensions from the processor's .config ---
image_processor = self.processor.image_processor
config = image_processor.config
# Create a dummy image tensor using the correct config values
dummy_pixel_values = torch.zeros(
(
1,
config.num_channels,
config.crop_size['height'],
config.crop_size['width']
),
dtype=self.model.dtype,
device=self.model.device
)
# Add the dummy tensor to the inputs dictionary
inputs['pixel_values'] = dummy_pixel_values
# Ensure the entire input dictionary is on the correct device
inputs = inputs.to(self.model.device)
# Generate the output (this part is the same for both cases)
with torch.no_grad():
output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
full_response = self.processor.decode(output[0], skip_special_tokens=True)
# Clean up the response string
assistant_response = full_response.split("ASSISTANT:")[-1].strip()
return {"generated_text": assistant_response}