File size: 3,234 Bytes
3061b42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import base64
import io
import logging
logger = logging.getLogger(__name__)
class EndpointHandler():
def __init__(self, path=""):
"""
ฟังก์ชันนี้จะทำงานแค่ครั้งเดียวตอนเริ่มต้น Endpoint เพื่อโหลดโมเดลรอไว้
"""
logger.info("Initializing model...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# โหลดโมเดลและ processor จาก path ที่ Hugging Face ส่งมาให้
self.model = AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map=self.device,
# revision="b98e57b" # อาจจะต้องใช้ถ้ามีปัญหา
)
self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
logger.info("Model initialized successfully.")
def __call__(self, data):
"""
ฟังก์ชันนี้จะทำงานทุกครั้งที่มี request ส่งเข้ามาที่ API
"""
logger.info("Processing new request...")
# ดึงข้อมูลจาก request
inputs = data.pop("inputs", data)
image_b64 = inputs.get("image") # แนะนำให้ใช้ key ชื่อ "image"
if not image_b64:
return {"error": "Missing 'image' key with base64 encoded string in inputs."}
try:
# แปลง base64 string กลับเป็นรูปภาพ
image_bytes = base64.b64decode(image_b64)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
logger.error(f"Error decoding image: {e}")
return {"error": f"Invalid base64 image data. {e}"}
# สร้าง Prompt ตามรูปแบบที่โมเดลต้องการ
prompt = "<|user|>\n<image>\n<|assistant|>"
# เตรียมข้อมูลสำหรับส่งเข้าโมเดล
model_inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device, torch.bfloat16)
# รันโมเดลเพื่อสร้าง text
generated_ids = self.model.generate(
input_ids=model_inputs["input_ids"],
pixel_values=model_inputs["pixel_values"],
max_new_tokens=2048,
do_sample=False,
num_beams=1
)
# ถอดรหัสผลลัพธ์
generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:]
response_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
logger.info("Request processed successfully.")
# ส่งผลลัพธ์กลับในรูปแบบ JSON
return {"generated_text": response_text} |