|
|
from typing import Dict, Any |
|
|
from transformers import AutoModelForImageTextToText, AutoProcessor |
|
|
from PIL import Image |
|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
import os |
|
|
from subprocess import run |
|
|
|
|
|
run("pip3 install pillow==9.4.0", shell=True, check=True) |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
self.HF_READ_TOKEN = os.getenv("HF_READ_TOKEN") |
|
|
|
|
|
self.model = AutoModelForImageTextToText.from_pretrained( |
|
|
path, |
|
|
load_in_4bit=True, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
self.processor = AutoProcessor.from_pretrained(path) |
|
|
print("Model and processor loaded") |
|
|
|
|
|
def decode_image(self, image_input): |
|
|
"""Convert base64 string or bytes to PIL Image""" |
|
|
if isinstance(image_input, str): |
|
|
image_data = base64.b64decode(image_input) |
|
|
elif isinstance(image_input, bytes): |
|
|
image_data = image_input |
|
|
else: |
|
|
raise ValueError("Image must be base64 string or bytes") |
|
|
return Image.open(io.BytesIO(image_data)).convert("RGB") |
|
|
|
|
|
def __call__(self, data) |
|
|
try: |
|
|
|
|
|
if "image" not in data or "text" not in data: |
|
|
return {"error": "Both 'image' and 'text' are required"} |
|
|
|
|
|
|
|
|
image = self.decode_image(data["image"]) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image"}, |
|
|
{"type": "text", "text": data["text"]} |
|
|
] |
|
|
} |
|
|
] |
|
|
input_text = self.processor.apply_chat_template(messages) |
|
|
|
|
|
|
|
|
inputs = self.processor( |
|
|
text=input_text, |
|
|
images=image, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True |
|
|
).to("cuda") |
|
|
|
|
|
|
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
temperature=0.7, |
|
|
min_p=0.1, |
|
|
do_sample=True |
|
|
) |
|
|
|
|
|
|
|
|
generated = self.processor.decode( |
|
|
outputs[0], |
|
|
skip_special_tokens=True |
|
|
).strip() |
|
|
|
|
|
|
|
|
if "assistant" in generated: |
|
|
generated = generated.split("assistant")[-1].strip() |
|
|
|
|
|
return {"generated_text": generated} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e)} |