homer7676 commited on
Commit
507b3a1
·
verified ·
1 Parent(s): 3135779

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +54 -34
handler.py CHANGED
@@ -1,45 +1,58 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from typing import Dict, Any
 
 
 
 
 
4
 
5
  class EndpointHandler:
6
- def __init__(self):
 
7
  self.tokenizer = None
8
  self.model = None
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
-
 
11
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
12
  """使 handler 可調用"""
13
- inputs = self.preprocess(data)
14
- outputs = self.inference(inputs)
15
- return self.postprocess(outputs)
16
 
17
  def initialize(self, context):
18
  """初始化模型和 tokenizer"""
19
- self.tokenizer = AutoTokenizer.from_pretrained(
20
- "homer7676/FrierenChatbotV1",
21
- trust_remote_code=True
22
- )
23
- self.model = AutoModelForCausalLM.from_pretrained(
24
- "homer7676/FrierenChatbotV1",
25
- trust_remote_code=True,
26
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
27
- ).to(self.device)
28
- self.model.eval()
29
-
30
- def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
31
- """預處理輸入數據"""
32
- inputs = data.pop("inputs", data)
33
- if not isinstance(inputs, dict):
34
- inputs = {"message": inputs}
35
- return inputs
 
 
 
 
36
 
37
  def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
38
  """執行推理"""
 
39
  try:
40
  message = inputs.get("message", "")
41
  context = inputs.get("context", "")
42
- prompt = f"""你是芙莉蓮,需要遵守以下規則回答:
 
 
43
  1. 身份設定:
44
  - 千年精靈魔法師
45
  - 態度溫柔但帶著些許嘲諷
@@ -57,14 +70,19 @@ class EndpointHandler:
57
  用戶:{message}
58
  芙莉蓮:"""
59
 
 
 
 
 
60
  inputs = self.tokenizer(
61
- prompt,
62
  return_tensors="pt",
63
  padding=True,
64
  truncation=True,
65
  max_length=2048
66
  ).to(self.device)
67
-
 
68
  with torch.no_grad():
69
  outputs = self.model.generate(
70
  **inputs,
@@ -72,21 +90,23 @@ class EndpointHandler:
72
  temperature=0.7,
73
  top_p=0.9,
74
  top_k=50,
75
- do_sample=True,
76
- repetition_penalty=1.2,
77
- pad_token_id=self.tokenizer.pad_token_id,
78
- eos_token_id=self.tokenizer.eos_token_id
79
  )
80
-
81
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
82
  response = response.split("芙莉蓮:")[-1].strip()
 
83
 
84
  return {"generated_text": response}
85
 
86
  except Exception as e:
87
- print(f"推理過程錯誤: {str(e)}")
88
  return {"error": str(e)}
89
 
90
- def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
91
- """處理輸數據"""
92
- return data
 
 
 
 
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from typing import Dict, Any
4
+ import logging
5
+
6
+ # 設置日誌
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
 
10
  class EndpointHandler:
11
+ def __init__(self, model_dir=None):
12
+ logger.info("初始化 EndpointHandler")
13
  self.tokenizer = None
14
  self.model = None
15
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ logger.info(f"使用設備: {self.device}")
17
+
18
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
19
  """使 handler 可調用"""
20
+ logger.info("調用 __call__ 方法")
21
+ return self.inference(self.preprocess(data))
 
22
 
23
  def initialize(self, context):
24
  """初始化模型和 tokenizer"""
25
+ logger.info("開始初始化模型")
26
+ try:
27
+ self.tokenizer = AutoTokenizer.from_pretrained(
28
+ "homer7676/FrierenChatbotV1",
29
+ trust_remote_code=True
30
+ )
31
+ logger.info("Tokenizer 載入成功")
32
+
33
+ self.model = AutoModelForCausalLM.from_pretrained(
34
+ "homer7676/FrierenChatbotV1",
35
+ trust_remote_code=True,
36
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
37
+ ).to(self.device)
38
+ logger.info("模型載入成功")
39
+
40
+ self.model.eval()
41
+ logger.info("模型初始化完成")
42
+
43
+ except Exception as e:
44
+ logger.error(f"模型載入錯誤: {str(e)}")
45
+ raise
46
 
47
  def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
48
  """執行推理"""
49
+ logger.info("開始執行推理")
50
  try:
51
  message = inputs.get("message", "")
52
  context = inputs.get("context", "")
53
+ logger.info(f"收到訊息: {message}")
54
+
55
+ input_text = f"""你是芙莉蓮,需要遵守以下規則回答:
56
  1. 身份設定:
57
  - 千年精靈魔法師
58
  - 態度溫柔但帶著些許嘲諷
 
70
  用戶:{message}
71
  芙莉蓮:"""
72
 
73
+ # 記錄 token 數量
74
+ tokens = self.tokenizer.encode(input_text)
75
+ logger.info(f"輸入 token 數量: {len(tokens)}")
76
+
77
  inputs = self.tokenizer(
78
+ input_text,
79
  return_tensors="pt",
80
  padding=True,
81
  truncation=True,
82
  max_length=2048
83
  ).to(self.device)
84
+
85
+ logger.info("開始生成回應")
86
  with torch.no_grad():
87
  outputs = self.model.generate(
88
  **inputs,
 
90
  temperature=0.7,
91
  top_p=0.9,
92
  top_k=50,
93
+ do_sample=True
 
 
 
94
  )
95
+
96
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
97
  response = response.split("芙莉蓮:")[-1].strip()
98
+ logger.info(f"生成回應完成,長度: {len(response)}")
99
 
100
  return {"generated_text": response}
101
 
102
  except Exception as e:
103
+ logger.error(f"推理過程錯誤: {str(e)}")
104
  return {"error": str(e)}
105
 
106
+ def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
107
+ """處理輸數據"""
108
+ logger.info(f"預處理輸入數據: {data}")
109
+ inputs = data.pop("inputs", data)
110
+ if not isinstance(inputs, dict):
111
+ inputs = {"message": inputs}
112
+ return inputs