hr-eval-api-v2 / models /correctness.py
KarenYYH
Initial commit - HR Evaluation API v2
c8b1f17
"""
正确性评估模块
使用Sentence-BERT计算语义相似度
"""
import json
from typing import Dict, List, Optional
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from config import MODEL_CONFIG, DATA_CONFIG, EVALUATION_CONFIG
class CorrectnessEvaluator:
"""正确性评估器"""
def __init__(self):
"""初始化评估器"""
print("正在加载Sentence-BERT模型...")
self.model = SentenceTransformer(MODEL_CONFIG["sentence_transformer_model"])
print("Sentence-BERT模型加载完成")
# 加载知识库
self.knowledge_base = self._load_knowledge_base()
print(f"已加载 {len(self.knowledge_base)} 条知识库Q&A")
# 预计算知识库embeddings
print("正在预计算知识库embeddings...")
self.qa_embeddings = self.model.encode(
[qa['question'] for qa in self.knowledge_base]
)
print("知识库embeddings计算完成")
def _load_knowledge_base(self) -> List[Dict]:
"""加载知识库Q&A"""
try:
with open(DATA_CONFIG["knowledge_base_path"], 'r', encoding='utf-8') as f:
data = json.load(f)
knowledge_base = []
# 合并场景型和知识型Q&A
if "scenario_based" in data:
knowledge_base.extend(data["scenario_based"])
if "knowledge_based" in data:
knowledge_base.extend(data["knowledge_based"])
return knowledge_base
except FileNotFoundError:
print(f"警告: 知识库文件不存在: {DATA_CONFIG['knowledge_base_path']}")
return []
def evaluate_turn(
self,
utterance: str,
context: Optional[Dict] = None
) -> Dict:
"""
评估单轮对话的正确性
Args:
utterance: HR Assistant的回答
context: 上下文信息
Returns:
评估结果字典
"""
if not self.knowledge_base:
return {
"similarity": 0.0,
"matched_qa": None,
"level": "unknown",
"is_correct": False,
"reason": "知识库未加载",
"error": "知识库未加载"
}
# 计算当前回答的embedding
utterance_embedding = self.model.encode([utterance])
# 计算与知识库的相似度
similarities = cosine_similarity(
utterance_embedding,
self.qa_embeddings
)[0]
# 获取最高分
max_similarity = similarities.max()
matched_idx = similarities.argmax()
# 判定等级
threshold = EVALUATION_CONFIG["correctness"]["threshold"]
partial_threshold = EVALUATION_CONFIG["correctness"]["partial_threshold"]
if max_similarity >= threshold:
level = "correct"
reason = f"回答与标准答案高度匹配 (相似度: {max_similarity:.2f})"
elif max_similarity >= partial_threshold:
level = "partial"
reason = f"回答与标准答案部分匹配 (相似度: {max_similarity:.2f})"
else:
level = "incorrect"
reason = f"回答与标准答案不匹配 (相似度: {max_similarity:.2f})"
return {
"similarity": float(max_similarity),
"matched_qa": self.knowledge_base[matched_idx],
"matched_qa_id": self.knowledge_base[matched_idx].get('id'),
"level": level,
"is_correct": level == "correct",
"reason": reason
}
def evaluate_dialogue(
self,
dialogue: List[Dict],
speaker_filter: Optional[str] = "HR Assistant"
) -> Dict:
"""
评估整段对话的正确性
Args:
dialogue: 对话列表
speaker_filter: 只评估特定说话人(默认HR Assistant)
Returns:
整体评估结果
"""
results = []
for turn in dialogue:
# 只评估HR Assistant的回答
if speaker_filter and turn.get("speaker") != speaker_filter:
continue
if "utterance" in turn:
result = self.evaluate_turn(
turn["utterance"],
{"turn_id": turn.get("turn_id")}
)
result["turn_id"] = turn.get("turn_id")
results.append(result)
if not results:
return {
"score": 0.0,
"avg_score": 0.0,
"level": "unknown",
"turns_evaluated": 0,
"correct_turns": 0,
"partial_turns": 0,
"incorrect_turns": 0,
"error_count": 0,
"details": []
}
# 统计
avg_score = sum(r["similarity"] for r in results) / len(results)
correct_turns = sum(1 for r in results if r["level"] == "correct")
partial_turns = sum(1 for r in results if r["level"] == "partial")
incorrect_turns = sum(1 for r in results if r["level"] == "incorrect")
# 判定整体等级(使用配置中的阈值)
if avg_score >= EVALUATION_CONFIG["correctness"]["threshold"]:
overall_level = "good"
elif avg_score >= EVALUATION_CONFIG["correctness"]["partial_threshold"]:
overall_level = "fair"
else:
overall_level = "poor"
return {
"score": round(avg_score * 100, 2),
"avg_score": round(avg_score, 4),
"level": overall_level,
"turns_evaluated": len(results),
"correct_turns": correct_turns,
"partial_turns": partial_turns,
"incorrect_turns": incorrect_turns,
"error_count": incorrect_turns,
"details": results
}
# 测试代码
if __name__ == "__main__":
evaluator = CorrectnessEvaluator()
# 测试单轮评估
test_utterance = "请问有什么可以帮您?"
result = evaluator.evaluate_turn(test_utterance)
print("单轮评估结果:")
print(f" 相似度: {result['similarity']:.4f}")
print(f" 等级: {result['level']}")
print(f" 匹配Q&A: {result.get('matched_qa', {}).get('question', 'N/A')}")