File size: 6,211 Bytes
cd262f8 ee22ab7 cd262f8 0e4ccf3 03387d9 0e4ccf3 03387d9 0e4ccf3 8c08231 0e4ccf3 03387d9 0e4ccf3 03387d9 0e4ccf3 03387d9 0e4ccf3 8c08231 0e4ccf3 03387d9 0e4ccf3 8c08231 0e4ccf3 8c08231 0e4ccf3 8c08231 0e4ccf3 03387d9 0e4ccf3 03387d9 0e4ccf3 03387d9 0e4ccf3 03387d9 0e4ccf3 03387d9 0e4ccf3 03387d9 0e4ccf3 03387d9 cd262f8 03387d9 0e4ccf3 03387d9 0e4ccf3 03387d9 0e4ccf3 03387d9 0e4ccf3 cd262f8 0e4ccf3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | # handler.py
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
class EndpointHandler:
"""
Hugging Face Inference Endpoints sẽ tự động khởi tạo lớp này khi container startup:
handler = EndpointHandler(model_dir=<path to repository>)
Sau đó, mỗi khi có request, toolkit sẽ gọi handler(data) để chạy inference.
"""
def __init__(self, model_dir: str, **kwargs):
"""
model_dir: Đường dẫn tới thư mục mà HF Endpoint đã clone repo của bạn vào.
Trong trường hợp này, bạn load model trực tiếp từ Hugging Face Hub,
nên model_dir chỉ để “đáp ứng” signature. Bạn không cần dùng model_dir.
"""
# 1. Bắt buộc phải có token để truy cập gated repo
HF_TOKEN = os.getenv("HF_HUB_TOKEN")
if not HF_TOKEN:
raise RuntimeError(
"HF_HUB_TOKEN chưa được thiết lập. Vui lòng thêm biến môi trường HF_HUB_TOKEN trong phần Settings → Environment Variables của Endpoint."
)
# Dòng debug này in ra token (bạn có thể comment/xóa sau khi chắc chắn đã đúng)
print("DEBUG: HF_HUB_TOKEN =", HF_TOKEN)
# 2. Định nghĩa model gốc và model fork (của bạn)
BASE_MODEL = "meta-llama/Meta-Llama-3-8B"
HF_MODEL_NAME = "cuong4941/llama3-8b-summarize-vn"
# 3. Chọn device: 'cuda' nếu có GPU, else 'cpu'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"DEBUG: Using device = {DEVICE}")
# 4. Load tokenizer từ BASE_MODEL, kèm use_auth_token để băng qua gated repo
self.tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
use_auth_token=HF_TOKEN
)
# 5. Cấu hình bitsandbytes 4-bit quantization
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
llm_int8_enable_fp32_cpu_offload=True
)
# 6. Load model 4-bit từ repo của bạn, kèm use_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_NAME,
quantization_config=quant_config,
device_map="auto", # Accelerate sẽ tự động phân bổ device
trust_remote_code=True,
use_auth_token=HF_TOKEN
)
self.model.eval()
# 7. Tạo pipeline text-generation (KHÔNG truyền thêm 'device' vào đây)
# Accelerate đã xử lý việc phân bổ thiết bị (CUDA hoặc CPU), nên không cần device param.
self.generator = pipeline(
task="text-generation",
model=self.model,
tokenizer=self.tokenizer,
trust_remote_code=True
)
def __call__(self, data: dict) -> dict:
"""
Mỗi khi có request, Endpoint sẽ gọi handler(data).
data: dict có structure như sau (tương ứng JSON payload của client):
{
"inputs": "<chuỗi bài báo muốn tóm tắt>",
"parameters": { # có thể bỏ qua nếu không truyền
"max_new_tokens": 150,
"temperature": 0.7,
"top_k": 50,
"top_p": 0.95,
"num_beams": 1,
"do_sample": True
# … bất kỳ tham số nào còn lại mà pipeline hỗ trợ
}
}
Trả về: dict, ví dụ {"generated_text": "..."} hoặc list of dict nếu bạn muốn trả danh sách outputs.
"""
# 1. Lấy inputs (chuỗi văn bản cần tóm tắt)
inputs = data.get("inputs")
if inputs is None:
raise ValueError("Không tìm thấy key 'inputs' trong payload. Vui lòng gửi JSON có field 'inputs'.")
# 2. Lấy parameters, nếu không có thì khởi tạo rỗng
parameters = data.get("parameters", {})
# 3. Đặt mặc định các tham số generation nếu client không gửi
max_new_tokens = parameters.get("max_new_tokens", 200)
temperature = parameters.get("temperature", 0.7)
top_k = parameters.get("top_k", None)
top_p = parameters.get("top_p", None)
num_beams = parameters.get("num_beams", None)
do_sample = parameters.get("do_sample", True)
# 4. Xây prompt y hệt như trong notebook của bạn
prompt_text = self._make_prompt(inputs)
# 5. Gọi pipeline để generate text
outputs = self.generator(
prompt_text,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_beams=num_beams,
do_sample=do_sample
)
# 6. Lấy kết quả đầu tiên từ list outputs
generated = outputs[0].get("generated_text")
if generated is None:
raise RuntimeError("Pipeline trả về output không chứa key 'generated_text'.")
# 7. Trả kết quả dưới dạng dictionary
return {"generated_text": generated}
@staticmethod
def _make_prompt(article_text: str) -> str:
"""
Build prompt theo format bạn đã train trong notebook:
- Dòng đầu: instruction “Tưởng tượng bạn là một chuyên gia về tài chính...”
- Dòng tiếp: nội dung bài báo
- Kết thúc bằng “Tóm tắt sau khi đọc bài báo:”
"""
instruction = (
"Tưởng tượng bạn là một chuyên gia về tài chính trong các lĩnh vực "
"kinh tế số, thị trường chứng khoán, bất động sản, doanh nghiệp, "
"tài chính ngân hàng"
)
return f"{instruction}\n\n{article_text}\n\nTóm tắt sau khi đọc bài báo:"
|