Spaces:
Paused
Paused
| """ | |
| 主应用程序入口 | |
| 集成所有模块,构建工作流并运行自适应RAG系统 | |
| """ | |
| import time | |
| from langgraph.graph import END, StateGraph, START | |
| from pprint import pprint | |
| from config import setup_environment, validate_api_keys, ENABLE_GRAPHRAG | |
| from document_processor import initialize_document_processor | |
| from routers_and_graders import initialize_graders_and_router | |
| from workflow_nodes import WorkflowNodes, GraphState | |
| try: | |
| from knowledge_graph import initialize_knowledge_graph, initialize_community_summarizer | |
| from graph_retriever import initialize_graph_retriever | |
| except ImportError: | |
| print("⚠️ 无法导入知识图谱模块,GraphRAG功能将不可用") | |
| ENABLE_GRAPHRAG = False | |
| class AdaptiveRAGSystem: | |
| """自适应RAG系统主类""" | |
| def __init__(self): | |
| print("初始化自适应RAG系统...") | |
| # 设置环境和验证API密钥 | |
| try: | |
| setup_environment() | |
| validate_api_keys() # 验证API密钥是否正确设置 | |
| print("✅ API密钥验证成功") | |
| except ValueError as e: | |
| print(f"❌ {e}") | |
| raise | |
| # 检查 Ollama 服务是否运行 | |
| print("🔍 检查 Ollama 服务状态...") | |
| if not self._check_ollama_service(): | |
| print("\n" + "="*60) | |
| print("❌ Ollama 服务未启动!") | |
| print("="*60) | |
| print("\n请先启动 Ollama 服务:") | |
| print("\n方法1: 在终端运行") | |
| print(" $ ollama serve") | |
| print("\n方法2: 在 Kaggle Notebook 中运行") | |
| print(" import subprocess") | |
| print(" subprocess.Popen(['ollama', 'serve'])") | |
| print("\n方法3: 使用快捷脚本") | |
| print(" %run KAGGLE_LOAD_OLLAMA.py") | |
| print("="*60) | |
| raise ConnectionError("Ollama 服务未运行,请先启动服务") | |
| print("✅ Ollama 服务运行正常") | |
| # 初始化文档处理器 | |
| print("设置文档处理器...") | |
| self.doc_processor, self.vectorstore, self.retriever, self.doc_splits = initialize_document_processor() | |
| # 初始化评分器和路由器 | |
| print("初始化评分器和路由器...") | |
| self.graders = initialize_graders_and_router() | |
| # 初始化知识图谱 (如果启用) | |
| self.graph_retriever = None | |
| if ENABLE_GRAPHRAG: | |
| print("初始化 GraphRAG...") | |
| try: | |
| kg = initialize_knowledge_graph() | |
| # 尝试加载已有的图谱数据 | |
| try: | |
| kg.load_from_file("knowledge_graph.json") | |
| except FileNotFoundError: | |
| print(" 未找到 existing knowledge_graph.json, 将使用空图谱") | |
| self.graph_retriever = initialize_graph_retriever(kg) | |
| print("✅ GraphRAG 初始化成功") | |
| except Exception as e: | |
| print(f"⚠️ GraphRAG 初始化失败: {e}") | |
| # 初始化工作流节点 | |
| print("设置工作流节点...") | |
| # WorkflowNodes 将在 _build_workflow 中初始化 | |
| # 构建工作流 | |
| print("构建工作流图...") | |
| self.app = self._build_workflow() | |
| print("✅ 自适应RAG系统初始化完成!") | |
| def _check_ollama_service(self) -> bool: | |
| """检查 Ollama 服务是否运行""" | |
| import requests | |
| try: | |
| # 尝试连接 Ollama API | |
| response = requests.get('http://localhost:11434/api/tags', timeout=2) | |
| return response.status_code == 200 | |
| except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): | |
| return False | |
| def _build_workflow(self): | |
| """构建工作流图""" | |
| # 创建工作流节点实例,传递DocumentProcessor实例和retriever | |
| self.workflow_nodes = WorkflowNodes( | |
| doc_processor=self.doc_processor, | |
| graders=self.graders, | |
| retriever=self.retriever | |
| ) | |
| workflow = StateGraph(GraphState) | |
| # 定义节点 | |
| workflow.add_node("web_search", self.workflow_nodes.web_search) | |
| workflow.add_node("retrieve", self.workflow_nodes.retrieve) | |
| workflow.add_node("grade_documents", self.workflow_nodes.grade_documents) | |
| workflow.add_node("generate", self.workflow_nodes.generate) | |
| workflow.add_node("transform_query", self.workflow_nodes.transform_query) | |
| workflow.add_node("decompose_query", self.workflow_nodes.decompose_query) | |
| workflow.add_node("prepare_next_query", self.workflow_nodes.prepare_next_query) | |
| # 构建图 | |
| workflow.add_conditional_edges( | |
| START, | |
| self.workflow_nodes.route_question, | |
| { | |
| "web_search": "web_search", | |
| "vectorstore": "decompose_query", # 向量检索前先进行查询分解 | |
| }, | |
| ) | |
| workflow.add_edge("web_search", "generate") | |
| workflow.add_edge("decompose_query", "retrieve") | |
| workflow.add_edge("retrieve", "grade_documents") | |
| workflow.add_conditional_edges( | |
| "grade_documents", | |
| self.workflow_nodes.decide_to_generate, | |
| { | |
| "transform_query": "transform_query", | |
| "prepare_next_query": "prepare_next_query", | |
| "generate": "generate", | |
| "web_search": "web_search", # 添加 web_search 作为回退选项 | |
| }, | |
| ) | |
| workflow.add_edge("transform_query", "retrieve") | |
| workflow.add_edge("prepare_next_query", "retrieve") | |
| workflow.add_conditional_edges( | |
| "generate", | |
| self.workflow_nodes.grade_generation_v_documents_and_question, | |
| { | |
| "not supported": "transform_query", # 修复:有幻觉时重新转换查询,而不是再次生成 | |
| "useful": END, | |
| "not useful": "transform_query", | |
| }, | |
| ) | |
| # 编译(设置递归限制以防止无限循环) | |
| return workflow.compile( | |
| checkpointer=None, | |
| interrupt_before=None, | |
| interrupt_after=None, | |
| debug=False | |
| ) | |
| async def query(self, question: str, verbose: bool = True): | |
| """ | |
| 处理查询 (异步版本) | |
| Args: | |
| question (str): 用户问题 | |
| verbose (bool): 是否显示详细输出 | |
| Returns: | |
| dict: 包含最终答案和评估指标的字典 | |
| """ | |
| import asyncio | |
| print(f"\n🔍 处理问题: {question}") | |
| print("=" * 50) | |
| inputs = {"question": question, "retry_count": 0} # 初始化重试计数器 | |
| final_generation = None | |
| retrieval_metrics = None | |
| # 设置配置,增加递归限制 | |
| config = {"recursion_limit": 50} # 增加到 50,默认是 25 | |
| print("\n🤖 思考过程:") | |
| async for output in self.app.astream(inputs, config=config): | |
| for key, value in output.items(): | |
| if verbose: | |
| # 简单的节点执行提示,模拟流式感 | |
| print(f" ↳ 执行节点: {key}...", end="\r") | |
| # 异步暂停 | |
| await asyncio.sleep(0.1) | |
| print(f" ✅ 完成节点: {key} ") | |
| final_generation = value.get("generation", final_generation) | |
| # 保存检索评估指标 | |
| if "retrieval_metrics" in value: | |
| retrieval_metrics = value["retrieval_metrics"] | |
| print("\n" + "=" * 50) | |
| print("🎯 最终答案:") | |
| print("-" * 30) | |
| # 模拟流式输出效果 (打字机效果) | |
| if final_generation: | |
| import sys | |
| for char in final_generation: | |
| sys.stdout.write(char) | |
| sys.stdout.flush() | |
| # 异步暂停 | |
| await asyncio.sleep(0.01) # 控制打字速度 | |
| print() # 换行 | |
| else: | |
| print("未生成答案") | |
| print("=" * 50) | |
| # 返回包含答案和评估指标的字典 | |
| return { | |
| "answer": final_generation, | |
| "retrieval_metrics": retrieval_metrics | |
| } | |
| def interactive_mode(self): | |
| """交互模式,允许用户持续提问""" | |
| import asyncio | |
| print("\n🤖 欢迎使用自适应RAG系统!") | |
| print("💡 输入问题开始对话,输入 'quit' 或 'exit' 退出") | |
| print("-" * 50) | |
| while True: | |
| try: | |
| question = input("\n❓ 请输入您的问题: ").strip() | |
| if question.lower() in ['quit', 'exit', '退出', 'q']: | |
| print("👋 感谢使用,再见!") | |
| break | |
| if not question: | |
| print("⚠️ 请输入一个有效的问题") | |
| continue | |
| # 使用 asyncio.run 执行异步查询 | |
| result = asyncio.run(self.query(question)) | |
| # 显示检索评估摘要 | |
| if result.get("retrieval_metrics"): | |
| metrics = result["retrieval_metrics"] | |
| print("\n📊 检索评估摘要:") | |
| print(f" - 检索耗时: {metrics.get('latency', 0):.4f}秒") | |
| print(f" - 检索文档数: {metrics.get('retrieved_docs_count', 0)}") | |
| print(f" - Precision@3: {metrics.get('precision_at_3', 0):.4f}") | |
| print(f" - Recall@3: {metrics.get('recall_at_3', 0):.4f}") | |
| print(f" - MAP: {metrics.get('map_score', 0):.4f}") | |
| except KeyboardInterrupt: | |
| print("\n👋 感谢使用,再见!") | |
| break | |
| except Exception as e: | |
| print(f"❌ 发生错误: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print("请重试或输入 'quit' 退出") | |
| def main(): | |
| """主函数""" | |
| import asyncio | |
| try: | |
| # 初始化系统 | |
| rag_system: AdaptiveRAGSystem = AdaptiveRAGSystem() | |
| # 测试查询 | |
| # test_question = "AlphaCodium论文讲的是什么?" | |
| test_question = "LangGraph的作者目前在哪家公司工作?" | |
| # test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤" | |
| # 使用 asyncio.run 执行异步查询 | |
| result = asyncio.run(rag_system.query(test_question)) | |
| # 显示测试查询的检索评估摘要 | |
| if result.get("retrieval_metrics"): | |
| metrics = result["retrieval_metrics"] | |
| print("\n📊 测试查询检索评估摘要:") | |
| print(f" - 检索耗时: {metrics.get('latency', 0):.4f}秒") | |
| print(f" - 检索文档数: {metrics.get('retrieved_docs_count', 0)}") | |
| print(f" - Precision@3: {metrics.get('precision_at_3', 0):.4f}") | |
| print(f" - Recall@3: {metrics.get('recall_at_3', 0):.4f}") | |
| print(f" - MAP: {metrics.get('map_score', 0):.4f}") | |
| # 启动交互模式 | |
| rag_system.interactive_mode() | |
| except Exception as e: | |
| print(f"❌ 系统初始化失败: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print("请检查配置和依赖是否正确安装") | |
| if __name__ == "__main__": | |
| main() |