""" GraphRAG检索器 实现基于知识图谱的检索策略,包括本地查询和全局查询 """ from typing import List, Dict, Set, Tuple try: from langchain_core.prompts import PromptTemplate except ImportError: from langchain.prompts import PromptTemplate from langchain_community.chat_models import ChatOllama from langchain_core.output_parsers import StrOutputParser, JsonOutputParser from knowledge_graph import KnowledgeGraph from config import LOCAL_LLM class GraphRetriever: """基于知识图谱的检索器""" def __init__(self, knowledge_graph: KnowledgeGraph): self.kg = knowledge_graph self.llm = ChatOllama(model=LOCAL_LLM, temperature=0.3) # 实体识别提示 self.entity_recognition_prompt = PromptTemplate( template="""从以下问题中识别关键实体和概念: 问题: {question} 已知实体示例: {sample_entities} 请识别问题中提到的实体,返回JSON格式: {{ "entities": ["实体1", "实体2", ...] }} 只返回JSON,不要其他内容。 """, input_variables=["question", "sample_entities"] ) # 全局查询生成提示 self.global_query_prompt = PromptTemplate( template="""你是一个知识图谱分析专家。基于以下社区摘要,回答用户问题。 用户问题: {question} 相关社区摘要: {community_summaries} 请基于这些摘要提供一个综合性的答案。如果摘要中没有相关信息,请说明。 答案: """, input_variables=["question", "community_summaries"] ) # 本地查询生成提示 self.local_query_prompt = PromptTemplate( template="""基于以下实体及其关系信息,回答用户问题。 用户问题: {question} 相关实体信息: {entity_info} 实体间的关系: {relations} 请基于这些信息提供答案。 答案: """, input_variables=["question", "entity_info", "relations"] ) self.entity_recognition_chain = self.entity_recognition_prompt | self.llm | JsonOutputParser() self.global_query_chain = self.global_query_prompt | self.llm | StrOutputParser() self.local_query_chain = self.local_query_prompt | self.llm | StrOutputParser() def recognize_entities(self, question: str) -> List[str]: """ 从问题中识别实体 Args: question: 用户问题 Returns: 识别到的实体列表 """ # 获取一些示例实体 sample_entities = list(self.kg.entities.keys())[:10] sample_text = ", ".join(sample_entities) try: result = self.entity_recognition_chain.invoke({ "question": question, "sample_entities": sample_text }) entities = result.get("entities", []) # 匹配到图谱中的实体 matched_entities = [] for entity in entities: # 精确匹配 if entity in self.kg.entities: matched_entities.append(entity) else: # 模糊匹配 for kg_entity in self.kg.entities.keys(): if entity.lower() in kg_entity.lower() or kg_entity.lower() in entity.lower(): matched_entities.append(kg_entity) break print(f"🔍 识别到实体: {matched_entities}") return matched_entities except Exception as e: print(f"❌ 实体识别失败: {e}") return [] def local_query(self, question: str, max_hops: int = 2, top_k: int = 10) -> str: """ 本地查询 - 基于问题中的实体及其邻域进行检索 适用场景: 针对特定实体的详细问题 例如: "AlphaCodium的作者是谁?" Args: question: 用户问题 max_hops: 最大跳数 top_k: 返回的最大实体数 Returns: 答案文本 """ print(f"\n🔎 执行本地查询...") # 1. 识别问题中的实体 mentioned_entities = self.recognize_entities(question) if not mentioned_entities: return "未能在知识图谱中找到相关实体。" # 2. 获取实体的邻域 relevant_entities = set() for entity in mentioned_entities: neighbors = self.kg.get_node_neighbors(entity, depth=max_hops) relevant_entities.update(neighbors) relevant_entities = list(relevant_entities)[:top_k] # 3. 收集实体信息 entity_info_list = [] for entity in relevant_entities: info = self.kg.get_entity_info(entity) if info: entity_info_list.append( f"- {info['name']} ({info.get('type', 'UNKNOWN')}): {info.get('description', '无描述')}" ) # 4. 收集关系信息 relation_list = [] for u, v, data in self.kg.graph.edges(data=True): if u in relevant_entities and v in relevant_entities: relation_list.append( f"- {u} --[{data.get('relation_type', 'RELATED')}]--> {v}: {data.get('description', '')}" ) entity_info_text = "\n".join(entity_info_list) if entity_info_list else "无相关实体信息" relations_text = "\n".join(relation_list[:20]) if relation_list else "无相关关系" # 5. 生成答案 try: answer = self.local_query_chain.invoke({ "question": question, "entity_info": entity_info_text, "relations": relations_text }) print(f"✅ 本地查询完成") return answer.strip() except Exception as e: print(f"❌ 本地查询失败: {e}") return "查询失败,请重试。" def global_query(self, question: str, top_k_communities: int = 5) -> str: """ 全局查询 - 基于社区摘要进行高层次查询 适用场景: 需要整体理解的概括性问题 例如: "这些文档主要讨论什么主题?" Args: question: 用户问题 top_k_communities: 使用的社区数量 Returns: 答案文本 """ print(f"\n🌍 执行全局查询...") if not self.kg.community_summaries: return "知识图谱尚未生成社区摘要,请先运行索引流程。" # 获取社区摘要 community_summaries = [] for cid, summary in list(self.kg.community_summaries.items())[:top_k_communities]: community_summaries.append(f"社区 {cid}:\n{summary}\n") summaries_text = "\n".join(community_summaries) # 生成答案 try: answer = self.global_query_chain.invoke({ "question": question, "community_summaries": summaries_text }) print(f"✅ 全局查询完成") return answer.strip() except Exception as e: print(f"❌ 全局查询失败: {e}") return "查询失败,请重试。" def hybrid_query(self, question: str) -> Dict[str, str]: """ 混合查询 - 同时执行本地和全局查询,返回两种结果 Args: question: 用户问题 Returns: 包含本地和全局查询结果的字典 """ print(f"\n🔀 执行混合查询...") local_answer = self.local_query(question) global_answer = self.global_query(question) return { "local": local_answer, "global": global_answer, "question": question } def smart_query(self, question: str) -> str: """ 智能查询 - 根据问题类型自动选择查询策略 Args: question: 用户问题 Returns: 答案文本 """ # 判断问题类型 question_lower = question.lower() # 包含具体实体名称的问题 -> 本地查询 mentioned_entities = self.recognize_entities(question) if mentioned_entities: print("📍 检测到具体实体,使用本地查询") return self.local_query(question) # 概括性问题 -> 全局查询 global_keywords = ["主要", "总体", "概述", "整体", "主题", "讨论", "内容", "what", "overview", "main", "topics"] if any(keyword in question_lower for keyword in global_keywords): print("🌐 检测到概括性问题,使用全局查询") return self.global_query(question) # 默认使用本地查询 print("📍 使用本地查询作为默认策略") return self.local_query(question) def initialize_graph_retriever(knowledge_graph: KnowledgeGraph): """初始化GraphRAG检索器""" return GraphRetriever(knowledge_graph)