adaptive_rag / main.py
lanny xu
add Milvus db
f3ef5e1
raw
history blame
11.8 kB
"""
主应用程序入口
集成所有模块,构建工作流并运行自适应RAG系统
"""
import time
from langgraph.graph import END, StateGraph, START
from pprint import pprint
from config import setup_environment, validate_api_keys, ENABLE_GRAPHRAG
from document_processor import initialize_document_processor
from routers_and_graders import initialize_graders_and_router
from workflow_nodes import WorkflowNodes, GraphState
try:
from knowledge_graph import initialize_knowledge_graph, initialize_community_summarizer
from graph_retriever import initialize_graph_retriever
except ImportError:
print("⚠️ 无法导入知识图谱模块,GraphRAG功能将不可用")
ENABLE_GRAPHRAG = False
class AdaptiveRAGSystem:
"""自适应RAG系统主类"""
def __init__(self):
print("初始化自适应RAG系统...")
# 设置环境和验证API密钥
try:
setup_environment()
validate_api_keys() # 验证API密钥是否正确设置
print("✅ API密钥验证成功")
except ValueError as e:
print(f"❌ {e}")
raise
# 检查 Ollama 服务是否运行
print("🔍 检查 Ollama 服务状态...")
if not self._check_ollama_service():
print("\n" + "="*60)
print("❌ Ollama 服务未启动!")
print("="*60)
print("\n请先启动 Ollama 服务:")
print("\n方法1: 在终端运行")
print(" $ ollama serve")
print("\n方法2: 在 Kaggle Notebook 中运行")
print(" import subprocess")
print(" subprocess.Popen(['ollama', 'serve'])")
print("\n方法3: 使用快捷脚本")
print(" %run KAGGLE_LOAD_OLLAMA.py")
print("="*60)
raise ConnectionError("Ollama 服务未运行,请先启动服务")
print("✅ Ollama 服务运行正常")
# 初始化文档处理器
print("设置文档处理器...")
self.doc_processor, self.vectorstore, self.retriever, self.doc_splits = initialize_document_processor()
# 初始化评分器和路由器
print("初始化评分器和路由器...")
self.graders = initialize_graders_and_router()
# 初始化知识图谱 (如果启用)
self.graph_retriever = None
if ENABLE_GRAPHRAG:
print("初始化 GraphRAG...")
try:
kg = initialize_knowledge_graph()
# 尝试加载已有的图谱数据
try:
kg.load_from_file("knowledge_graph.json")
except FileNotFoundError:
print(" 未找到 existing knowledge_graph.json, 将使用空图谱")
self.graph_retriever = initialize_graph_retriever(kg)
print("✅ GraphRAG 初始化成功")
except Exception as e:
print(f"⚠️ GraphRAG 初始化失败: {e}")
# 初始化工作流节点
print("设置工作流节点...")
# WorkflowNodes 将在 _build_workflow 中初始化
# 构建工作流
print("构建工作流图...")
self.app = self._build_workflow()
print("✅ 自适应RAG系统初始化完成!")
def _check_ollama_service(self) -> bool:
"""检查 Ollama 服务是否运行"""
import requests
try:
# 尝试连接 Ollama API
response = requests.get('http://localhost:11434/api/tags', timeout=2)
return response.status_code == 200
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
return False
def _build_workflow(self):
"""构建工作流图"""
# 创建工作流节点实例,传递DocumentProcessor实例和retriever
self.workflow_nodes = WorkflowNodes(
doc_processor=self.doc_processor,
graders=self.graders,
retriever=self.retriever
)
workflow = StateGraph(GraphState)
# 定义节点
workflow.add_node("web_search", self.workflow_nodes.web_search)
workflow.add_node("retrieve", self.workflow_nodes.retrieve)
workflow.add_node("grade_documents", self.workflow_nodes.grade_documents)
workflow.add_node("generate", self.workflow_nodes.generate)
workflow.add_node("transform_query", self.workflow_nodes.transform_query)
workflow.add_node("decompose_query", self.workflow_nodes.decompose_query)
workflow.add_node("prepare_next_query", self.workflow_nodes.prepare_next_query)
# 构建图
workflow.add_conditional_edges(
START,
self.workflow_nodes.route_question,
{
"web_search": "web_search",
"vectorstore": "decompose_query", # 向量检索前先进行查询分解
},
)
workflow.add_edge("web_search", "generate")
workflow.add_edge("decompose_query", "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
self.workflow_nodes.decide_to_generate,
{
"transform_query": "transform_query",
"prepare_next_query": "prepare_next_query",
"generate": "generate",
"web_search": "web_search", # 添加 web_search 作为回退选项
},
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_edge("prepare_next_query", "retrieve")
workflow.add_conditional_edges(
"generate",
self.workflow_nodes.grade_generation_v_documents_and_question,
{
"not supported": "transform_query", # 修复:有幻觉时重新转换查询,而不是再次生成
"useful": END,
"not useful": "transform_query",
},
)
# 编译(设置递归限制以防止无限循环)
return workflow.compile(
checkpointer=None,
interrupt_before=None,
interrupt_after=None,
debug=False
)
async def query(self, question: str, verbose: bool = True):
"""
处理查询 (异步版本)
Args:
question (str): 用户问题
verbose (bool): 是否显示详细输出
Returns:
dict: 包含最终答案和评估指标的字典
"""
import asyncio
print(f"\n🔍 处理问题: {question}")
print("=" * 50)
inputs = {"question": question, "retry_count": 0} # 初始化重试计数器
final_generation = None
retrieval_metrics = None
# 设置配置,增加递归限制
config = {"recursion_limit": 50} # 增加到 50,默认是 25
print("\n🤖 思考过程:")
async for output in self.app.astream(inputs, config=config):
for key, value in output.items():
if verbose:
# 简单的节点执行提示,模拟流式感
print(f" ↳ 执行节点: {key}...", end="\r")
# 异步暂停
await asyncio.sleep(0.1)
print(f" ✅ 完成节点: {key} ")
final_generation = value.get("generation", final_generation)
# 保存检索评估指标
if "retrieval_metrics" in value:
retrieval_metrics = value["retrieval_metrics"]
print("\n" + "=" * 50)
print("🎯 最终答案:")
print("-" * 30)
# 模拟流式输出效果 (打字机效果)
if final_generation:
import sys
for char in final_generation:
sys.stdout.write(char)
sys.stdout.flush()
# 异步暂停
await asyncio.sleep(0.01) # 控制打字速度
print() # 换行
else:
print("未生成答案")
print("=" * 50)
# 返回包含答案和评估指标的字典
return {
"answer": final_generation,
"retrieval_metrics": retrieval_metrics
}
def interactive_mode(self):
"""交互模式,允许用户持续提问"""
import asyncio
print("\n🤖 欢迎使用自适应RAG系统!")
print("💡 输入问题开始对话,输入 'quit' 或 'exit' 退出")
print("-" * 50)
while True:
try:
question = input("\n❓ 请输入您的问题: ").strip()
if question.lower() in ['quit', 'exit', '退出', 'q']:
print("👋 感谢使用,再见!")
break
if not question:
print("⚠️ 请输入一个有效的问题")
continue
# 使用 asyncio.run 执行异步查询
result = asyncio.run(self.query(question))
# 显示检索评估摘要
if result.get("retrieval_metrics"):
metrics = result["retrieval_metrics"]
print("\n📊 检索评估摘要:")
print(f" - 检索耗时: {metrics.get('latency', 0):.4f}秒")
print(f" - 检索文档数: {metrics.get('retrieved_docs_count', 0)}")
print(f" - Precision@3: {metrics.get('precision_at_3', 0):.4f}")
print(f" - Recall@3: {metrics.get('recall_at_3', 0):.4f}")
print(f" - MAP: {metrics.get('map_score', 0):.4f}")
except KeyboardInterrupt:
print("\n👋 感谢使用,再见!")
break
except Exception as e:
print(f"❌ 发生错误: {e}")
import traceback
traceback.print_exc()
print("请重试或输入 'quit' 退出")
def main():
"""主函数"""
import asyncio
try:
# 初始化系统
rag_system: AdaptiveRAGSystem = AdaptiveRAGSystem()
# 测试查询
# test_question = "AlphaCodium论文讲的是什么?"
test_question = "LangGraph的作者目前在哪家公司工作?"
# test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
# 使用 asyncio.run 执行异步查询
result = asyncio.run(rag_system.query(test_question))
# 显示测试查询的检索评估摘要
if result.get("retrieval_metrics"):
metrics = result["retrieval_metrics"]
print("\n📊 测试查询检索评估摘要:")
print(f" - 检索耗时: {metrics.get('latency', 0):.4f}秒")
print(f" - 检索文档数: {metrics.get('retrieved_docs_count', 0)}")
print(f" - Precision@3: {metrics.get('precision_at_3', 0):.4f}")
print(f" - Recall@3: {metrics.get('recall_at_3', 0):.4f}")
print(f" - MAP: {metrics.get('map_score', 0):.4f}")
# 启动交互模式
rag_system.interactive_mode()
except Exception as e:
print(f"❌ 系统初始化失败: {e}")
import traceback
traceback.print_exc()
print("请检查配置和依赖是否正确安装")
if __name__ == "__main__":
main()