Wind-xixi commited on
Commit
dbcf6db
·
verified ·
1 Parent(s): cec571c

Update predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +36 -8
predictor.py CHANGED
@@ -32,13 +32,28 @@ class SentenceExtractor:
32
 
33
  def _preprocess_text(self, text: str) -> np.ndarray:
34
  """
35
- 预处理文本为模型输入格式
36
  """
37
- max_seq_length = 128
38
- features = np.zeros((1, max_seq_length), dtype=np.float32)
39
- for i, ch in enumerate(text[:max_seq_length]):
40
- features[0, i] = (ord(ch) % 256) / 255.0
41
- return features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def _predict_grade_with_model(self, text: str) -> str:
44
  """
@@ -47,8 +62,20 @@ class SentenceExtractor:
47
  try:
48
  if not self.ort_session:
49
  return "C"
50
- input_data = self._preprocess_text(text)
51
- outputs = self.ort_session.run([self.output_name], {self.input_name: input_data})
 
 
 
 
 
 
 
 
 
 
 
 
52
  predictions = outputs[0]
53
  grade_index = int(np.argmax(predictions))
54
  grades = ['A', 'B', 'C', 'D', 'E']
@@ -261,4 +288,5 @@ if __name__ == "__main__":
261
 
262
  for i, item in enumerate(result['scored_sentences'], 1):
263
  print(f"句子{i}加评分等级:{item['sentence']} - {item['grade']}")
 
264
 
 
32
 
33
  def _preprocess_text(self, text: str) -> np.ndarray:
34
  """
35
+ 预处理文本为模型输入格式 - 使用BERT tokenizer
36
  """
37
+ try:
38
+ from transformers import AutoTokenizer
39
+ # 使用与学生模型相同的tokenizer
40
+ tokenizer = AutoTokenizer.from_pretrained("uer/chinese_roberta_L-4_H-256")
41
+ inputs = tokenizer(
42
+ text,
43
+ truncation=True,
44
+ padding=True,
45
+ max_length=512,
46
+ return_tensors='np'
47
+ )
48
+ return inputs
49
+ except Exception as e:
50
+ print(f"Tokenizer预处理失败: {e}")
51
+ # 降级到简单字符编码
52
+ max_seq_length = 128
53
+ features = np.zeros((1, max_seq_length), dtype=np.float32)
54
+ for i, ch in enumerate(text[:max_seq_length]):
55
+ features[0, i] = (ord(ch) % 256) / 255.0
56
+ return features
57
 
58
  def _predict_grade_with_model(self, text: str) -> str:
59
  """
 
62
  try:
63
  if not self.ort_session:
64
  return "C"
65
+ inputs = self._preprocess_text(text)
66
+
67
+ # 检查是否是tokenizer输出格式
68
+ if isinstance(inputs, dict) and 'input_ids' in inputs:
69
+ # BERT tokenizer格式
70
+ input_data = {
71
+ 'input_ids': inputs['input_ids'],
72
+ 'attention_mask': inputs['attention_mask']
73
+ }
74
+ else:
75
+ # 简单字符编码格式
76
+ input_data = {self.input_name: inputs}
77
+
78
+ outputs = self.ort_session.run([self.output_name], input_data)
79
  predictions = outputs[0]
80
  grade_index = int(np.argmax(predictions))
81
  grades = ['A', 'B', 'C', 'D', 'E']
 
288
 
289
  for i, item in enumerate(result['scored_sentences'], 1):
290
  print(f"句子{i}加评分等级:{item['sentence']} - {item['grade']}")
291
+
292