Wind-xixi commited on
Commit
c501664
·
verified ·
1 Parent(s): 75c3d28

Update predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +97 -126
predictor.py CHANGED
@@ -1,142 +1,113 @@
 
 
1
  import json
 
 
 
2
  import re
3
- from typing import List, Dict, Set, Tuple
4
 
5
- class SentenceExtractor:
6
- def __init__(self, main_keywords_path: str, eval_keywords_path: str):
7
  """
8
- 初始化句子提取器,加载主关键词评估关键词库
9
- :param main_keywords_path: 主关键词JSON文件路径
10
- :param eval_keywords_path: 评估关键词库(JSON)文件路径
11
  """
12
- self.main_keywords = self._load_keywords(main_keywords_path)
13
- self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
14
- # 合并所有关键词用于快速查找
15
- self.all_keywords = self._merge_all_keywords()
16
-
17
- def _load_keywords(self, file_path: str) -> Dict[str, List[str]]:
18
- """加载主关键词文件"""
19
- try:
20
- with open(file_path, 'r', encoding='utf-8') as f:
21
- return json.load(f)
22
- except Exception as e:
23
- print(f"加载主关键词文件失败: {e}")
24
- return {}
25
-
26
- def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]:
27
- """加载评估关键词库(evaluation_keywords2.json)"""
28
- try:
29
- with open(file_path, 'r', encoding='utf-8') as f:
30
- return json.load(f)
31
- except Exception as e:
32
- print(f"加载评估关键词库失败: {e}")
33
- return {}
34
-
35
- def _merge_all_keywords(self) -> Set[str]:
36
- """合并所有关键词到一个集合中,用于快速查找"""
37
- keywords_set = set()
38
 
39
- # 添加主关键词
40
- for category, keywords in self.main_keywords.items():
41
- keywords_set.update(keywords)
42
-
43
- # 添加评估关键词库中的所有关键词
44
- for category, types in self.eval_keywords.items():
45
- for type_, keywords in types.items():
46
- keywords_set.update(keywords)
47
-
48
- return keywords_set
49
-
50
- def _split_into_sentences(self, text: str) -> List[str]:
51
- """将文本分割成句子"""
52
- # 简单的句子分割正则,可根据需要优化
53
- sentence_endings = re.compile(r'(?<=[。!?,.!?])\s+')
54
- sentences = sentence_endings.split(text)
55
- return [s.strip() for s in sentences if s.strip()]
56
-
57
- def _extract_relevant_sentences(self, text: str) -> Tuple[List[str], Dict[str, List[str]]]:
58
  """
59
- 提取与关键词相关的句子
60
- :param text: 输入文本
61
- :return: 相关句子列表和按类别分组的句子字典
62
  """
63
- sentences = self._split_into_sentences(text)
 
64
  relevant_sentences = []
65
- categorized_sentences = {
66
- "main": [],
67
- "student_performance": {"positive": [], "negative": [], "nature": [], "suggestion": []},
68
- "content_quality": {"positive": [], "negative": [], "nature": [], "suggestion": []},
69
- "cross_scene": {"positive": [], "negative": [], "nature": [], "suggestion": []}
70
- }
71
-
72
  for sentence in sentences:
73
- # 检查是否包含主关键词
74
- main_keyword_matched = False
75
- for category, keywords in self.main_keywords.items():
76
- for keyword in keywords:
77
- if keyword in sentence:
78
- relevant_sentences.append(sentence)
79
- categorized_sentences["main"].append(sentence)
80
- main_keyword_matched = True
81
- break
82
- if main_keyword_matched:
83
- break
84
-
85
- # 检查评估关键库中的关键词
86
- for category in ["student_performance", "content_quality", "cross_scene"]:
87
- if category not in self.eval_keywords:
88
- continue
89
-
90
- for sentiment in ["positive", "negative", "nature", "suggestion"]:
91
- if sentiment not in self.eval_keywords[category]:
92
- continue
93
-
94
- for keyword in self.eval_keywords[category][sentiment]:
95
- if keyword in sentence and sentence not in categorized_sentences[category][sentiment]:
96
- # 如果还没添加到相关句子列表,则添加
97
- if sentence not in relevant_sentences:
98
- relevant_sentences.append(sentence)
99
- categorized_sentences[category][sentiment].append(sentence)
100
 
101
- return relevant_sentences, categorized_sentences
102
-
103
- def extract(self, text: str) -> Dict[str, any]:
 
 
 
 
 
 
 
 
104
  """
105
- 提取文本中与关键词相关的句子
106
- :param text: 输入文本
107
- :return: 包含相关句子和分类信息的字典
108
  """
109
- if not text:
110
- return {"relevant_sentences": [], "categorized_sentences": {}}
111
-
112
- relevant_sentences, categorized_sentences = self._extract_relevant_sentences(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- return {
115
- "relevant_sentences": relevant_sentences,
116
- "categorized_sentences": categorized_sentences,
117
- "count": len(relevant_sentences)
118
- }
 
 
 
 
 
 
 
119
 
120
- # 使用示例
121
- if __name__ == "__main__":
122
- # 假设主关键词文件名为main_keywords.json
123
- extractor = SentenceExtractor(
124
- main_keywords_path="main_keywords.json",
125
- eval_keywords_path="evaluation_keywords2.json"
126
- )
127
-
128
- sample_text = """
129
- 该学生表现优异,团队合作能力强,在项目中展现了很强的创新能力。
130
- 但代码质量不高,存在安全漏洞,需要加强测试。
131
- 项目文档完整,符合行业标准,具有很好的应用价值。
132
- 建议加强代码审查,提高系统安全性,优化算法效率。
133
- """
134
-
135
- result = extractor.extract(sample_text)
136
- print(f"提取到 {result['count']} 个相关句子:")
137
- for i, sent in enumerate(result['relevant_sentences'], 1):
138
- print(f"{i}. {sent}")
139
-
140
- print("\n按类别分组:")
141
- print(json.dumps(result['categorized_sentences'], ensure_ascii=False, indent=2))
142
 
 
 
 
 
 
 
1
+ # predictor.py
2
+
3
  import json
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from transformers import BertTokenizer
7
  import re
 
8
 
9
+ class Predictor:
10
+ def __init__(self):
11
  """
12
+ 在服务启动时一次性加载所有必要的模型文件。
 
 
13
  """
14
+ # 1. 加载分词器 (Tokenizer)
15
+ # Hugging Face Spaces会自动下载git仓库中的所有文件到当前目录
16
+ self.tokenizer = BertTokenizer.from_pretrained('.')
17
+
18
+ # 2. 加载ONNX模型并创建推理会话
19
+ self.ort_session = ort.InferenceSession('model_quantized.onnx')
20
+
21
+ # 3. 加载关键词词集
22
+ with open('evaluation_keywords2.json', 'r', encoding='utf-8') as f:
23
+ self.keywords = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # 4. 定义等级映射
26
+ self.id2label = {0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e'}
27
+ self.label2score = {'a': 5, 'b': 4, 'c': 3, 'd': 2, 'e': 1} # 用于计算平均值
28
+
29
+ def _extract_relevant_sentences(self, text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  """
31
+ 根据关键词提取相关的句子
 
 
32
  """
33
+ # 使用正则表达式按标点符号分割句子,更准确
34
+ sentences = re.split(r'[。!?]', text)
35
  relevant_sentences = []
 
 
 
 
 
 
 
36
  for sentence in sentences:
37
+ if not sentence:
38
+ continue
39
+ for keyword in self.keywords:
40
+ if keyword in sentence:
41
+ relevant_sentences.append(sentence)
42
+ break # 找到一个关键词就添加,避免重复
43
+ return relevant_sentences
44
+
45
+ def _predict_single_sentence(self, sentence):
46
+ """
47
+ 对单个句子进行模型推理,返回预测的等级标签。
48
+ """
49
+ # 使用分器处理文本
50
+ inputs = self.tokenizer(sentence, return_tensors="np", padding='max_length', truncation=True, max_length=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # 准备ONNX模型的输入
53
+ ort_inputs = {self.ort_session.get_inputs()[0].name: inputs['input_ids']}
54
+
55
+ # 执行推理
56
+ ort_outs = self.ort_session.run(None, ort_inputs)
57
+
58
+ # 处理输出结果
59
+ prediction = np.argmax(ort_outs[0], axis=1)[0]
60
+ return self.id2label[prediction]
61
+
62
+ def predict(self, text):
63
  """
64
+ 执行完整的预测流程:提取句子 -> 逐句评分 -> 计算平均等级。
65
+ 这是暴露给app.py调用的主方法。
 
66
  """
67
+ # 步骤1: 提取包含关键词的句子
68
+ relevant_sentences = self._extract_relevant_sentences(text)
69
+
70
+ if not relevant_sentences:
71
+ return {
72
+ "grade": "c", # 如果没有找到相关句子,返回一个默认的中间等级
73
+ "summary": "文本中未检测到可用于评价的关键词句,无法进行有效分析。",
74
+ "analyzed_sentences_count": 0
75
+ }
76
+
77
+ # 步骤2: 对每个相关句子进行评分
78
+ scores = []
79
+ for sentence in relevant_sentences:
80
+ label = self._predict_single_sentence(sentence)
81
+ scores.append(self.label2score[label])
82
+
83
+ # 步骤3: 计算平均分并转换为最终等级
84
+ if not scores:
85
+ return {
86
+ "grade": "c",
87
+ "summary": "虽然找到相关句子,但模型未能给出评分。",
88
+ "analyzed_sentences_count": len(relevant_sentences)
89
+ }
90
+
91
+ average_score = sum(scores) / len(scores)
92
 
93
+ # 将平均分四舍五入后映射回最终等级
94
+ final_grade = ""
95
+ if average_score >= 4.5:
96
+ final_grade = "a"
97
+ elif average_score >= 3.5:
98
+ final_grade = "b"
99
+ elif average_score >= 2.5:
100
+ final_grade = "c"
101
+ elif average_score >= 1.5:
102
+ final_grade = "d"
103
+ else:
104
+ final_grade = "e"
105
 
106
+ # 步骤4: 生成总结性文本
107
+ summary = f"系统分析了 {len(relevant_sentences)} 个关键句子,综合评定等级为“{final_grade.upper()}”。"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ return {
110
+ "grade": final_grade,
111
+ "summary": summary,
112
+ "analyzed_sentences_count": len(relevant_sentences)
113
+ }