cuong4941's picture
fsdfs
8c08231
# handler.py
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
class EndpointHandler:
"""
Hugging Face Inference Endpoints sẽ tự động khởi tạo lớp này khi container startup:
handler = EndpointHandler(model_dir=<path to repository>)
Sau đó, mỗi khi có request, toolkit sẽ gọi handler(data) để chạy inference.
"""
def __init__(self, model_dir: str, **kwargs):
"""
model_dir: Đường dẫn tới thư mục mà HF Endpoint đã clone repo của bạn vào.
Trong trường hợp này, bạn load model trực tiếp từ Hugging Face Hub,
nên model_dir chỉ để “đáp ứng” signature. Bạn không cần dùng model_dir.
"""
# 1. Bắt buộc phải có token để truy cập gated repo
HF_TOKEN = os.getenv("HF_HUB_TOKEN")
if not HF_TOKEN:
raise RuntimeError(
"HF_HUB_TOKEN chưa được thiết lập. Vui lòng thêm biến môi trường HF_HUB_TOKEN trong phần Settings → Environment Variables của Endpoint."
)
# Dòng debug này in ra token (bạn có thể comment/xóa sau khi chắc chắn đã đúng)
print("DEBUG: HF_HUB_TOKEN =", HF_TOKEN)
# 2. Định nghĩa model gốc và model fork (của bạn)
BASE_MODEL = "meta-llama/Meta-Llama-3-8B"
HF_MODEL_NAME = "cuong4941/llama3-8b-summarize-vn"
# 3. Chọn device: 'cuda' nếu có GPU, else 'cpu'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"DEBUG: Using device = {DEVICE}")
# 4. Load tokenizer từ BASE_MODEL, kèm use_auth_token để băng qua gated repo
self.tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
use_auth_token=HF_TOKEN
)
# 5. Cấu hình bitsandbytes 4-bit quantization
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
llm_int8_enable_fp32_cpu_offload=True
)
# 6. Load model 4-bit từ repo của bạn, kèm use_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_NAME,
quantization_config=quant_config,
device_map="auto", # Accelerate sẽ tự động phân bổ device
trust_remote_code=True,
use_auth_token=HF_TOKEN
)
self.model.eval()
# 7. Tạo pipeline text-generation (KHÔNG truyền thêm 'device' vào đây)
# Accelerate đã xử lý việc phân bổ thiết bị (CUDA hoặc CPU), nên không cần device param.
self.generator = pipeline(
task="text-generation",
model=self.model,
tokenizer=self.tokenizer,
trust_remote_code=True
)
def __call__(self, data: dict) -> dict:
"""
Mỗi khi có request, Endpoint sẽ gọi handler(data).
data: dict có structure như sau (tương ứng JSON payload của client):
{
"inputs": "<chuỗi bài báo muốn tóm tắt>",
"parameters": { # có thể bỏ qua nếu không truyền
"max_new_tokens": 150,
"temperature": 0.7,
"top_k": 50,
"top_p": 0.95,
"num_beams": 1,
"do_sample": True
# … bất kỳ tham số nào còn lại mà pipeline hỗ trợ
}
}
Trả về: dict, ví dụ {"generated_text": "..."} hoặc list of dict nếu bạn muốn trả danh sách outputs.
"""
# 1. Lấy inputs (chuỗi văn bản cần tóm tắt)
inputs = data.get("inputs")
if inputs is None:
raise ValueError("Không tìm thấy key 'inputs' trong payload. Vui lòng gửi JSON có field 'inputs'.")
# 2. Lấy parameters, nếu không có thì khởi tạo rỗng
parameters = data.get("parameters", {})
# 3. Đặt mặc định các tham số generation nếu client không gửi
max_new_tokens = parameters.get("max_new_tokens", 200)
temperature = parameters.get("temperature", 0.7)
top_k = parameters.get("top_k", None)
top_p = parameters.get("top_p", None)
num_beams = parameters.get("num_beams", None)
do_sample = parameters.get("do_sample", True)
# 4. Xây prompt y hệt như trong notebook của bạn
prompt_text = self._make_prompt(inputs)
# 5. Gọi pipeline để generate text
outputs = self.generator(
prompt_text,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_beams=num_beams,
do_sample=do_sample
)
# 6. Lấy kết quả đầu tiên từ list outputs
generated = outputs[0].get("generated_text")
if generated is None:
raise RuntimeError("Pipeline trả về output không chứa key 'generated_text'.")
# 7. Trả kết quả dưới dạng dictionary
return {"generated_text": generated}
@staticmethod
def _make_prompt(article_text: str) -> str:
"""
Build prompt theo format bạn đã train trong notebook:
- Dòng đầu: instruction “Tưởng tượng bạn là một chuyên gia về tài chính...”
- Dòng tiếp: nội dung bài báo
- Kết thúc bằng “Tóm tắt sau khi đọc bài báo:”
"""
instruction = (
"Tưởng tượng bạn là một chuyên gia về tài chính trong các lĩnh vực "
"kinh tế số, thị trường chứng khoán, bất động sản, doanh nghiệp, "
"tài chính ngân hàng"
)
return f"{instruction}\n\n{article_text}\n\nTóm tắt sau khi đọc bài báo:"