File size: 6,211 Bytes
cd262f8
 
ee22ab7
cd262f8
0e4ccf3
 
03387d9
 
 
 
 
 
 
 
0e4ccf3
03387d9
 
 
0e4ccf3
8c08231
0e4ccf3
 
03387d9
 
 
 
 
 
 
0e4ccf3
 
 
03387d9
0e4ccf3
 
 
03387d9
0e4ccf3
 
 
 
 
 
8c08231
0e4ccf3
 
 
 
 
 
 
 
03387d9
0e4ccf3
 
 
8c08231
0e4ccf3
 
 
 
 
8c08231
 
0e4ccf3
 
 
 
8c08231
0e4ccf3
 
03387d9
0e4ccf3
03387d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e4ccf3
03387d9
 
 
 
 
 
 
0e4ccf3
03387d9
0e4ccf3
 
 
 
 
 
 
03387d9
 
0e4ccf3
03387d9
0e4ccf3
03387d9
cd262f8
 
 
 
 
03387d9
0e4ccf3
 
03387d9
 
 
 
0e4ccf3
03387d9
0e4ccf3
 
 
 
 
03387d9
 
 
 
0e4ccf3
 
 
 
 
cd262f8
0e4ccf3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# handler.py

import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig

class EndpointHandler:
    """
    Hugging Face Inference Endpoints sẽ tự động khởi tạo lớp này khi container startup:
        handler = EndpointHandler(model_dir=<path to repository>)
    Sau đó, mỗi khi có request, toolkit sẽ gọi handler(data) để chạy inference.
    """

    def __init__(self, model_dir: str, **kwargs):
        """
        model_dir: Đường dẫn tới thư mục mà HF Endpoint đã clone repo của bạn vào.
                   Trong trường hợp này, bạn load model trực tiếp từ Hugging Face Hub,
                   nên model_dir chỉ để “đáp ứng” signature. Bạn không cần dùng model_dir.
        """
        # 1. Bắt buộc phải có token để truy cập gated repo
        HF_TOKEN = os.getenv("HF_HUB_TOKEN")
        if not HF_TOKEN:
            raise RuntimeError(
                "HF_HUB_TOKEN chưa được thiết lập. Vui lòng thêm biến môi trường HF_HUB_TOKEN trong phần Settings → Environment Variables của Endpoint."
            )
        # Dòng debug này in ra token (bạn có thể comment/xóa sau khi chắc chắn đã đúng)
        print("DEBUG: HF_HUB_TOKEN =", HF_TOKEN)

        # 2. Định nghĩa model gốc và model fork (của bạn)
        BASE_MODEL    = "meta-llama/Meta-Llama-3-8B"
        HF_MODEL_NAME = "cuong4941/llama3-8b-summarize-vn"

        # 3. Chọn device: 'cuda' nếu có GPU, else 'cpu'
        DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"DEBUG: Using device = {DEVICE}")

        # 4. Load tokenizer từ BASE_MODEL, kèm use_auth_token để băng qua gated repo
        self.tokenizer = AutoTokenizer.from_pretrained(
            BASE_MODEL,
            trust_remote_code=True,
            use_auth_token=HF_TOKEN
        )

        # 5. Cấu hình bitsandbytes 4-bit quantization
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_quant_type="nf4",
            llm_int8_enable_fp32_cpu_offload=True
        )

        # 6. Load model 4-bit từ repo của bạn, kèm use_auth_token
        self.model = AutoModelForCausalLM.from_pretrained(
            HF_MODEL_NAME,
            quantization_config=quant_config,
            device_map="auto",            # Accelerate sẽ tự động phân bổ device
            trust_remote_code=True,
            use_auth_token=HF_TOKEN
        )
        self.model.eval()

        # 7. Tạo pipeline text-generation (KHÔNG truyền thêm 'device' vào đây)
        #    Accelerate đã xử lý việc phân bổ thiết bị (CUDA hoặc CPU), nên không cần device param.
        self.generator = pipeline(
            task="text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            trust_remote_code=True
        )

    def __call__(self, data: dict) -> dict:
        """
        Mỗi khi có request, Endpoint sẽ gọi handler(data).
        data: dict có structure như sau (tương ứng JSON payload của client):
            {
              "inputs": "<chuỗi bài báo muốn tóm tắt>",
              "parameters": {    # có thể bỏ qua nếu không truyền
                "max_new_tokens": 150,
                "temperature": 0.7,
                "top_k": 50,
                "top_p": 0.95,
                "num_beams": 1,
                "do_sample": True
                # … bất kỳ tham số nào còn lại mà pipeline hỗ trợ
              }
            }
        Trả về: dict, ví dụ {"generated_text": "..."} hoặc list of dict nếu bạn muốn trả danh sách outputs.
        """
        # 1. Lấy inputs (chuỗi văn bản cần tóm tắt)
        inputs = data.get("inputs")
        if inputs is None:
            raise ValueError("Không tìm thấy key 'inputs' trong payload. Vui lòng gửi JSON có field 'inputs'.")

        # 2. Lấy parameters, nếu không có thì khởi tạo rỗng
        parameters = data.get("parameters", {})

        # 3. Đặt mặc định các tham số generation nếu client không gửi
        max_new_tokens = parameters.get("max_new_tokens", 200)
        temperature    = parameters.get("temperature", 0.7)
        top_k          = parameters.get("top_k", None)
        top_p          = parameters.get("top_p", None)
        num_beams      = parameters.get("num_beams", None)
        do_sample      = parameters.get("do_sample", True)

        # 4. Xây prompt y hệt như trong notebook của bạn
        prompt_text = self._make_prompt(inputs)

        # 5. Gọi pipeline để generate text
        outputs = self.generator(
            prompt_text,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            num_beams=num_beams,
            do_sample=do_sample
        )

        # 6. Lấy kết quả đầu tiên từ list outputs
        generated = outputs[0].get("generated_text")
        if generated is None:
            raise RuntimeError("Pipeline trả về output không chứa key 'generated_text'.")

        # 7. Trả kết quả dưới dạng dictionary
        return {"generated_text": generated}

    @staticmethod
    def _make_prompt(article_text: str) -> str:
        """
        Build prompt theo format bạn đã train trong notebook:
        - Dòng đầu: instruction “Tưởng tượng bạn là một chuyên gia về tài chính...”
        - Dòng tiếp: nội dung bài báo
        - Kết thúc bằng “Tóm tắt sau khi đọc bài báo:”
        """
        instruction = (
            "Tưởng tượng bạn là một chuyên gia về tài chính trong các lĩnh vực "
            "kinh tế số, thị trường chứng khoán, bất động sản, doanh nghiệp, "
            "tài chính ngân hàng"
        )
        return f"{instruction}\n\n{article_text}\n\nTóm tắt sau khi đọc bài báo:"