zohaibterminator's picture
Update handler.py
6e113e1 verified
from typing import Dict, Any
from transformers import AutoModelForImageTextToText, AutoProcessor
from PIL import Image
import torch
import base64
import io
import os
from subprocess import run
run("pip3 install pillow==9.4.0", shell=True, check=True)
class EndpointHandler:
def __init__(self, path=""):
self.HF_READ_TOKEN = os.getenv("HF_READ_TOKEN")
# Load model and processor with 4-bit quantization
self.model = AutoModelForImageTextToText.from_pretrained(
path,
load_in_4bit=True,
torch_dtype=torch.float16,
device_map="auto"
)
self.processor = AutoProcessor.from_pretrained(path)
print("Model and processor loaded")
def decode_image(self, image_input):
"""Convert base64 string or bytes to PIL Image"""
if isinstance(image_input, str):
image_data = base64.b64decode(image_input)
elif isinstance(image_input, bytes):
image_data = image_input
else:
raise ValueError("Image must be base64 string or bytes")
return Image.open(io.BytesIO(image_data)).convert("RGB")
def __call__(self, data)
try:
# Validate inputs
if "image" not in data or "text" not in data:
return {"error": "Both 'image' and 'text' are required"}
# Process image
image = self.decode_image(data["image"])
# Prepare chat template
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": data["text"]}
]
}
]
input_text = self.processor.apply_chat_template(messages)
# Process inputs
inputs = self.processor(
text=input_text,
images=image,
return_tensors="pt",
padding=True,
truncation=True
).to("cuda")
# Generate response
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
min_p=0.1,
do_sample=True
)
# Decode and clean output
generated = self.processor.decode(
outputs[0],
skip_special_tokens=True
).strip()
# Remove prompt fragments
if "assistant" in generated:
generated = generated.split("assistant")[-1].strip()
return {"generated_text": generated}
except Exception as e:
return {"error": str(e)}