toy3003's picture
Create handler.py
3061b42 verified
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import base64
import io
import logging
logger = logging.getLogger(__name__)
class EndpointHandler():
def __init__(self, path=""):
"""
ฟังก์ชันนี้จะทำงานแค่ครั้งเดียวตอนเริ่มต้น Endpoint เพื่อโหลดโมเดลรอไว้
"""
logger.info("Initializing model...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# โหลดโมเดลและ processor จาก path ที่ Hugging Face ส่งมาให้
self.model = AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map=self.device,
# revision="b98e57b" # อาจจะต้องใช้ถ้ามีปัญหา
)
self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
logger.info("Model initialized successfully.")
def __call__(self, data):
"""
ฟังก์ชันนี้จะทำงานทุกครั้งที่มี request ส่งเข้ามาที่ API
"""
logger.info("Processing new request...")
# ดึงข้อมูลจาก request
inputs = data.pop("inputs", data)
image_b64 = inputs.get("image") # แนะนำให้ใช้ key ชื่อ "image"
if not image_b64:
return {"error": "Missing 'image' key with base64 encoded string in inputs."}
try:
# แปลง base64 string กลับเป็นรูปภาพ
image_bytes = base64.b64decode(image_b64)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
logger.error(f"Error decoding image: {e}")
return {"error": f"Invalid base64 image data. {e}"}
# สร้าง Prompt ตามรูปแบบที่โมเดลต้องการ
prompt = "<|user|>\n<image>\n<|assistant|>"
# เตรียมข้อมูลสำหรับส่งเข้าโมเดล
model_inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device, torch.bfloat16)
# รันโมเดลเพื่อสร้าง text
generated_ids = self.model.generate(
input_ids=model_inputs["input_ids"],
pixel_values=model_inputs["pixel_values"],
max_new_tokens=2048,
do_sample=False,
num_beams=1
)
# ถอดรหัสผลลัพธ์
generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:]
response_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
logger.info("Request processed successfully.")
# ส่งผลลัพธ์กลับในรูปแบบ JSON
return {"generated_text": response_text}