Wind-xixi commited on
Commit
16a9c1b
·
verified ·
1 Parent(s): e50a344

Update predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +67 -11
predictor.py CHANGED
@@ -2,20 +2,40 @@ import json
2
  import re
3
  import onnxruntime as ort
4
  import numpy as np
5
- from typing import List, Dict, Set
6
 
7
 
8
  score_map = {'A': 5, 'B': 4, 'C': 3, 'D': 2, 'E': 1}
9
 
10
 
11
  class SentenceExtractor:
12
- def __init__(self, eval_keywords_path: str, model_path: str = "distilled_model.onnx"):
 
 
 
 
 
 
 
 
 
 
 
 
13
  self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
14
  self.all_keywords = self._extract_all_keywords()
15
 
16
  self.ort_session = None
17
  self.input_name = None
18
  self.output_name = None
 
 
 
 
 
 
 
 
19
  try:
20
  self.ort_session = ort.InferenceSession(model_path)
21
  self.input_name = self.ort_session.get_inputs()[0].name
@@ -114,17 +134,48 @@ class SentenceExtractor:
114
  def _split_into_sentences(self, text: str) -> List[str]:
115
  if not text:
116
  return []
 
 
117
  normalized = re.sub(r'([。!?.!?])', r'\1\n', text)
118
  normalized = re.sub(r'[;;]\s*', ';\n', normalized)
119
  candidates = [s.strip() for s in re.split(r'[\r\n]+', normalized) if s.strip()]
120
- sentences: List[str] = []
 
 
121
  for s in candidates:
122
  if len(s) > 80 and not re.search(r'[。!?.!?;;]', s):
123
  parts = re.split(r'[,,]', s)
124
- sentences.extend([p.strip() for p in parts if p.strip()])
125
  else:
126
- sentences.append(s)
127
- return sentences
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  def _fuzzy_match_keyword(self, sentence: str, keyword: str) -> bool:
130
  """更严格的中文关键词匹配。
@@ -271,16 +322,21 @@ class SentenceExtractor:
271
 
272
  comprehensive_grade = "C"
273
  if relevant_sentences:
274
- avg_score = total_sentence_score / len(relevant_sentences)
275
- rounded_score = int(round(avg_score))
276
  reverse_map = {5: 'A', 4: 'B', 3: 'C', 2: 'D', 1: 'E'}
277
- comprehensive_grade = reverse_map.get(rounded_score, "C")
 
 
 
 
 
 
 
278
 
279
  word_scores = self._calculate_word_scores(text)
280
  final_grade = comprehensive_grade
281
- if word_scores["total_score"] > 0:
282
  final_grade = comprehensive_grade + "+"
283
- elif word_scores["total_score"] < 0:
284
  final_grade = comprehensive_grade + "-"
285
 
286
  return {
 
2
  import re
3
  import onnxruntime as ort
4
  import numpy as np
5
+ from typing import List, Dict, Set, Optional
6
 
7
 
8
  score_map = {'A': 5, 'B': 4, 'C': 3, 'D': 2, 'E': 1}
9
 
10
 
11
  class SentenceExtractor:
12
+ def __init__(
13
+ self,
14
+ eval_keywords_path: str,
15
+ model_path: str = "distilled_model.onnx",
16
+ *,
17
+ # 分句与聚合相关的可配置开关
18
+ merge_leading_punct: bool = True,
19
+ min_sentence_char_len: int = 6,
20
+ aggregation_mode: str = "max", # 可选:"max" | "mean"
21
+ # 加减号阈值(>0 / <0 为原逻辑;建议适度提高到 2/-2)
22
+ word_score_plus_threshold: int = 1,
23
+ word_score_minus_threshold: int = -1,
24
+ ):
25
  self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
26
  self.all_keywords = self._extract_all_keywords()
27
 
28
  self.ort_session = None
29
  self.input_name = None
30
  self.output_name = None
31
+ # 配置项
32
+ self.merge_leading_punct = merge_leading_punct
33
+ self.min_sentence_char_len = max(0, int(min_sentence_char_len))
34
+ self.aggregation_mode = aggregation_mode.lower().strip()
35
+ if self.aggregation_mode not in {"max", "mean"}:
36
+ self.aggregation_mode = "max"
37
+ self.word_score_plus_threshold = int(word_score_plus_threshold)
38
+ self.word_score_minus_threshold = int(word_score_minus_threshold)
39
  try:
40
  self.ort_session = ort.InferenceSession(model_path)
41
  self.input_name = self.ort_session.get_inputs()[0].name
 
134
  def _split_into_sentences(self, text: str) -> List[str]:
135
  if not text:
136
  return []
137
+
138
+ # 先按强终止符切分
139
  normalized = re.sub(r'([。!?.!?])', r'\1\n', text)
140
  normalized = re.sub(r'[;;]\s*', ';\n', normalized)
141
  candidates = [s.strip() for s in re.split(r'[\r\n]+', normalized) if s.strip()]
142
+
143
+ # 长句再按逗号细分
144
+ rough_sentences: List[str] = []
145
  for s in candidates:
146
  if len(s) > 80 and not re.search(r'[。!?.!?;;]', s):
147
  parts = re.split(r'[,,]', s)
148
+ rough_sentences.extend([p.strip() for p in parts if p.strip()])
149
  else:
150
+ rough_sentences.append(s)
151
+
152
+ # 合并以标点开头的碎片,并过滤超短句
153
+ sentences: List[str] = []
154
+ leading_punct_pattern = r'^[,,。;;::、\s]+'
155
+ for s in rough_sentences:
156
+ if self.merge_leading_punct and re.match(leading_punct_pattern, s):
157
+ # 去掉前缀标点后并入上一句
158
+ cleaned = re.sub(leading_punct_pattern, '', s)
159
+ if sentences:
160
+ sentences[-1] = f"{sentences[-1]}{cleaned}"
161
+ else:
162
+ if cleaned:
163
+ sentences.append(cleaned)
164
+ continue
165
+
166
+ # 过滤极短句(去标点长度)
167
+ plain = re.sub(r'[,,。;;::、!!??\s]', '', s)
168
+ if self.min_sentence_char_len > 0 and len(plain) < self.min_sentence_char_len:
169
+ # 不直接丢弃:若有上一句,合并
170
+ if sentences:
171
+ sentences[-1] = f"{sentences[-1]}{s}"
172
+ else:
173
+ sentences.append(s)
174
+ continue
175
+
176
+ sentences.append(s)
177
+
178
+ return [s.strip() for s in sentences if s and s.strip()]
179
 
180
  def _fuzzy_match_keyword(self, sentence: str, keyword: str) -> bool:
181
  """更严格的中文关键词匹配。
 
322
 
323
  comprehensive_grade = "C"
324
  if relevant_sentences:
 
 
325
  reverse_map = {5: 'A', 4: 'B', 3: 'C', 2: 'D', 1: 'E'}
326
+ if self.aggregation_mode == "max":
327
+ # 取最高等级(更鲁棒,避免短碎句拉低均值)
328
+ max_score = max(score_map.get(item["grade"], 3) for item in scored_sentences)
329
+ comprehensive_grade = reverse_map.get(max_score, "C")
330
+ else:
331
+ avg_score = total_sentence_score / len(relevant_sentences)
332
+ rounded_score = int(round(avg_score))
333
+ comprehensive_grade = reverse_map.get(rounded_score, "C")
334
 
335
  word_scores = self._calculate_word_scores(text)
336
  final_grade = comprehensive_grade
337
+ if word_scores["total_score"] > self.word_score_plus_threshold:
338
  final_grade = comprehensive_grade + "+"
339
+ elif word_scores["total_score"] < self.word_score_minus_threshold:
340
  final_grade = comprehensive_grade + "-"
341
 
342
  return {