| | |
| |
|
| | import os |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig |
| |
|
| | class EndpointHandler: |
| | """ |
| | Hugging Face Inference Endpoints sẽ tự động khởi tạo lớp này khi container startup: |
| | handler = EndpointHandler(model_dir=<path to repository>) |
| | Sau đó, mỗi khi có request, toolkit sẽ gọi handler(data) để chạy inference. |
| | """ |
| |
|
| | def __init__(self, model_dir: str, **kwargs): |
| | """ |
| | model_dir: Đường dẫn tới thư mục mà HF Endpoint đã clone repo của bạn vào. |
| | Trong trường hợp này, bạn load model trực tiếp từ Hugging Face Hub, |
| | nên model_dir chỉ để “đáp ứng” signature. Bạn không cần dùng model_dir. |
| | """ |
| | |
| | HF_TOKEN = os.getenv("HF_HUB_TOKEN") |
| | if not HF_TOKEN: |
| | raise RuntimeError( |
| | "HF_HUB_TOKEN chưa được thiết lập. Vui lòng thêm biến môi trường HF_HUB_TOKEN trong phần Settings → Environment Variables của Endpoint." |
| | ) |
| | |
| | print("DEBUG: HF_HUB_TOKEN =", HF_TOKEN) |
| |
|
| | |
| | BASE_MODEL = "meta-llama/Meta-Llama-3-8B" |
| | HF_MODEL_NAME = "cuong4941/llama3-8b-summarize-vn" |
| |
|
| | |
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"DEBUG: Using device = {DEVICE}") |
| |
|
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | BASE_MODEL, |
| | trust_remote_code=True, |
| | use_auth_token=HF_TOKEN |
| | ) |
| |
|
| | |
| | quant_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_compute_dtype=torch.bfloat16, |
| | bnb_4bit_quant_type="nf4", |
| | llm_int8_enable_fp32_cpu_offload=True |
| | ) |
| |
|
| | |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | HF_MODEL_NAME, |
| | quantization_config=quant_config, |
| | device_map="auto", |
| | trust_remote_code=True, |
| | use_auth_token=HF_TOKEN |
| | ) |
| | self.model.eval() |
| |
|
| | |
| | |
| | self.generator = pipeline( |
| | task="text-generation", |
| | model=self.model, |
| | tokenizer=self.tokenizer, |
| | trust_remote_code=True |
| | ) |
| |
|
| | def __call__(self, data: dict) -> dict: |
| | """ |
| | Mỗi khi có request, Endpoint sẽ gọi handler(data). |
| | data: dict có structure như sau (tương ứng JSON payload của client): |
| | { |
| | "inputs": "<chuỗi bài báo muốn tóm tắt>", |
| | "parameters": { # có thể bỏ qua nếu không truyền |
| | "max_new_tokens": 150, |
| | "temperature": 0.7, |
| | "top_k": 50, |
| | "top_p": 0.95, |
| | "num_beams": 1, |
| | "do_sample": True |
| | # … bất kỳ tham số nào còn lại mà pipeline hỗ trợ |
| | } |
| | } |
| | Trả về: dict, ví dụ {"generated_text": "..."} hoặc list of dict nếu bạn muốn trả danh sách outputs. |
| | """ |
| | |
| | inputs = data.get("inputs") |
| | if inputs is None: |
| | raise ValueError("Không tìm thấy key 'inputs' trong payload. Vui lòng gửi JSON có field 'inputs'.") |
| |
|
| | |
| | parameters = data.get("parameters", {}) |
| |
|
| | |
| | max_new_tokens = parameters.get("max_new_tokens", 200) |
| | temperature = parameters.get("temperature", 0.7) |
| | top_k = parameters.get("top_k", None) |
| | top_p = parameters.get("top_p", None) |
| | num_beams = parameters.get("num_beams", None) |
| | do_sample = parameters.get("do_sample", True) |
| |
|
| | |
| | prompt_text = self._make_prompt(inputs) |
| |
|
| | |
| | outputs = self.generator( |
| | prompt_text, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | top_k=top_k, |
| | top_p=top_p, |
| | num_beams=num_beams, |
| | do_sample=do_sample |
| | ) |
| |
|
| | |
| | generated = outputs[0].get("generated_text") |
| | if generated is None: |
| | raise RuntimeError("Pipeline trả về output không chứa key 'generated_text'.") |
| |
|
| | |
| | return {"generated_text": generated} |
| |
|
| | @staticmethod |
| | def _make_prompt(article_text: str) -> str: |
| | """ |
| | Build prompt theo format bạn đã train trong notebook: |
| | - Dòng đầu: instruction “Tưởng tượng bạn là một chuyên gia về tài chính...” |
| | - Dòng tiếp: nội dung bài báo |
| | - Kết thúc bằng “Tóm tắt sau khi đọc bài báo:” |
| | """ |
| | instruction = ( |
| | "Tưởng tượng bạn là một chuyên gia về tài chính trong các lĩnh vực " |
| | "kinh tế số, thị trường chứng khoán, bất động sản, doanh nghiệp, " |
| | "tài chính ngân hàng" |
| | ) |
| | return f"{instruction}\n\n{article_text}\n\nTóm tắt sau khi đọc bài báo:" |
| |
|