Spaces:
Sleeping
Sleeping
| """ | |
| 正确性评估模块 | |
| 使用Sentence-BERT计算语义相似度 | |
| """ | |
| import json | |
| from typing import Dict, List, Optional | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import numpy as np | |
| from config import MODEL_CONFIG, DATA_CONFIG, EVALUATION_CONFIG | |
| class CorrectnessEvaluator: | |
| """正确性评估器""" | |
| def __init__(self): | |
| """初始化评估器""" | |
| print("正在加载Sentence-BERT模型...") | |
| self.model = SentenceTransformer(MODEL_CONFIG["sentence_transformer_model"]) | |
| print("Sentence-BERT模型加载完成") | |
| # 加载知识库 | |
| self.knowledge_base = self._load_knowledge_base() | |
| print(f"已加载 {len(self.knowledge_base)} 条知识库Q&A") | |
| # 预计算知识库embeddings | |
| print("正在预计算知识库embeddings...") | |
| self.qa_embeddings = self.model.encode( | |
| [qa['question'] for qa in self.knowledge_base] | |
| ) | |
| print("知识库embeddings计算完成") | |
| def _load_knowledge_base(self) -> List[Dict]: | |
| """加载知识库Q&A""" | |
| try: | |
| with open(DATA_CONFIG["knowledge_base_path"], 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| knowledge_base = [] | |
| # 合并场景型和知识型Q&A | |
| if "scenario_based" in data: | |
| knowledge_base.extend(data["scenario_based"]) | |
| if "knowledge_based" in data: | |
| knowledge_base.extend(data["knowledge_based"]) | |
| return knowledge_base | |
| except FileNotFoundError: | |
| print(f"警告: 知识库文件不存在: {DATA_CONFIG['knowledge_base_path']}") | |
| return [] | |
| def evaluate_turn( | |
| self, | |
| utterance: str, | |
| context: Optional[Dict] = None | |
| ) -> Dict: | |
| """ | |
| 评估单轮对话的正确性 | |
| Args: | |
| utterance: HR Assistant的回答 | |
| context: 上下文信息 | |
| Returns: | |
| 评估结果字典 | |
| """ | |
| if not self.knowledge_base: | |
| return { | |
| "similarity": 0.0, | |
| "matched_qa": None, | |
| "level": "unknown", | |
| "is_correct": False, | |
| "reason": "知识库未加载", | |
| "error": "知识库未加载" | |
| } | |
| # 计算当前回答的embedding | |
| utterance_embedding = self.model.encode([utterance]) | |
| # 计算与知识库的相似度 | |
| similarities = cosine_similarity( | |
| utterance_embedding, | |
| self.qa_embeddings | |
| )[0] | |
| # 获取最高分 | |
| max_similarity = similarities.max() | |
| matched_idx = similarities.argmax() | |
| # 判定等级 | |
| threshold = EVALUATION_CONFIG["correctness"]["threshold"] | |
| partial_threshold = EVALUATION_CONFIG["correctness"]["partial_threshold"] | |
| if max_similarity >= threshold: | |
| level = "correct" | |
| reason = f"回答与标准答案高度匹配 (相似度: {max_similarity:.2f})" | |
| elif max_similarity >= partial_threshold: | |
| level = "partial" | |
| reason = f"回答与标准答案部分匹配 (相似度: {max_similarity:.2f})" | |
| else: | |
| level = "incorrect" | |
| reason = f"回答与标准答案不匹配 (相似度: {max_similarity:.2f})" | |
| return { | |
| "similarity": float(max_similarity), | |
| "matched_qa": self.knowledge_base[matched_idx], | |
| "matched_qa_id": self.knowledge_base[matched_idx].get('id'), | |
| "level": level, | |
| "is_correct": level == "correct", | |
| "reason": reason | |
| } | |
| def evaluate_dialogue( | |
| self, | |
| dialogue: List[Dict], | |
| speaker_filter: Optional[str] = "HR Assistant" | |
| ) -> Dict: | |
| """ | |
| 评估整段对话的正确性 | |
| Args: | |
| dialogue: 对话列表 | |
| speaker_filter: 只评估特定说话人(默认HR Assistant) | |
| Returns: | |
| 整体评估结果 | |
| """ | |
| results = [] | |
| for turn in dialogue: | |
| # 只评估HR Assistant的回答 | |
| if speaker_filter and turn.get("speaker") != speaker_filter: | |
| continue | |
| if "utterance" in turn: | |
| result = self.evaluate_turn( | |
| turn["utterance"], | |
| {"turn_id": turn.get("turn_id")} | |
| ) | |
| result["turn_id"] = turn.get("turn_id") | |
| results.append(result) | |
| if not results: | |
| return { | |
| "score": 0.0, | |
| "avg_score": 0.0, | |
| "level": "unknown", | |
| "turns_evaluated": 0, | |
| "correct_turns": 0, | |
| "partial_turns": 0, | |
| "incorrect_turns": 0, | |
| "error_count": 0, | |
| "details": [] | |
| } | |
| # 统计 | |
| avg_score = sum(r["similarity"] for r in results) / len(results) | |
| correct_turns = sum(1 for r in results if r["level"] == "correct") | |
| partial_turns = sum(1 for r in results if r["level"] == "partial") | |
| incorrect_turns = sum(1 for r in results if r["level"] == "incorrect") | |
| # 判定整体等级(使用配置中的阈值) | |
| if avg_score >= EVALUATION_CONFIG["correctness"]["threshold"]: | |
| overall_level = "good" | |
| elif avg_score >= EVALUATION_CONFIG["correctness"]["partial_threshold"]: | |
| overall_level = "fair" | |
| else: | |
| overall_level = "poor" | |
| return { | |
| "score": round(avg_score * 100, 2), | |
| "avg_score": round(avg_score, 4), | |
| "level": overall_level, | |
| "turns_evaluated": len(results), | |
| "correct_turns": correct_turns, | |
| "partial_turns": partial_turns, | |
| "incorrect_turns": incorrect_turns, | |
| "error_count": incorrect_turns, | |
| "details": results | |
| } | |
| # 测试代码 | |
| if __name__ == "__main__": | |
| evaluator = CorrectnessEvaluator() | |
| # 测试单轮评估 | |
| test_utterance = "请问有什么可以帮您?" | |
| result = evaluator.evaluate_turn(test_utterance) | |
| print("单轮评估结果:") | |
| print(f" 相似度: {result['similarity']:.4f}") | |
| print(f" 等级: {result['level']}") | |
| print(f" 匹配Q&A: {result.get('matched_qa', {}).get('question', 'N/A')}") | |