homer7676 commited on
Commit
4de6589
·
verified ·
1 Parent(s): 4727905

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -22
handler.py CHANGED
@@ -3,16 +3,14 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from typing import Dict, Any
4
  import re
5
 
6
- # 簡繁轉換字典
7
  SIMPLIFIED_TO_TRADITIONAL = {
8
  '发': '發', '书': '書', '记': '記', '亚': '亞', '欧': '歐', '韩': '韓', '边': '邊',
9
  '恒': '恆', '说': '說', '话': '話', '东': '東', '车': '車', '马': '馬', '样': '樣',
10
  '风': '風', '专': '專', '万': '萬', '劳': '勞', '动': '動', '习': '習', '头': '頭',
11
  '们': '們', '为': '為', '产': '產', '场': '場', '实': '實', '观': '觀', '见': '見',
12
  '师': '師', '长': '長', '识': '識', '电': '電', '图': '圖', '华': '華', '龙': '龍',
13
- '师': '師', '变': '變', '问': '問', '岁': '歲', '义': '義', '': '', '': '',
14
- '': '', '': '', '': '', '': '', '': '', '': '', '': '',
15
- '带': '帶', '难': '難'
16
  }
17
 
18
  class EndpointHandler:
@@ -23,7 +21,6 @@ class EndpointHandler:
23
  self.model_dir = model_dir if model_dir else "homer7676/FrierenChatbotV1"
24
 
25
  def initialize(self, context):
26
- """初始化模型和 tokenizer"""
27
  try:
28
  self.tokenizer = AutoTokenizer.from_pretrained(
29
  self.model_dir,
@@ -36,9 +33,8 @@ class EndpointHandler:
36
  self.model = AutoModelForCausalLM.from_pretrained(
37
  self.model_dir,
38
  trust_remote_code=True,
39
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
- device_map="auto"
41
- )
42
 
43
  self.model.eval()
44
 
@@ -47,14 +43,12 @@ class EndpointHandler:
47
  raise
48
 
49
  def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
50
- """預處理輸入數據"""
51
  inputs = data.pop("inputs", data)
52
  if not isinstance(inputs, dict):
53
  inputs = {"message": inputs}
54
  return inputs
55
 
56
  def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
57
- """執行推理"""
58
  try:
59
  message = inputs.get("message", "")
60
  context = inputs.get("context", "")
@@ -67,11 +61,7 @@ class EndpointHandler:
67
  truncation=True,
68
  max_length=2048,
69
  padding=True
70
- )
71
-
72
- for key in encoding:
73
- if isinstance(encoding[key], torch.Tensor):
74
- encoding[key] = encoding[key].to(self.device)
75
 
76
  with torch.no_grad():
77
  outputs = self.model.generate(
@@ -83,8 +73,6 @@ class EndpointHandler:
83
  top_k=50,
84
  do_sample=True,
85
  repetition_penalty=1.2,
86
- pad_token_id=self.tokenizer.pad_token_id,
87
- eos_token_id=self.tokenizer.eos_token_id,
88
  num_beams=4,
89
  early_stopping=True
90
  )
@@ -92,7 +80,6 @@ class EndpointHandler:
92
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
93
  response = response.split("芙莉蓮:")[-1].strip()
94
  response = self._process_response(response)
95
-
96
  return {"response": response}
97
 
98
  except Exception as e:
@@ -100,7 +87,6 @@ class EndpointHandler:
100
  return {"response": "抱歉,在處理您的請求時發生了錯誤。請稍後再試。", "error": str(e)}
101
 
102
  def _build_prompt(self, context: str, query: str) -> str:
103
- """構建提示詞"""
104
  return f"""你是芙莉蓮,需要遵守以下規則回答:
105
  1. 身份設定:
106
  - 千年精靈魔法師
@@ -120,13 +106,11 @@ class EndpointHandler:
120
  芙莉蓮:"""
121
 
122
  def _convert_to_traditional(self, text: str) -> str:
123
- """將簡體轉換為繁體"""
124
  for simplified, traditional in SIMPLIFIED_TO_TRADITIONAL.items():
125
  text = text.replace(simplified, traditional)
126
  return text
127
 
128
  def _process_response(self, response: str) -> str:
129
- """處理回應文本"""
130
  if not response or not response.strip():
131
  return "抱歉,我現在有點恍神,請你再問一次好嗎?"
132
 
@@ -139,5 +123,4 @@ class EndpointHandler:
139
  return response
140
 
141
  def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
142
- """後處理輸出數據"""
143
  return data
 
3
  from typing import Dict, Any
4
  import re
5
 
 
6
  SIMPLIFIED_TO_TRADITIONAL = {
7
  '发': '發', '书': '書', '记': '記', '亚': '亞', '欧': '歐', '韩': '韓', '边': '邊',
8
  '恒': '恆', '说': '說', '话': '話', '东': '東', '车': '車', '马': '馬', '样': '樣',
9
  '风': '風', '专': '專', '万': '萬', '劳': '勞', '动': '動', '习': '習', '头': '頭',
10
  '们': '們', '为': '為', '产': '產', '场': '場', '实': '實', '观': '觀', '见': '見',
11
  '师': '師', '长': '長', '识': '識', '电': '電', '图': '圖', '华': '華', '龙': '龍',
12
+ '变': '變', '问': '問', '岁': '歲', '义': '義', '': '', '': '', '乐': '樂',
13
+ '': '', '': '', '': '', '': '', '': '', '': '', '': ''
 
14
  }
15
 
16
  class EndpointHandler:
 
21
  self.model_dir = model_dir if model_dir else "homer7676/FrierenChatbotV1"
22
 
23
  def initialize(self, context):
 
24
  try:
25
  self.tokenizer = AutoTokenizer.from_pretrained(
26
  self.model_dir,
 
33
  self.model = AutoModelForCausalLM.from_pretrained(
34
  self.model_dir,
35
  trust_remote_code=True,
36
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
37
+ ).to(self.device)
 
38
 
39
  self.model.eval()
40
 
 
43
  raise
44
 
45
  def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
46
  inputs = data.pop("inputs", data)
47
  if not isinstance(inputs, dict):
48
  inputs = {"message": inputs}
49
  return inputs
50
 
51
  def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
 
52
  try:
53
  message = inputs.get("message", "")
54
  context = inputs.get("context", "")
 
61
  truncation=True,
62
  max_length=2048,
63
  padding=True
64
+ ).to(self.device)
 
 
 
 
65
 
66
  with torch.no_grad():
67
  outputs = self.model.generate(
 
73
  top_k=50,
74
  do_sample=True,
75
  repetition_penalty=1.2,
 
 
76
  num_beams=4,
77
  early_stopping=True
78
  )
 
80
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
81
  response = response.split("芙莉蓮:")[-1].strip()
82
  response = self._process_response(response)
 
83
  return {"response": response}
84
 
85
  except Exception as e:
 
87
  return {"response": "抱歉,在處理您的請求時發生了錯誤。請稍後再試。", "error": str(e)}
88
 
89
  def _build_prompt(self, context: str, query: str) -> str:
 
90
  return f"""你是芙莉蓮,需要遵守以下規則回答:
91
  1. 身份設定:
92
  - 千年精靈魔法師
 
106
  芙莉蓮:"""
107
 
108
  def _convert_to_traditional(self, text: str) -> str:
 
109
  for simplified, traditional in SIMPLIFIED_TO_TRADITIONAL.items():
110
  text = text.replace(simplified, traditional)
111
  return text
112
 
113
  def _process_response(self, response: str) -> str:
 
114
  if not response or not response.strip():
115
  return "抱歉,我現在有點恍神,請你再問一次好嗎?"
116
 
 
123
  return response
124
 
125
  def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
126
  return data