""" GraphRAG集成示例 展示如何在自适应RAG系统中使用知识图谱功能 """ import os from pprint import pprint from config import ( setup_environment, ENABLE_GRAPHRAG, GRAPHRAG_INDEX_PATH, GRAPHRAG_BATCH_SIZE ) from document_processor import initialize_document_processor from graph_indexer import initialize_graph_indexer from graph_retriever import initialize_graph_retriever class AdaptiveRAGWithGraph: """集成GraphRAG的自适应RAG系统""" def __init__(self, enable_graphrag=True, rebuild_index=False): print("🚀 初始化集成GraphRAG的自适应RAG系统...") print("="*60) # 设置环境 try: setup_environment() print("✅ 环境配置完成") except ValueError as e: print(f"❌ {e}") raise # 初始化文档处理器 print("\n📚 初始化文档处理器...") self.doc_processor, self.vectorstore, self.retriever, self.doc_splits = \ initialize_document_processor() # GraphRAG组件 self.enable_graphrag = enable_graphrag and ENABLE_GRAPHRAG self.graph_indexer = None self.graph_retriever = None self.knowledge_graph = None if self.enable_graphrag: self._setup_graphrag(rebuild_index) print("\n" + "="*60) print("✅ 系统初始化完成!") print("="*60) def _setup_graphrag(self, rebuild_index=False): """设置GraphRAG组件""" print("\n🔷 设置GraphRAG组件...") # 初始化索引器 self.graph_indexer = initialize_graph_indexer() # 检查是否已有索引 index_exists = os.path.exists(GRAPHRAG_INDEX_PATH) if index_exists and not rebuild_index: print(f"📂 发现现有索引: {GRAPHRAG_INDEX_PATH}") print(" 加载现有索引...") self.knowledge_graph = self.graph_indexer.load_index(GRAPHRAG_INDEX_PATH) else: if rebuild_index: print("🔄 重新构建索引...") else: print("📝 首次构建索引...") if self.doc_splits is None: try: docs_from_vs = self.doc_processor.get_all_documents_from_vectorstore() if docs_from_vs: self.doc_splits = docs_from_vs else: docs = self.doc_processor.load_documents() self.doc_splits = self.doc_processor.split_documents(docs) except Exception as e: print(f" ❌ 准备GraphRAG文档块失败: {e}") raise # 构建索引 self.knowledge_graph = self.graph_indexer.index_documents( documents=self.doc_splits, batch_size=GRAPHRAG_BATCH_SIZE, save_path=GRAPHRAG_INDEX_PATH ) # 初始化检索器 self.graph_retriever = initialize_graph_retriever(self.knowledge_graph) print("✅ GraphRAG组件设置完成") def query_vector_only(self, question: str) -> str: """仅使用向量检索""" print(f"\n{'='*60}") print(f"🔍 向量检索模式") print(f"问题: {question}") print(f"{'='*60}") docs = self.retriever.get_relevant_documents(question) print(f"\n📄 检索到 {len(docs)} 个文档片段:") for i, doc in enumerate(docs[:3], 1): print(f"\n片段 {i}:") print(f"{doc.page_content[:200]}...") return self.doc_processor.format_docs(docs) def query_graph_local(self, question: str) -> str: """使用图谱本地查询""" if not self.enable_graphrag: return "GraphRAG未启用" print(f"\n{'='*60}") print(f"🔎 图谱本地查询模式") print(f"问题: {question}") print(f"{'='*60}") answer = self.graph_retriever.local_query(question) print(f"\n💡 答案:") print(answer) return answer def query_graph_global(self, question: str) -> str: """使用图谱全局查询""" if not self.enable_graphrag: return "GraphRAG未启用" print(f"\n{'='*60}") print(f"🌍 图谱全局查询模式") print(f"问题: {question}") print(f"{'='*60}") answer = self.graph_retriever.global_query(question) print(f"\n💡 答案:") print(answer) return answer def query_hybrid(self, question: str) -> dict: """混合查询:向量 + 图谱""" if not self.enable_graphrag: return {"error": "GraphRAG未启用"} print(f"\n{'='*60}") print(f"🔀 混合查询模式") print(f"问题: {question}") print(f"{'='*60}") # 向量检索 vector_docs = self.retriever.get_relevant_documents(question) vector_context = self.doc_processor.format_docs(vector_docs[:3]) # 图谱查询 graph_results = self.graph_retriever.hybrid_query_with_metrics(question) result = { "question": question, "vector_retrieval": { "doc_count": len(vector_docs), "context": vector_context[:500] + "..." if len(vector_context) > 500 else vector_context }, "graph_local": graph_results["local"], "graph_global": graph_results["global"], "graph_local_hallucination": graph_results.get("local_hallucination"), "graph_global_hallucination": graph_results.get("global_hallucination"), "graph_local_metrics": graph_results.get("local_metrics"), "graph_global_metrics": graph_results.get("global_metrics") } print("\n📊 结果汇总:") print(f" • 向量检索: {len(vector_docs)} 个文档") print(f" • 图谱本地查询完成") print(f" • 图谱全局查询完成") return result def query_smart(self, question: str) -> str: """智能查询:自动选择最佳策略""" if not self.enable_graphrag: return self.query_vector_only(question) print(f"\n{'='*60}") print(f"🧠 智能查询模式") print(f"问题: {question}") print(f"{'='*60}") answer = self.graph_retriever.smart_query(question) print(f"\n💡 答案:") print(answer) return answer def get_graph_statistics(self): """获取知识图谱统计信息""" if not self.enable_graphrag or not self.knowledge_graph: print("GraphRAG未启用或图谱未构建") return stats = self.knowledge_graph.get_statistics() print("\n" + "="*60) print("📊 知识图谱统计信息") print("="*60) print(f"节点数: {stats['num_nodes']}") print(f"边数: {stats['num_edges']}") print(f"社区数: {stats['num_communities']}") print(f"图密度: {stats['density']:.4f}") print("\n实体类型分布:") for etype, count in stats['entity_types'].items(): print(f" • {etype}: {count}") print("="*60) return stats def interactive_mode(self): """交互模式""" print("\n" + "="*60) print("🤖 欢迎使用GraphRAG增强的自适应RAG系统!") print("="*60) print("\n查询模式:") print(" 1️⃣ vector - 仅向量检索") print(" 2️⃣ local - 图谱本地查询") print(" 3️⃣ global - 图谱全局查询") print(" 4️⃣ hybrid - 混合查询") print(" 5️⃣ smart - 智能查询(推荐)") print(" 6️⃣ stats - 显示图谱统计") print(" 7️⃣ quit - 退出") print("-"*60) while True: try: mode = input("\n选择模式 (1-7): ").strip() if mode in ['7', 'quit', 'exit', '退出', 'q']: print("👋 感谢使用,再见!") break if mode in ['6', 'stats']: self.get_graph_statistics() continue question = input("❓ 请输入问题: ").strip() if not question: print("⚠️ 请输入有效问题") continue if mode in ['1', 'vector']: self.query_vector_only(question) elif mode in ['2', 'local']: self.query_graph_local(question) elif mode in ['3', 'global']: self.query_graph_global(question) elif mode in ['4', 'hybrid']: result = self.query_hybrid(question) pprint(result) else: # 默认智能模式 self.query_smart(question) except KeyboardInterrupt: print("\n👋 感谢使用,再见!") break except Exception as e: print(f"❌ 发生错误: {e}") print("请重试或输入 'quit' 退出") def main(): """主函数""" try: # 初始化系统(首次运行设置rebuild_index=True) rag_system = AdaptiveRAGWithGraph( enable_graphrag=True, rebuild_index=False # 设为True重新构建索引 ) # 显示图谱统计 rag_system.get_graph_statistics() # 测试查询 print("\n" + "="*60) print("🧪 测试查询示例") print("="*60) # 示例1: 本地查询 rag_system.query_graph_local("LLM Agent的主要组成部分是什么?") # 示例2: 全局查询 rag_system.query_graph_global("这些文档主要讨论了什么主题?") # 启动交互模式 rag_system.interactive_mode() except Exception as e: print(f"❌ 系统初始化失败: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()