Spaces:
Sleeping
Sleeping
Update predictor.py
Browse files- predictor.py +36 -8
predictor.py
CHANGED
|
@@ -32,13 +32,28 @@ class SentenceExtractor:
|
|
| 32 |
|
| 33 |
def _preprocess_text(self, text: str) -> np.ndarray:
|
| 34 |
"""
|
| 35 |
-
预处理文本为模型输入格式
|
| 36 |
"""
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def _predict_grade_with_model(self, text: str) -> str:
|
| 44 |
"""
|
|
@@ -47,8 +62,20 @@ class SentenceExtractor:
|
|
| 47 |
try:
|
| 48 |
if not self.ort_session:
|
| 49 |
return "C"
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
predictions = outputs[0]
|
| 53 |
grade_index = int(np.argmax(predictions))
|
| 54 |
grades = ['A', 'B', 'C', 'D', 'E']
|
|
@@ -261,4 +288,5 @@ if __name__ == "__main__":
|
|
| 261 |
|
| 262 |
for i, item in enumerate(result['scored_sentences'], 1):
|
| 263 |
print(f"句子{i}加评分等级:{item['sentence']} - {item['grade']}")
|
|
|
|
| 264 |
|
|
|
|
| 32 |
|
| 33 |
def _preprocess_text(self, text: str) -> np.ndarray:
|
| 34 |
"""
|
| 35 |
+
预处理文本为模型输入格式 - 使用BERT tokenizer
|
| 36 |
"""
|
| 37 |
+
try:
|
| 38 |
+
from transformers import AutoTokenizer
|
| 39 |
+
# 使用与学生模型相同的tokenizer
|
| 40 |
+
tokenizer = AutoTokenizer.from_pretrained("uer/chinese_roberta_L-4_H-256")
|
| 41 |
+
inputs = tokenizer(
|
| 42 |
+
text,
|
| 43 |
+
truncation=True,
|
| 44 |
+
padding=True,
|
| 45 |
+
max_length=512,
|
| 46 |
+
return_tensors='np'
|
| 47 |
+
)
|
| 48 |
+
return inputs
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f"Tokenizer预处理失败: {e}")
|
| 51 |
+
# 降级到简单字符编码
|
| 52 |
+
max_seq_length = 128
|
| 53 |
+
features = np.zeros((1, max_seq_length), dtype=np.float32)
|
| 54 |
+
for i, ch in enumerate(text[:max_seq_length]):
|
| 55 |
+
features[0, i] = (ord(ch) % 256) / 255.0
|
| 56 |
+
return features
|
| 57 |
|
| 58 |
def _predict_grade_with_model(self, text: str) -> str:
|
| 59 |
"""
|
|
|
|
| 62 |
try:
|
| 63 |
if not self.ort_session:
|
| 64 |
return "C"
|
| 65 |
+
inputs = self._preprocess_text(text)
|
| 66 |
+
|
| 67 |
+
# 检查是否是tokenizer输出格式
|
| 68 |
+
if isinstance(inputs, dict) and 'input_ids' in inputs:
|
| 69 |
+
# BERT tokenizer格式
|
| 70 |
+
input_data = {
|
| 71 |
+
'input_ids': inputs['input_ids'],
|
| 72 |
+
'attention_mask': inputs['attention_mask']
|
| 73 |
+
}
|
| 74 |
+
else:
|
| 75 |
+
# 简单字符编码格式
|
| 76 |
+
input_data = {self.input_name: inputs}
|
| 77 |
+
|
| 78 |
+
outputs = self.ort_session.run([self.output_name], input_data)
|
| 79 |
predictions = outputs[0]
|
| 80 |
grade_index = int(np.argmax(predictions))
|
| 81 |
grades = ['A', 'B', 'C', 'D', 'E']
|
|
|
|
| 288 |
|
| 289 |
for i, item in enumerate(result['scored_sentences'], 1):
|
| 290 |
print(f"句子{i}加评分等级:{item['sentence']} - {item['grade']}")
|
| 291 |
+
|
| 292 |
|