Wind-xixi commited on
Commit
10261fa
·
verified ·
1 Parent(s): 352ef77

Update predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +59 -29
predictor.py CHANGED
@@ -8,7 +8,7 @@ score_map = {'A': 5, 'B': 4, 'C': 3, 'D': 2, 'E': 1}
8
 
9
 
10
  class SentenceExtractor:
11
- def __init__(self, main_keywords_path: str, eval_keywords_path: str, model_path: str = "model_quantized.onnx"):
12
  """
13
  初始化句子提取器,加载主关键词、评估关键词库和评分模型
14
  :param main_keywords_path: 主关键词JSON文件路径
@@ -19,50 +19,80 @@ class SentenceExtractor:
19
  self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
20
  self.all_keywords = self._merge_all_keywords()
21
 
22
- # 加载ONNX评分模型
23
- self.ort_session = ort.InferenceSession(model_path)
24
- self.input_name = self.ort_session.get_inputs()[0].name
25
- self.output_name = self.ort_session.get_outputs()[0].name
 
 
 
 
 
 
 
 
 
26
 
27
  def _preprocess_text(self, text: str) -> np.ndarray:
28
  """
29
- 预处理文本,为模型输入做准备
30
- 注意:这里需要根据实际模型的输入要求进行调整
31
  """
32
- # 示例预处理 - 实际实现需与训练模型时的预处理一致
33
- # 这里假设模型接受固定长度的词向量或嵌入
34
- # 以下为示例代码,需根据实际模型修改
35
  max_seq_length = 128
36
- # 简单的哈希特征示例(实际应使用与模型匹配的预处理)
37
  features = np.zeros((1, max_seq_length), dtype=np.float32)
38
- for i, char in enumerate(text[:max_seq_length]):
39
- features[0, i] = hash(char) % 1000 / 1000.0
40
  return features
41
 
42
- def _predict_grade(self, text: str) -> str:
43
- """
44
- 使用ONNX模型预测文本评分等级
45
- :param text: 输入句子
46
- :return: 评分等级(A/B/C/D/E)
47
- """
48
  try:
49
- # 预处理文本
 
50
  input_data = self._preprocess_text(text)
51
-
52
- # 模型推理
53
  outputs = self.ort_session.run([self.output_name], {self.input_name: input_data})
54
-
55
- # 解析模型输出获取等级
56
- # 假设模型输出是概率分布,取最大概率对应的等级
57
  predictions = outputs[0]
58
- grade_index = np.argmax(predictions)
59
-
60
- # 将索引映射到等级A-E
61
  grades = ['A', 'B', 'C', 'D', 'E']
62
  return grades[grade_index]
63
  except Exception as e:
64
  print(f"模型预测出错: {e}")
65
- return "C" # 出错时返回默认等级
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def _load_keywords(self, file_path: str) -> Dict[str, List[str]]:
68
  """加载主关键词文件"""
 
8
 
9
 
10
  class SentenceExtractor:
11
+ def __init__(self, main_keywords_path: str, eval_keywords_path: str, model_path: str = "model_quantized.onnx", use_model: bool = False):
12
  """
13
  初始化句子提取器,加载主关键词、评估关键词库和评分模型
14
  :param main_keywords_path: 主关键词JSON文件路径
 
19
  self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
20
  self.all_keywords = self._merge_all_keywords()
21
 
22
+ # 加载ONNX评分模型(可选)
23
+ self.use_model = use_model
24
+ self.ort_session = None
25
+ self.input_name = None
26
+ self.output_name = None
27
+ if self.use_model:
28
+ try:
29
+ self.ort_session = ort.InferenceSession(model_path)
30
+ self.input_name = self.ort_session.get_inputs()[0].name
31
+ self.output_name = self.ort_session.get_outputs()[0].name
32
+ except Exception as e:
33
+ print(f"ONNX 模型加载失败,回退到启发式打分: {e}")
34
+ self.use_model = False
35
 
36
  def _preprocess_text(self, text: str) -> np.ndarray:
37
  """
38
+ 预处理文本(占位实现)。若启用 ONNX,请根据训练时的 tokenizer/embedding 改造。
 
39
  """
 
 
 
40
  max_seq_length = 128
 
41
  features = np.zeros((1, max_seq_length), dtype=np.float32)
42
+ for i, ch in enumerate(text[:max_seq_length]):
43
+ features[0, i] = (ord(ch) % 256) / 255.0
44
  return features
45
 
46
+ def _predict_grade_with_model(self, text: str) -> str:
 
 
 
 
 
47
  try:
48
+ if not self.ort_session:
49
+ return "C"
50
  input_data = self._preprocess_text(text)
 
 
51
  outputs = self.ort_session.run([self.output_name], {self.input_name: input_data})
 
 
 
52
  predictions = outputs[0]
53
+ grade_index = int(np.argmax(predictions))
 
 
54
  grades = ['A', 'B', 'C', 'D', 'E']
55
  return grades[grade_index]
56
  except Exception as e:
57
  print(f"模型预测出错: {e}")
58
+ return "C"
59
+
60
+ def _predict_grade_heuristic(self, text: str) -> str:
61
+ score = 0
62
+ hit_any = False
63
+ for category in ["student_performance", "content_quality", "cross_scene"]:
64
+ cat_dict = self.eval_keywords.get(category, {})
65
+ for sentiment, weight in [["positive", 2], ["suggestion", 1], ["negative", -2], ["nature", 0]]:
66
+ for kw in cat_dict.get(sentiment, []):
67
+ if kw and kw in text:
68
+ score += weight
69
+ hit_any = True
70
+ if not hit_any:
71
+ for _, kws in self.main_keywords.items():
72
+ if any(kw in text for kw in kws):
73
+ return "C"
74
+ return "C"
75
+
76
+ if score >= 3:
77
+ return "A"
78
+ if score >= 1:
79
+ return "B"
80
+ if score == 0:
81
+ return "C"
82
+ if score <= -3:
83
+ return "E"
84
+ return "D"
85
+
86
+ def _predict_grade(self, text: str) -> str:
87
+ grade = self._predict_grade_heuristic(text)
88
+ if self.use_model:
89
+ model_grade = self._predict_grade_with_model(text)
90
+ # 简单融合策略:若模型比启发式高两档以上,则提升一档
91
+ order = {"A":5, "B":4, "C":3, "D":2, "E":1}
92
+ if order.get(model_grade,3) - order.get(grade,3) >= 2:
93
+ return ["A","B","C","D","E"][max(0, 5 - (order.get(grade,3)+1))]
94
+ return grade
95
+ return grade
96
 
97
  def _load_keywords(self, file_path: str) -> Dict[str, List[str]]:
98
  """加载主关键词文件"""