""" 正确性评估模块 使用Sentence-BERT计算语义相似度 """ import json from typing import Dict, List, Optional from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import numpy as np from config import MODEL_CONFIG, DATA_CONFIG, EVALUATION_CONFIG class CorrectnessEvaluator: """正确性评估器""" def __init__(self): """初始化评估器""" print("正在加载Sentence-BERT模型...") self.model = SentenceTransformer(MODEL_CONFIG["sentence_transformer_model"]) print("Sentence-BERT模型加载完成") # 加载知识库 self.knowledge_base = self._load_knowledge_base() print(f"已加载 {len(self.knowledge_base)} 条知识库Q&A") # 预计算知识库embeddings print("正在预计算知识库embeddings...") self.qa_embeddings = self.model.encode( [qa['question'] for qa in self.knowledge_base] ) print("知识库embeddings计算完成") def _load_knowledge_base(self) -> List[Dict]: """加载知识库Q&A""" try: with open(DATA_CONFIG["knowledge_base_path"], 'r', encoding='utf-8') as f: data = json.load(f) knowledge_base = [] # 合并场景型和知识型Q&A if "scenario_based" in data: knowledge_base.extend(data["scenario_based"]) if "knowledge_based" in data: knowledge_base.extend(data["knowledge_based"]) return knowledge_base except FileNotFoundError: print(f"警告: 知识库文件不存在: {DATA_CONFIG['knowledge_base_path']}") return [] def evaluate_turn( self, utterance: str, context: Optional[Dict] = None ) -> Dict: """ 评估单轮对话的正确性 Args: utterance: HR Assistant的回答 context: 上下文信息 Returns: 评估结果字典 """ if not self.knowledge_base: return { "similarity": 0.0, "matched_qa": None, "level": "unknown", "is_correct": False, "reason": "知识库未加载", "error": "知识库未加载" } # 计算当前回答的embedding utterance_embedding = self.model.encode([utterance]) # 计算与知识库的相似度 similarities = cosine_similarity( utterance_embedding, self.qa_embeddings )[0] # 获取最高分 max_similarity = similarities.max() matched_idx = similarities.argmax() # 判定等级 threshold = EVALUATION_CONFIG["correctness"]["threshold"] partial_threshold = EVALUATION_CONFIG["correctness"]["partial_threshold"] if max_similarity >= threshold: level = "correct" reason = f"回答与标准答案高度匹配 (相似度: {max_similarity:.2f})" elif max_similarity >= partial_threshold: level = "partial" reason = f"回答与标准答案部分匹配 (相似度: {max_similarity:.2f})" else: level = "incorrect" reason = f"回答与标准答案不匹配 (相似度: {max_similarity:.2f})" return { "similarity": float(max_similarity), "matched_qa": self.knowledge_base[matched_idx], "matched_qa_id": self.knowledge_base[matched_idx].get('id'), "level": level, "is_correct": level == "correct", "reason": reason } def evaluate_dialogue( self, dialogue: List[Dict], speaker_filter: Optional[str] = "HR Assistant" ) -> Dict: """ 评估整段对话的正确性 Args: dialogue: 对话列表 speaker_filter: 只评估特定说话人(默认HR Assistant) Returns: 整体评估结果 """ results = [] for turn in dialogue: # 只评估HR Assistant的回答 if speaker_filter and turn.get("speaker") != speaker_filter: continue if "utterance" in turn: result = self.evaluate_turn( turn["utterance"], {"turn_id": turn.get("turn_id")} ) result["turn_id"] = turn.get("turn_id") results.append(result) if not results: return { "score": 0.0, "avg_score": 0.0, "level": "unknown", "turns_evaluated": 0, "correct_turns": 0, "partial_turns": 0, "incorrect_turns": 0, "error_count": 0, "details": [] } # 统计 avg_score = sum(r["similarity"] for r in results) / len(results) correct_turns = sum(1 for r in results if r["level"] == "correct") partial_turns = sum(1 for r in results if r["level"] == "partial") incorrect_turns = sum(1 for r in results if r["level"] == "incorrect") # 判定整体等级(使用配置中的阈值) if avg_score >= EVALUATION_CONFIG["correctness"]["threshold"]: overall_level = "good" elif avg_score >= EVALUATION_CONFIG["correctness"]["partial_threshold"]: overall_level = "fair" else: overall_level = "poor" return { "score": round(avg_score * 100, 2), "avg_score": round(avg_score, 4), "level": overall_level, "turns_evaluated": len(results), "correct_turns": correct_turns, "partial_turns": partial_turns, "incorrect_turns": incorrect_turns, "error_count": incorrect_turns, "details": results } # 测试代码 if __name__ == "__main__": evaluator = CorrectnessEvaluator() # 测试单轮评估 test_utterance = "请问有什么可以帮您?" result = evaluator.evaluate_turn(test_utterance) print("单轮评估结果:") print(f" 相似度: {result['similarity']:.4f}") print(f" 等级: {result['level']}") print(f" 匹配Q&A: {result.get('matched_qa', {}).get('question', 'N/A')}")