homer7676 commited on
Commit
7de7db9
·
verified ·
1 Parent(s): bd9671f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +55 -25
handler.py CHANGED
@@ -7,61 +7,56 @@ logging.basicConfig(level=logging.INFO)
7
  logger = logging.getLogger(__name__)
8
 
9
  class EndpointHandler:
10
- def __init__(self, model_dir: str = None): # 修改這裡,添加類型提示和默認值
11
- logger.info(f"初始化 EndpointHandler,model_dir: {model_dir}")
12
  self.model_dir = model_dir if model_dir else "homer7676/FrierenChatbotV1"
13
  self.tokenizer = None
14
  self.model = None
15
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
- logger.info(f"使用設備: {self.device}")
17
 
18
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
19
  try:
20
  inputs = self.preprocess(data)
21
  outputs = self.inference(inputs)
 
 
 
22
  return [outputs]
23
  except Exception as e:
24
  logger.error(f"處理過程錯誤: {str(e)}")
25
  return [{"error": str(e)}]
26
 
27
  def initialize(self, context):
28
- """初始化模型和 tokenizer"""
29
  logger.info("開始初始化模型")
30
  try:
31
  self.tokenizer = AutoTokenizer.from_pretrained(
32
- self.model_dir, # 使用 model_dir
33
  trust_remote_code=True
34
  )
35
 
 
 
 
36
  self.model = AutoModelForCausalLM.from_pretrained(
37
- self.model_dir, # 使用 model_dir
38
  trust_remote_code=True,
39
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
40
  ).to(self.device)
41
 
42
- if self.tokenizer.pad_token is None:
43
- self.tokenizer.pad_token = self.tokenizer.eos_token
44
-
45
  self.model.eval()
46
  logger.info("模型初始化完成")
47
  except Exception as e:
48
  logger.error(f"模型載入錯誤: {str(e)}")
49
  raise
50
 
51
- def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
52
- """預處理輸入數據"""
53
- inputs = data.pop("inputs", data)
54
- if not isinstance(inputs, dict):
55
- inputs = {"message": inputs}
56
- return inputs
57
-
58
  def inference(self, inputs: Dict[str, Any]) -> Dict[str, str]:
59
- """執行推理"""
60
  logger.info("開始執行推理")
61
  try:
62
  message = inputs.get("message", "")
63
  context = inputs.get("context", "")
 
64
 
 
65
  prompt = f"""你是芙莉蓮,需要遵守以下規則回答:
66
  1. 身份設定:
67
  - 千年精靈魔法師
@@ -80,29 +75,57 @@ class EndpointHandler:
80
  用戶:{message}
81
  芙莉蓮:"""
82
 
83
- inputs = self.tokenizer(
 
 
 
 
84
  prompt,
 
85
  return_tensors="pt",
86
  padding=True,
87
  truncation=True,
88
  max_length=2048
89
- ).to(self.device)
 
 
 
 
 
 
90
 
 
91
  with torch.no_grad():
92
  outputs = self.model.generate(
93
- **inputs,
 
94
  max_new_tokens=256,
95
  temperature=0.7,
96
  top_p=0.9,
97
  top_k=50,
98
  do_sample=True,
99
  pad_token_id=self.tokenizer.pad_token_id,
100
- eos_token_id=self.tokenizer.eos_token_id
 
 
101
  )
 
 
102
 
103
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
104
- response = response.split("芙莉蓮:")[-1].strip()
105
- logger.info("生成回應完成")
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  return {
108
  "generated_text": response
@@ -110,4 +133,11 @@ class EndpointHandler:
110
 
111
  except Exception as e:
112
  logger.error(f"推理過程錯誤: {str(e)}")
113
- return {"error": str(e)}
 
 
 
 
 
 
 
 
7
  logger = logging.getLogger(__name__)
8
 
9
  class EndpointHandler:
10
+ def __init__(self, model_dir: str = None):
 
11
  self.model_dir = model_dir if model_dir else "homer7676/FrierenChatbotV1"
12
  self.tokenizer = None
13
  self.model = None
14
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ logger.info(f"初始化 EndpointHandler,使用設備: {self.device}")
16
 
17
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
18
  try:
19
  inputs = self.preprocess(data)
20
  outputs = self.inference(inputs)
21
+ # 確保輸出不為空
22
+ if not outputs or "generated_text" not in outputs:
23
+ raise ValueError("No text was generated")
24
  return [outputs]
25
  except Exception as e:
26
  logger.error(f"處理過程錯誤: {str(e)}")
27
  return [{"error": str(e)}]
28
 
29
  def initialize(self, context):
 
30
  logger.info("開始初始化模型")
31
  try:
32
  self.tokenizer = AutoTokenizer.from_pretrained(
33
+ self.model_dir,
34
  trust_remote_code=True
35
  )
36
 
37
+ if self.tokenizer.pad_token is None:
38
+ self.tokenizer.pad_token = self.tokenizer.eos_token
39
+
40
  self.model = AutoModelForCausalLM.from_pretrained(
41
+ self.model_dir,
42
  trust_remote_code=True,
43
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
44
  ).to(self.device)
45
 
 
 
 
46
  self.model.eval()
47
  logger.info("模型初始化完成")
48
  except Exception as e:
49
  logger.error(f"模型載入錯誤: {str(e)}")
50
  raise
51
 
 
 
 
 
 
 
 
52
  def inference(self, inputs: Dict[str, Any]) -> Dict[str, str]:
 
53
  logger.info("開始執行推理")
54
  try:
55
  message = inputs.get("message", "")
56
  context = inputs.get("context", "")
57
+ logger.info(f"處理訊息: {message}")
58
 
59
+ # 構建提示詞
60
  prompt = f"""你是芙莉蓮,需要遵守以下規則回答:
61
  1. 身份設定:
62
  - 千年精靈魔法師
 
75
  用戶:{message}
76
  芙莉蓮:"""
77
 
78
+ # 記錄提示詞長度
79
+ logger.info(f"提示詞長度: {len(prompt)}")
80
+
81
+ # Tokenize
82
+ encoding = self.tokenizer.encode_plus(
83
  prompt,
84
+ add_special_tokens=True,
85
  return_tensors="pt",
86
  padding=True,
87
  truncation=True,
88
  max_length=2048
89
+ )
90
+
91
+ # 移動到正確的設備
92
+ input_ids = encoding["input_ids"].to(self.device)
93
+ attention_mask = encoding["attention_mask"].to(self.device)
94
+
95
+ logger.info(f"輸入 token 數量: {input_ids.shape[-1]}")
96
 
97
+ # 生成回應
98
  with torch.no_grad():
99
  outputs = self.model.generate(
100
+ input_ids=input_ids,
101
+ attention_mask=attention_mask,
102
  max_new_tokens=256,
103
  temperature=0.7,
104
  top_p=0.9,
105
  top_k=50,
106
  do_sample=True,
107
  pad_token_id=self.tokenizer.pad_token_id,
108
+ eos_token_id=self.tokenizer.eos_token_id,
109
+ num_return_sequences=1,
110
+ no_repeat_ngram_size=3
111
  )
112
+
113
+ logger.info(f"生成的 token 數量: {outputs.shape[-1]}")
114
 
115
+ # 解碼回應
116
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
117
+
118
+ # 分離出模型的回應部分
119
+ if "芙莉蓮:" in full_response:
120
+ response = full_response.split("芙莉蓮:")[-1].strip()
121
+ else:
122
+ response = full_response.split("用戶:")[-1].strip()
123
+
124
+ logger.info(f"生成回應長度: {len(response)}")
125
+
126
+ # 確保回應不為空
127
+ if not response:
128
+ response = "抱歉,我似乎有點恍神了。能請你再說一次嗎?"
129
 
130
  return {
131
  "generated_text": response
 
133
 
134
  except Exception as e:
135
  logger.error(f"推理過程錯誤: {str(e)}")
136
+ raise
137
+
138
+ def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
139
+ logger.info(f"預處理輸入數據: {data}")
140
+ inputs = data.pop("inputs", data)
141
+ if not isinstance(inputs, dict):
142
+ inputs = {"message": inputs}
143
+ return inputs