File size: 5,848 Bytes
9bd77e9
4727905
a5262a6
507b3a1
 
 
 
9bd77e9
 
7de7db9
bd9671f
11ffb79
7de7db9
4233874
 
507b3a1
4233874
507b3a1
7de7db9
507b3a1
 
7de7db9
 
4233874
507b3a1
7de7db9
507b3a1
 
 
 
4233874
 
507b3a1
4233874
507b3a1
9bd77e9
4233874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5262a6
507b3a1
e83d9e4
4233874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e83d9e4
 
4233874
 
a5262a6
11ffb79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d029ab
4233874
 
 
 
 
a5262a6
e83d9e4
9d029ab
e83d9e4
9d029ab
4233874
a5262a6
e83d9e4
11ffb79
4233874
 
e83d9e4
 
 
a5262a6
 
4233874
e83d9e4
a5262a6
4233874
 
7de7db9
 
4233874
 
 
 
a5262a6
11ffb79
507b3a1
4233874
7de7db9
 
 
4233874
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Dict, Any, List
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, model_dir: str = None):
        self.model_dir = model_dir if model_dir else "homer7676/FrierenChatbotV1"
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"初始化 EndpointHandler,使用設備: {self.device}")
        
        # 在初始化時就載入模型和 tokenizer
        try:
            logger.info("開始載入 tokenizer 和模型")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_dir,
                trust_remote_code=True
            )
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_dir,
                trust_remote_code=True,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            ).to(self.device)
            self.model.eval()
            logger.info("模型和 tokenizer 載入完成")
            
        except Exception as e:
            logger.error(f"初始化錯誤: {str(e)}")
            raise

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
        try:
            # 確保 tokenizer 和 model 已經初始化
            if self.tokenizer is None or self.model is None:
                raise RuntimeError("Tokenizer or model not initialized")
                
            inputs = self.preprocess(data)
            outputs = self.inference(inputs)
            return [outputs]
        except Exception as e:
            logger.error(f"處理過程錯誤: {str(e)}")
            return [{"error": str(e)}]

    def initialize(self, context):
        """確保模型已初始化"""
        if self.tokenizer is None or self.model is None:
            logger.info("在 initialize 中重新初始化模型")
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.model_dir,
                    trust_remote_code=True
                )
                if self.tokenizer.pad_token is None:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
                    
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_dir,
                    trust_remote_code=True,
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
                ).to(self.device)
                self.model.eval()
                logger.info("模型重新初始化完成")
            except Exception as e:
                logger.error(f"模型重新初始化錯誤: {str(e)}")
                raise

    def inference(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        logger.info("開始執行推理")
        try:
            # 檢查輸入格式
            if isinstance(inputs, str):
                try:
                    import json
                    inputs = json.loads(inputs)
                except json.JSONDecodeError:
                    inputs = {"message": inputs}
            
            # 提取消息和上下文
            if isinstance(inputs, dict) and "inputs" in inputs:
                inputs = inputs["inputs"]
            if isinstance(inputs, str):
                try:
                    import json
                    inputs = json.loads(inputs)
                except json.JSONDecodeError:
                    inputs = {"message": inputs}
            
            message = inputs.get("message", "")
            context = inputs.get("context", "")
            logger.info(f"處理消息: {message}, 上下文: {context}")

            prompt = f"""你是芙莉蓮,需要遵守以下規則回答:
1. 身份設定:
 - 千年精靈魔法師
 - 態度溫柔但帶著些許嘲諷
 - 說話優雅且有距離感
2. 重要關係:
 - 弗蘭梅是我的師傅
 - 費倫是我的學生
 - 欣梅爾是我的摯友
 - 海塔是我的故友
3. 回答規則:
 - 使用繁體中文
 - 必須提供具體詳細的內容
 - 保持回答的連貫性和完整性
相關資訊:{context}
用戶:{message}
芙莉蓮:"""

            # 確保 tokenizer 存在
            if self.tokenizer is None:
                raise RuntimeError("Tokenizer not initialized")

            tokens = self.tokenizer(
                prompt,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=2048
            ).to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    **tokens,
                    max_new_tokens=150,
                    temperature=0.7,
                    top_p=0.9,
                    top_k=50,
                    do_sample=True,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )

            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = response.split("芙莉蓮:")[-1].strip()
            
            if not response:
                response = "唔...讓我思考一下如何回答你的問題。"
                
            logger.info(f"生成回應: {response}")
            return {"generated_text": response}

        except Exception as e:
            logger.error(f"推理過程錯誤: {str(e)}")
            return {"error": str(e)}

    def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
        logger.info(f"預處理輸入數據: {data}")
        return data