File size: 2,845 Bytes
2bf2b8c bffb0ab 0fbd23f bffb0ab 2bf2b8c 0fbd23f 9053cea 0fbd23f 2bf2b8c 0fbd23f 2bf2b8c 0fbd23f 9053cea 2bf2b8c 9053cea 0fbd23f 2bf2b8c 9053cea 6e113e1 9053cea 2bf2b8c 9053cea 2bf2b8c 9053cea 2bf2b8c 9053cea 2bf2b8c 9053cea 2bf2b8c 9053cea 2bf2b8c bffb0ab 2bf2b8c bffb0ab 2bf2b8c 9053cea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from typing import Dict, Any
from transformers import AutoModelForImageTextToText, AutoProcessor
from PIL import Image
import torch
import base64
import io
import os
from subprocess import run
run("pip3 install pillow==9.4.0", shell=True, check=True)
class EndpointHandler:
def __init__(self, path=""):
self.HF_READ_TOKEN = os.getenv("HF_READ_TOKEN")
# Load model and processor with 4-bit quantization
self.model = AutoModelForImageTextToText.from_pretrained(
path,
load_in_4bit=True,
torch_dtype=torch.float16,
device_map="auto"
)
self.processor = AutoProcessor.from_pretrained(path)
print("Model and processor loaded")
def decode_image(self, image_input):
"""Convert base64 string or bytes to PIL Image"""
if isinstance(image_input, str):
image_data = base64.b64decode(image_input)
elif isinstance(image_input, bytes):
image_data = image_input
else:
raise ValueError("Image must be base64 string or bytes")
return Image.open(io.BytesIO(image_data)).convert("RGB")
def __call__(self, data)
try:
# Validate inputs
if "image" not in data or "text" not in data:
return {"error": "Both 'image' and 'text' are required"}
# Process image
image = self.decode_image(data["image"])
# Prepare chat template
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": data["text"]}
]
}
]
input_text = self.processor.apply_chat_template(messages)
# Process inputs
inputs = self.processor(
text=input_text,
images=image,
return_tensors="pt",
padding=True,
truncation=True
).to("cuda")
# Generate response
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
min_p=0.1,
do_sample=True
)
# Decode and clean output
generated = self.processor.decode(
outputs[0],
skip_special_tokens=True
).strip()
# Remove prompt fragments
if "assistant" in generated:
generated = generated.split("assistant")[-1].strip()
return {"generated_text": generated}
except Exception as e:
return {"error": str(e)} |