Wind-xixi commited on
Commit
566e91d
·
verified ·
1 Parent(s): 55af840

Update predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +126 -98
predictor.py CHANGED
@@ -1,113 +1,141 @@
1
- # predictor.py
2
-
3
  import json
4
- import numpy as np
5
- import onnxruntime as ort
6
- from transformers import BertTokenizer
7
  import re
 
8
 
9
- class Predictor:
10
- def __init__(self):
11
  """
12
- 在服务启动时,一次性加载所有必要的模型和文件。
 
 
13
  """
14
- # 1. 加载分词器 (Tokenizer)
15
- # Hugging Face Spaces会自动下载git仓库中的所有文件到当前目录
16
- self.tokenizer = BertTokenizer.from_pretrained('.')
17
-
18
- # 2. 加载ONNX模型并创建推理会话
19
- self.ort_session = ort.InferenceSession('model_quantized.onnx')
20
-
21
- # 3. 加载关键词词集
22
- with open('evaluation_keywords2.json', 'r', encoding='utf-8') as f:
23
- self.keywords = json.load(f)
24
 
25
- # 4. 定义等级映射
26
- self.id2label = {0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e'}
27
- self.label2score = {'a': 5, 'b': 4, 'c': 3, 'd': 2, 'e': 1} # 用于计算平均值
28
-
29
- def _extract_relevant_sentences(self, text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  """
31
- 根据关键词提取相关的句子。
 
 
32
  """
33
- # 使用正则表达式按标点符号分割句子,更准确
34
- sentences = re.split(r'[。!?]', text)
35
  relevant_sentences = []
36
- for sentence in sentences:
37
- if not sentence:
38
- continue
39
- for keyword in self.keywords:
40
- if keyword in sentence:
41
- relevant_sentences.append(sentence)
42
- break # 找到一个关键词就添加,避免重复
43
- return relevant_sentences
44
-
45
- def _predict_single_sentence(self, sentence):
46
- """
47
- 对单个句子进行模型推理,返回预测的等级标签。
48
- """
49
- # 使用分词器处理文本
50
- inputs = self.tokenizer(sentence, return_tensors="np", padding='max_length', truncation=True, max_length=128)
51
-
52
- # 准备ONNX模型的输入
53
- ort_inputs = {self.ort_session.get_inputs()[0].name: inputs['input_ids']}
54
 
55
- # 执行推理
56
- ort_outs = self.ort_session.run(None, ort_inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # 处理输出结果
59
- prediction = np.argmax(ort_outs[0], axis=1)[0]
60
- return self.id2label[prediction]
61
-
62
- def predict(self, text):
63
  """
64
- 执行完整的预测流程:提取句子 -> 逐句评分 -> 计算平均等级。
65
- 这是暴露给app.py调用的主方法。
 
66
  """
67
- # 步骤1: 提取包含关键词的句子
68
- relevant_sentences = self._extract_relevant_sentences(text)
69
-
70
- if not relevant_sentences:
71
- return {
72
- "grade": "c", # 如果没有找到相关句子,返回一个默认的中间等级
73
- "summary": "文本中未检测到可用于评价的关键词句,无法进行有效分析。",
74
- "analyzed_sentences_count": 0
75
- }
76
-
77
- # 步骤2: 对每个相关句子进行评分
78
- scores = []
79
- for sentence in relevant_sentences:
80
- label = self._predict_single_sentence(sentence)
81
- scores.append(self.label2score[label])
82
-
83
- # 步骤3: 计算平均分并转换为最终等级
84
- if not scores:
85
- return {
86
- "grade": "c",
87
- "summary": "虽然找到相关句子,但模型未能给出评分。",
88
- "analyzed_sentences_count": len(relevant_sentences)
89
- }
90
-
91
- average_score = sum(scores) / len(scores)
92
 
93
- # 将平均分四舍五入后映射回最终等级
94
- final_grade = ""
95
- if average_score >= 4.5:
96
- final_grade = "a"
97
- elif average_score >= 3.5:
98
- final_grade = "b"
99
- elif average_score >= 2.5:
100
- final_grade = "c"
101
- elif average_score >= 1.5:
102
- final_grade = "d"
103
- else:
104
- final_grade = "e"
105
-
106
- # 步骤4: 生成总结性文本
107
- summary = f"系统分析了 {len(relevant_sentences)} 个关键句子,综合评定等级为“{final_grade.upper()}”。"
108
-
109
  return {
110
- "grade": final_grade,
111
- "summary": summary,
112
- "analyzed_sentences_count": len(relevant_sentences)
113
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
 
 
 
2
  import re
3
+ from typing import List, Dict, Set, Tuple
4
 
5
+ class SentenceExtractor:
6
+ def __init__(self, main_keywords_path: str, eval_keywords_path: str):
7
  """
8
+ 初始化句子提取器,加载主关键词和评估关键词库
9
+ :param main_keywords_path: 主关键词JSON文件路径
10
+ :param eval_keywords_path: 评估关键词库(JSON)文件路径
11
  """
12
+ self.main_keywords = self._load_keywords(main_keywords_path)
13
+ self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
14
+ # 合并所有关键词用于快速查找
15
+ self.all_keywords = self._merge_all_keywords()
 
 
 
 
 
 
16
 
17
+ def _load_keywords(self, file_path: str) -> Dict[str, List[str]]:
18
+ """加载主关键词文件"""
19
+ try:
20
+ with open(file_path, 'r', encoding='utf-8') as f:
21
+ return json.load(f)
22
+ except Exception as e:
23
+ print(f"加载主关键词文件失败: {e}")
24
+ return {}
25
+
26
+ def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]:
27
+ """加载评估关键词库(evaluation_keywords2.json)"""
28
+ try:
29
+ with open(file_path, 'r', encoding='utf-8') as f:
30
+ return json.load(f)
31
+ except Exception as e:
32
+ print(f"加载评估关键词库失败: {e}")
33
+ return {}
34
+
35
+ def _merge_all_keywords(self) -> Set[str]:
36
+ """合并所有关键词到一个集合中,用于快速查找"""
37
+ keywords_set = set()
38
+
39
+ # 添加主关键词
40
+ for category, keywords in self.main_keywords.items():
41
+ keywords_set.update(keywords)
42
+
43
+ # 添加评估关键词库中的所有关键词
44
+ for category, types in self.eval_keywords.items():
45
+ for type_, keywords in types.items():
46
+ keywords_set.update(keywords)
47
+
48
+ return keywords_set
49
+
50
+ def _split_into_sentences(self, text: str) -> List[str]:
51
+ """将文本分割成句子"""
52
+ # 简单的句子分割正则,可根据需要优化
53
+ sentence_endings = re.compile(r'(?<=[。!?,.!?])\s+')
54
+ sentences = sentence_endings.split(text)
55
+ return [s.strip() for s in sentences if s.strip()]
56
+
57
+ def _extract_relevant_sentences(self, text: str) -> Tuple[List[str], Dict[str, List[str]]]:
58
  """
59
+ 提取与关键词相关的句子
60
+ :param text: 输入文本
61
+ :return: 相关句子列表和按类别分组的句子字典
62
  """
63
+ sentences = self._split_into_sentences(text)
 
64
  relevant_sentences = []
65
+ categorized_sentences = {
66
+ "main": [],
67
+ "student_performance": {"positive": [], "negative": [], "nature": [], "suggestion": []},
68
+ "content_quality": {"positive": [], "negative": [], "nature": [], "suggestion": []},
69
+ "cross_scene": {"positive": [], "negative": [], "nature": [], "suggestion": []}
70
+ }
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ for sentence in sentences:
73
+ # 检查是否包含主关键词
74
+ main_keyword_matched = False
75
+ for category, keywords in self.main_keywords.items():
76
+ for keyword in keywords:
77
+ if keyword in sentence:
78
+ relevant_sentences.append(sentence)
79
+ categorized_sentences["main"].append(sentence)
80
+ main_keyword_matched = True
81
+ break
82
+ if main_keyword_matched:
83
+ break
84
+
85
+ # 检查评估关键词库中的关键词
86
+ for category in ["student_performance", "content_quality", "cross_scene"]:
87
+ if category not in self.eval_keywords:
88
+ continue
89
+
90
+ for sentiment in ["positive", "negative", "nature", "suggestion"]:
91
+ if sentiment not in self.eval_keywords[category]:
92
+ continue
93
+
94
+ for keyword in self.eval_keywords[category][sentiment]:
95
+ if keyword in sentence and sentence not in categorized_sentences[category][sentiment]:
96
+ # 如果还没添加到相关句子列表,则添加
97
+ if sentence not in relevant_sentences:
98
+ relevant_sentences.append(sentence)
99
+ categorized_sentences[category][sentiment].append(sentence)
100
 
101
+ return relevant_sentences, categorized_sentences
102
+
103
+ def extract(self, text: str) -> Dict[str, any]:
 
 
104
  """
105
+ 提取文本中与关键词相关的句子
106
+ :param text: 输入文本
107
+ :return: 包含相关句子和分类信息的字典
108
  """
109
+ if not text:
110
+ return {"relevant_sentences": [], "categorized_sentences": {}}
111
+
112
+ relevant_sentences, categorized_sentences = self._extract_relevant_sentences(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  return {
115
+ "relevant_sentences": relevant_sentences,
116
+ "categorized_sentences": categorized_sentences,
117
+ "count": len(relevant_sentences)
118
+ }
119
+
120
+ # 使用示例
121
+ if __name__ == "__main__":
122
+ # 假设主关键词文件名为main_keywords.json
123
+ extractor = SentenceExtractor(
124
+ main_keywords_path="main_keywords.json",
125
+ eval_keywords_path="evaluation_keywords2.json"
126
+ )
127
+
128
+ sample_text = """
129
+ 该学生表现优异,团队合作能力强,在项目中展现了很强的创新能力。
130
+ 但代码质量不高,存在安全漏洞,需要加强测试。
131
+ 项目文档完整,符合行业标准,具有很好的应用价值。
132
+ 建议加强代码审查,提高系统安全性,优化算法效率。
133
+ """
134
+
135
+ result = extractor.extract(sample_text)
136
+ print(f"提取到 {result['count']} 个相关句子:")
137
+ for i, sent in enumerate(result['relevant_sentences'], 1):
138
+ print(f"{i}. {sent}")
139
+
140
+ print("\n按类别分组:")
141
+ print(json.dumps(result['categorized_sentences'], ensure_ascii=False, indent=2))