""" GraphRAG检索器 实现基于知识图谱的检索策略,包括本地查询和全局查询 """ from typing import List, Dict, Set, Tuple import time import networkx as nx try: from langchain_core.documents import Document except ImportError: try: from langchain_core.documents import Document except ImportError: from langchain.schema import Document try: from langchain_core.prompts import PromptTemplate except ImportError: try: from langchain_core.prompts import PromptTemplate except ImportError: from langchain.prompts import PromptTemplate from langchain_community.chat_models import ChatOllama from langchain_core.output_parsers import StrOutputParser, JsonOutputParser from knowledge_graph import KnowledgeGraph from config import LOCAL_LLM from retrieval_evaluation import RetrievalEvaluator, RetrievalResult from routers_and_graders import HallucinationGrader class GraphRetriever: """基于知识图谱的检索器""" def __init__(self, knowledge_graph: KnowledgeGraph): self.kg = knowledge_graph self.llm = ChatOllama(model=LOCAL_LLM, temperature=0.3) self.hallucination_grader = HallucinationGrader() # 实体识别提示 self.entity_recognition_prompt = PromptTemplate( template="""从以下问题中识别关键实体和概念: 问题: {question} 已知实体示例: {sample_entities} 请识别问题中提到的实体,返回JSON格式: {{ "entities": ["实体1", "实体2", ...] }} 只返回JSON,不要其他内容。 """, input_variables=["question", "sample_entities"] ) # 全局查询生成提示 self.global_query_prompt = PromptTemplate( template="""你是一个知识图谱分析专家。基于以下社区摘要,回答用户问题。 用户问题: {question} 相关社区摘要: {community_summaries} 请基于这些摘要提供一个综合性的答案。如果摘要中没有相关信息,请说明。 答案: """, input_variables=["question", "community_summaries"] ) # 本地查询生成提示 self.local_query_prompt = PromptTemplate( template="""基于以下实体及其关系信息,回答用户问题。 用户问题: {question} 相关实体信息: {entity_info} 实体间的关系: {relations} 请基于这些信息提供答案。 答案: """, input_variables=["question", "entity_info", "relations"] ) self.entity_recognition_chain = self.entity_recognition_prompt | self.llm | JsonOutputParser() self.global_query_chain = self.global_query_prompt | self.llm | StrOutputParser() self.local_query_chain = self.local_query_prompt | self.llm | StrOutputParser() def recognize_entities(self, question: str) -> List[str]: """ 从问题中识别实体 Args: question: 用户问题 Returns: 识别到的实体列表 """ # 获取一些示例实体 sample_entities = list(self.kg.entities.keys())[:10] sample_text = ", ".join(sample_entities) try: result = self.entity_recognition_chain.invoke({ "question": question, "sample_entities": sample_text }) entities = result.get("entities", []) # 匹配到图谱中的实体 matched_entities = [] for entity in entities: # 精确匹配 if entity in self.kg.entities: matched_entities.append(entity) else: # 模糊匹配 for kg_entity in self.kg.entities.keys(): if entity.lower() in kg_entity.lower() or kg_entity.lower() in entity.lower(): matched_entities.append(kg_entity) break print(f"🔍 识别到实体: {matched_entities}") return matched_entities except Exception as e: print(f"❌ 实体识别失败: {e}") return [] def _normalize_map(self, values: Dict[str, float], keys: List[str]) -> Dict[str, float]: arr = [values.get(k, 0.0) for k in keys] if not arr: return {k: 0.0 for k in keys} mn = min(arr) mx = max(arr) if mx == mn: return {k: 0.5 for k in keys} return {k: (values.get(k, 0.0) - mn) / (mx - mn) for k in keys} def _rank_entities(self, mentioned_entities: List[str], candidate_entities: List[str]) -> List[str]: G = self.kg.graph nodes = list(set(candidate_entities) | set(mentioned_entities)) if not nodes: return [] subG = G.subgraph(nodes) deg = nx.degree_centrality(subG) btw = nx.betweenness_centrality(subG, normalized=True) weight_to_mentioned = {} path_prox = {} for n in candidate_entities: w_sum = 0.0 best_len = None for m in mentioned_entities: if G.has_edge(n, m): data = G.get_edge_data(n, m) if isinstance(data, dict): w_sum += float(data.get('weight', 1.0)) else: w_sum += 1.0 try: l = nx.shortest_path_length(G, source=m, target=n) if best_len is None or l < best_len: best_len = l except nx.NetworkXNoPath: pass weight_to_mentioned[n] = w_sum path_prox[n] = 0.0 if best_len is None else 1.0 / (1.0 + best_len) deg_n = self._normalize_map(deg, candidate_entities) btw_n = self._normalize_map(btw, candidate_entities) w_n = self._normalize_map(weight_to_mentioned, candidate_entities) prox_n = self._normalize_map(path_prox, candidate_entities) scores = {} for n in candidate_entities: scores[n] = 0.3 * deg_n.get(n, 0.0) + 0.3 * btw_n.get(n, 0.0) + 0.2 * w_n.get(n, 0.0) + 0.2 * prox_n.get(n, 0.0) ranked = sorted(candidate_entities, key=lambda x: scores.get(x, 0.0), reverse=True) return ranked def local_query(self, question: str, max_hops: int = 2, top_k: int = 10) -> str: """ 本地查询 - 基于问题中的实体及其邻域进行检索 适用场景: 针对特定实体的详细问题 例如: "AlphaCodium的作者是谁?" Args: question: 用户问题 max_hops: 最大跳数 top_k: 返回的最大实体数 Returns: 答案文本 """ print(f"\n🔎 执行本地查询...") # 1. 识别问题中的实体 mentioned_entities = self.recognize_entities(question) if not mentioned_entities: return "未能在知识图谱中找到相关实体。" # 2. 获取实体的邻域 relevant_entities = set() for entity in mentioned_entities: neighbors = self.kg.get_node_neighbors(entity, depth=max_hops) relevant_entities.update(neighbors) ranked_entities = self._rank_entities(mentioned_entities, list(relevant_entities)) relevant_entities = ranked_entities[:top_k] # 3. 收集实体信息 entity_info_list = [] for entity in relevant_entities: info = self.kg.get_entity_info(entity) if info: entity_info_list.append( f"- {info['name']} ({info.get('type', 'UNKNOWN')}): {info.get('description', '无描述')}" ) # 4. 收集关系信息 relation_list = [] for u, v, data in self.kg.graph.edges(data=True): if u in relevant_entities and v in relevant_entities: relation_list.append( f"- {u} --[{data.get('relation_type', 'RELATED')}]--> {v}: {data.get('description', '')}" ) entity_info_text = "\n".join(entity_info_list) if entity_info_list else "无相关实体信息" relations_text = "\n".join(relation_list[:20]) if relation_list else "无相关关系" # 5. 生成答案 try: answer = self.local_query_chain.invoke({ "question": question, "entity_info": entity_info_text, "relations": relations_text }) print(f"✅ 本地查询完成") return answer.strip() except Exception as e: print(f"❌ 本地查询失败: {e}") return "查询失败,请重试。" def global_query(self, question: str, top_k_communities: int = 5) -> str: """ 全局查询 - 基于社区摘要进行高层次查询 适用场景: 需要整体理解的概括性问题 例如: "这些文档主要讨论什么主题?" Args: question: 用户问题 top_k_communities: 使用的社区数量 Returns: 答案文本 """ print(f"\n🌍 执行全局查询...") if not self.kg.community_summaries: return "知识图谱尚未生成社区摘要,请先运行索引流程。" # 获取社区摘要 community_summaries = [] for cid, summary in list(self.kg.community_summaries.items())[:top_k_communities]: community_summaries.append(f"社区 {cid}:\n{summary}\n") summaries_text = "\n".join(community_summaries) # 生成答案 try: answer = self.global_query_chain.invoke({ "question": question, "community_summaries": summaries_text }) print(f"✅ 全局查询完成") return answer.strip() except Exception as e: print(f"❌ 全局查询失败: {e}") return "查询失败,请重试。" def local_query_with_metrics(self, question: str, max_hops: int = 2, top_k: int = 10, k_values: List[int] = [1, 3, 5]) -> tuple: print(f"\n🔎 执行本地查询并评估...") start_t = time.time() mentioned_entities = self.recognize_entities(question) if not mentioned_entities: return "未能在知识图谱中找到相关实体。", { "error": "no_entities", "latency": 0.0, "retrieved_docs_count": 0 } relevant_entities = set() for entity in mentioned_entities: neighbors = self.kg.get_node_neighbors(entity, depth=max_hops) relevant_entities.update(neighbors) ranked_entities = self._rank_entities(mentioned_entities, list(relevant_entities)) relevant_entities = ranked_entities[:top_k] entity_info_list = [] for entity in relevant_entities: info = self.kg.get_entity_info(entity) if info: entity_info_list.append(f"- {info['name']} ({info.get('type', 'UNKNOWN')}): {info.get('description', '无描述')}") relation_list = [] for u, v, data in self.kg.graph.edges(data=True): if u in relevant_entities and v in relevant_entities: relation_list.append(f"- {u} --[{data.get('relation_type', 'RELATED')}]--> {v}: {data.get('description', '')}") entity_info_text = "\n".join(entity_info_list) if entity_info_list else "无相关实体信息" relations_text = "\n".join(relation_list[:20]) if relation_list else "无相关关系" try: answer = self.local_query_chain.invoke({ "question": question, "entity_info": entity_info_text, "relations": relations_text }).strip() except Exception: answer = "查询失败,请重试。" retrieved_docs = [] for entity in relevant_entities: info = self.kg.get_entity_info(entity) or {"name": entity} content = f"{info.get('name', entity)} {info.get('type', '')} {info.get('description', '')}".strip() retrieved_docs.append(Document(page_content=content, metadata={"entity": info.get('name', entity)})) try: hallucination_grade = self.hallucination_grader.grade(answer, retrieved_docs) except Exception: hallucination_grade = "unknown" relevant_docs = [] for entity in mentioned_entities: info = self.kg.get_entity_info(entity) or {"name": entity} content = f"{info.get('name', entity)} {info.get('type', '')} {info.get('description', '')}".strip() relevant_docs.append(Document(page_content=content, metadata={"entity": info.get('name', entity)})) latency = time.time() - start_t try: evaluator = RetrievalEvaluator() result = RetrievalResult(query=question, retrieved_docs=retrieved_docs, relevant_docs=relevant_docs, retrieval_time=latency) metrics_obj = evaluator.evaluate_retrieval([result], k_values=k_values) metrics = { "precision_at_1": metrics_obj.precision_at_k.get(1, 0), "precision_at_3": metrics_obj.precision_at_k.get(3, 0), "precision_at_5": metrics_obj.precision_at_k.get(5, 0), "recall_at_1": metrics_obj.recall_at_k.get(1, 0), "recall_at_3": metrics_obj.recall_at_k.get(3, 0), "recall_at_5": metrics_obj.recall_at_k.get(5, 0), "map_score": metrics_obj.map_score, "mrr": metrics_obj.mrr, "latency": metrics_obj.latency, "retrieved_docs_count": len(retrieved_docs), "hallucination": hallucination_grade } except Exception: metrics = {"latency": latency, "retrieved_docs_count": len(retrieved_docs), "hallucination": hallucination_grade} return answer, metrics def global_query_with_metrics(self, question: str, top_k_communities: int = 5, k_values: List[int] = [1, 3, 5]) -> tuple: print(f"\n🌍 执行全局查询并评估...") start_t = time.time() mentioned_entities = self.recognize_entities(question) if not self.kg.community_summaries: return "知识图谱尚未生成社区摘要,请先运行索引流程。", { "error": "no_summaries", "latency": 0.0, "retrieved_docs_count": 0 } community_summaries = [] for cid, summary in list(self.kg.community_summaries.items())[:top_k_communities]: community_summaries.append((cid, summary)) summaries_text = "\n".join([f"社区 {cid}:\n{summary}\n" for cid, summary in community_summaries]) try: answer = self.global_query_chain.invoke({ "question": question, "community_summaries": summaries_text }).strip() except Exception: answer = "查询失败,请重试。" retrieved_docs = [] for cid, summary in community_summaries: retrieved_docs.append(Document(page_content=summary, metadata={"community_id": str(cid)})) try: hallucination_grade = self.hallucination_grader.grade(answer, retrieved_docs) except Exception: hallucination_grade = "unknown" relevant_docs = [] query_tokens = [t for t in question.split() if t] for cid, summary in community_summaries: ok = False for ent in mentioned_entities: if ent and ent.lower() in summary.lower(): ok = True break if not ok: for t in query_tokens: if t and t.lower() in summary.lower(): ok = True break if ok: relevant_docs.append(Document(page_content=summary, metadata={"community_id": str(cid)})) latency = time.time() - start_t try: evaluator = RetrievalEvaluator() result = RetrievalResult(query=question, retrieved_docs=retrieved_docs, relevant_docs=relevant_docs, retrieval_time=latency) metrics_obj = evaluator.evaluate_retrieval([result], k_values=k_values) metrics = { "precision_at_1": metrics_obj.precision_at_k.get(1, 0), "precision_at_3": metrics_obj.precision_at_k.get(3, 0), "precision_at_5": metrics_obj.precision_at_k.get(5, 0), "recall_at_1": metrics_obj.recall_at_k.get(1, 0), "recall_at_3": metrics_obj.recall_at_k.get(3, 0), "recall_at_5": metrics_obj.recall_at_k.get(5, 0), "map_score": metrics_obj.map_score, "mrr": metrics_obj.mrr, "latency": metrics_obj.latency, "retrieved_docs_count": len(retrieved_docs), "hallucination": hallucination_grade } except Exception: metrics = {"latency": latency, "retrieved_docs_count": len(retrieved_docs), "hallucination": hallucination_grade} return answer, metrics def hybrid_query_with_metrics(self, question: str) -> Dict[str, str]: print(f"\n🔀 执行混合查询并评估...") local_answer, local_metrics = self.local_query_with_metrics(question) global_answer, global_metrics = self.global_query_with_metrics(question) return { "local": local_answer, "global": global_answer, "local_hallucination": local_metrics.get("hallucination"), "global_hallucination": global_metrics.get("hallucination"), "local_metrics": local_metrics, "global_metrics": global_metrics, "question": question } def hybrid_query(self, question: str) -> Dict[str, str]: """ 混合查询 - 同时执行本地和全局查询,返回两种结果 Args: question: 用户问题 Returns: 包含本地和全局查询结果的字典 """ print(f"\n🔀 执行混合查询...") local_answer = self.local_query(question) global_answer = self.global_query(question) return { "local": local_answer, "global": global_answer, "question": question } def smart_query(self, question: str) -> str: """ 智能查询 - 根据问题类型自动选择查询策略 Args: question: 用户问题 Returns: 答案文本 """ # 判断问题类型 question_lower = question.lower() # 包含具体实体名称的问题 -> 本地查询 mentioned_entities = self.recognize_entities(question) if mentioned_entities: print("📍 检测到具体实体,使用本地查询") return self.local_query(question) # 概括性问题 -> 全局查询 global_keywords = ["主要", "总体", "概述", "整体", "主题", "讨论", "内容", "what", "overview", "main", "topics"] if any(keyword in question_lower for keyword in global_keywords): print("🌐 检测到概括性问题,使用全局查询") return self.global_query(question) # 默认使用本地查询 print("📍 使用本地查询作为默认策略") return self.local_query(question) def initialize_graph_retriever(knowledge_graph: KnowledgeGraph): """初始化GraphRAG检索器""" return GraphRetriever(knowledge_graph)