File size: 3,264 Bytes
58f78c1 89a2202 58f78c1 5ee030f 89a2202 58f78c1 89a2202 58f78c1 89a2202 58f78c1 89a2202 58f78c1 5ee030f 2eccf7e 5ee030f 5a65a07 5ee030f 2eccf7e 5ee030f 5a65a07 5ee030f 2eccf7e 5a65a07 2eccf7e 5a65a07 2eccf7e 5a65a07 2eccf7e 5a65a07 2eccf7e 89a2202 58f78c1 2eccf7e dd7c1fd 2eccf7e 58f78c1 2eccf7e b2f905d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import base64
from io import BytesIO
class EndpointHandler:
def __init__(self, path=""):
# The 'path' is a self-contained directory with the complete, merged model.
print("Loading model and processor from local path...")
self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
self.model = LlavaForConditionalGeneration.from_pretrained(
path,
load_in_4bit=True,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
print("✅ Model loaded successfully.")
def __call__(self, data: dict) -> dict:
payload = data.pop("inputs", data)
prompt_text = payload.pop("prompt", "Describe the image in detail.")
image_b64 = payload.pop("image_b64", None)
max_new_tokens = payload.pop("max_new_tokens", 200)
image = None
if image_b64:
try:
image_bytes = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_bytes))
except Exception as e:
return {"error": f"Failed to decode or open base64 image: {e}"}
if image is not None:
# --- Case 1: Multimodal (Image + Text) ---
print("Processing multimodal request...")
prompt = f"USER: <image>\n{prompt_text} ASSISTANT:"
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.model.device)
else:
# --- Case 2: Text-Only - CORRECTED LOGIC ---
print("Processing text-only request...")
prompt = f"USER: {prompt_text} ASSISTANT:"
# First, process the text to get input_ids
inputs = self.processor(text=prompt, return_tensors="pt")
# --- THE FIX: Get image dimensions from the processor's .config ---
image_processor = self.processor.image_processor
config = image_processor.config
# Create a dummy image tensor using the correct config values
dummy_pixel_values = torch.zeros(
(
1,
config.num_channels,
config.crop_size['height'],
config.crop_size['width']
),
dtype=self.model.dtype,
device=self.model.device
)
# Add the dummy tensor to the inputs dictionary
inputs['pixel_values'] = dummy_pixel_values
# Ensure the entire input dictionary is on the correct device
inputs = inputs.to(self.model.device)
# Generate the output (this part is the same for both cases)
with torch.no_grad():
output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
full_response = self.processor.decode(output[0], skip_special_tokens=True)
# Clean up the response string
assistant_response = full_response.split("ASSISTANT:")[-1].strip()
return {"generated_text": assistant_response} |