llava-finetuned / handler.py
Abdulmateen's picture
Update handler.py
65535dd verified
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from peft import PeftModel
from PIL import Image
import requests
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self, path=""):
# path is the local path to your LoRA adapter repository
# 1. Define the base model ID
base_model_id = "llava-hf/llava-v1.5-7b"
# The path to your LoRA adapters is the local path provided
lora_model_path = path
print("Loading processor...")
# ADDED: trust_remote_code=True is required for custom models
self.processor = AutoProcessor.from_pretrained(base_model_id, trust_remote_code=True)
print("Loading base model...")
# Load the base model in 4-bit and add trust_remote_code=True
self.model = LlavaForConditionalGeneration.from_pretrained(
base_model_id,
load_in_4bit=True,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
print(f"Loading and merging LoRA adapters from: {lora_model_path}...")
# Load and merge your LoRA adapters onto the base model
self.model = PeftModel.from_pretrained(self.model, lora_model_path)
print("✅ Model and adapters loaded successfully.")
def __call__(self, data: dict) -> dict:
prompt_text = data.pop("prompt", "Describe the image in detail.")
image_b64 = data.pop("image_b64", None)
max_new_tokens = data.pop("max_new_tokens", 200)
if not image_b64:
return {"error": "No image provided. Please use the 'image_b64' key."}
try:
image_bytes = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_bytes))
except Exception as e:
return {"error": f"Failed to decode or open base64 image: {e}"}
prompt = f"USER: <image>\n{prompt_text} ASSISTANT:"
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to("cuda")
with torch.no_grad():
output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
full_response = self.processor.decode(output[0], skip_special_tokens=True)
assistant_response = full_response.split("ASSISTANT:")[-1].strip()
return {"generated_text": assistant_response}