adaptive_rag / main_graphrag.py
lanny xu
add cuda
69629dd
"""
GraphRAG集成示例
展示如何在自适应RAG系统中使用知识图谱功能
"""
import os
from pprint import pprint
from config import (
setup_environment,
ENABLE_GRAPHRAG,
GRAPHRAG_INDEX_PATH,
GRAPHRAG_BATCH_SIZE
)
from document_processor import initialize_document_processor
from graph_indexer import initialize_graph_indexer
from graph_retriever import initialize_graph_retriever
class AdaptiveRAGWithGraph:
"""集成GraphRAG的自适应RAG系统"""
def __init__(self, enable_graphrag=True, rebuild_index=False):
print("🚀 初始化集成GraphRAG的自适应RAG系统...")
print("="*60)
# 设置环境
try:
setup_environment()
print("✅ 环境配置完成")
except ValueError as e:
print(f"❌ {e}")
raise
# 初始化文档处理器
print("\n📚 初始化文档处理器...")
self.doc_processor, self.vectorstore, self.retriever, self.doc_splits = \
initialize_document_processor()
# GraphRAG组件
self.enable_graphrag = enable_graphrag and ENABLE_GRAPHRAG
self.graph_indexer = None
self.graph_retriever = None
self.knowledge_graph = None
if self.enable_graphrag:
self._setup_graphrag(rebuild_index)
print("\n" + "="*60)
print("✅ 系统初始化完成!")
print("="*60)
def _setup_graphrag(self, rebuild_index=False):
"""设置GraphRAG组件"""
print("\n🔷 设置GraphRAG组件...")
# 初始化索引器
self.graph_indexer = initialize_graph_indexer()
# 检查是否已有索引
index_exists = os.path.exists(GRAPHRAG_INDEX_PATH)
if index_exists and not rebuild_index:
print(f"📂 发现现有索引: {GRAPHRAG_INDEX_PATH}")
print(" 加载现有索引...")
self.knowledge_graph = self.graph_indexer.load_index(GRAPHRAG_INDEX_PATH)
else:
if rebuild_index:
print("🔄 重新构建索引...")
else:
print("📝 首次构建索引...")
if self.doc_splits is None:
try:
docs_from_vs = self.doc_processor.get_all_documents_from_vectorstore()
if docs_from_vs:
self.doc_splits = docs_from_vs
else:
docs = self.doc_processor.load_documents()
self.doc_splits = self.doc_processor.split_documents(docs)
except Exception as e:
print(f" ❌ 准备GraphRAG文档块失败: {e}")
raise
# 构建索引
self.knowledge_graph = self.graph_indexer.index_documents(
documents=self.doc_splits,
batch_size=GRAPHRAG_BATCH_SIZE,
save_path=GRAPHRAG_INDEX_PATH
)
# 初始化检索器
self.graph_retriever = initialize_graph_retriever(self.knowledge_graph)
print("✅ GraphRAG组件设置完成")
def query_vector_only(self, question: str) -> str:
"""仅使用向量检索"""
print(f"\n{'='*60}")
print(f"🔍 向量检索模式")
print(f"问题: {question}")
print(f"{'='*60}")
docs = self.retriever.get_relevant_documents(question)
print(f"\n📄 检索到 {len(docs)} 个文档片段:")
for i, doc in enumerate(docs[:3], 1):
print(f"\n片段 {i}:")
print(f"{doc.page_content[:200]}...")
return self.doc_processor.format_docs(docs)
def query_graph_local(self, question: str) -> str:
"""使用图谱本地查询"""
if not self.enable_graphrag:
return "GraphRAG未启用"
print(f"\n{'='*60}")
print(f"🔎 图谱本地查询模式")
print(f"问题: {question}")
print(f"{'='*60}")
answer = self.graph_retriever.local_query(question)
print(f"\n💡 答案:")
print(answer)
return answer
def query_graph_global(self, question: str) -> str:
"""使用图谱全局查询"""
if not self.enable_graphrag:
return "GraphRAG未启用"
print(f"\n{'='*60}")
print(f"🌍 图谱全局查询模式")
print(f"问题: {question}")
print(f"{'='*60}")
answer = self.graph_retriever.global_query(question)
print(f"\n💡 答案:")
print(answer)
return answer
def query_hybrid(self, question: str) -> dict:
"""混合查询:向量 + 图谱"""
if not self.enable_graphrag:
return {"error": "GraphRAG未启用"}
print(f"\n{'='*60}")
print(f"🔀 混合查询模式")
print(f"问题: {question}")
print(f"{'='*60}")
# 向量检索
vector_docs = self.retriever.get_relevant_documents(question)
vector_context = self.doc_processor.format_docs(vector_docs[:3])
# 图谱查询
graph_results = self.graph_retriever.hybrid_query_with_metrics(question)
result = {
"question": question,
"vector_retrieval": {
"doc_count": len(vector_docs),
"context": vector_context[:500] + "..." if len(vector_context) > 500 else vector_context
},
"graph_local": graph_results["local"],
"graph_global": graph_results["global"],
"graph_local_hallucination": graph_results.get("local_hallucination"),
"graph_global_hallucination": graph_results.get("global_hallucination"),
"graph_local_metrics": graph_results.get("local_metrics"),
"graph_global_metrics": graph_results.get("global_metrics")
}
print("\n📊 结果汇总:")
print(f" • 向量检索: {len(vector_docs)} 个文档")
print(f" • 图谱本地查询完成")
print(f" • 图谱全局查询完成")
return result
def query_smart(self, question: str) -> str:
"""智能查询:自动选择最佳策略"""
if not self.enable_graphrag:
return self.query_vector_only(question)
print(f"\n{'='*60}")
print(f"🧠 智能查询模式")
print(f"问题: {question}")
print(f"{'='*60}")
answer = self.graph_retriever.smart_query(question)
print(f"\n💡 答案:")
print(answer)
return answer
def get_graph_statistics(self):
"""获取知识图谱统计信息"""
if not self.enable_graphrag or not self.knowledge_graph:
print("GraphRAG未启用或图谱未构建")
return
stats = self.knowledge_graph.get_statistics()
print("\n" + "="*60)
print("📊 知识图谱统计信息")
print("="*60)
print(f"节点数: {stats['num_nodes']}")
print(f"边数: {stats['num_edges']}")
print(f"社区数: {stats['num_communities']}")
print(f"图密度: {stats['density']:.4f}")
print("\n实体类型分布:")
for etype, count in stats['entity_types'].items():
print(f" • {etype}: {count}")
print("="*60)
return stats
def interactive_mode(self):
"""交互模式"""
print("\n" + "="*60)
print("🤖 欢迎使用GraphRAG增强的自适应RAG系统!")
print("="*60)
print("\n查询模式:")
print(" 1️⃣ vector - 仅向量检索")
print(" 2️⃣ local - 图谱本地查询")
print(" 3️⃣ global - 图谱全局查询")
print(" 4️⃣ hybrid - 混合查询")
print(" 5️⃣ smart - 智能查询(推荐)")
print(" 6️⃣ stats - 显示图谱统计")
print(" 7️⃣ quit - 退出")
print("-"*60)
while True:
try:
mode = input("\n选择模式 (1-7): ").strip()
if mode in ['7', 'quit', 'exit', '退出', 'q']:
print("👋 感谢使用,再见!")
break
if mode in ['6', 'stats']:
self.get_graph_statistics()
continue
question = input("❓ 请输入问题: ").strip()
if not question:
print("⚠️ 请输入有效问题")
continue
if mode in ['1', 'vector']:
self.query_vector_only(question)
elif mode in ['2', 'local']:
self.query_graph_local(question)
elif mode in ['3', 'global']:
self.query_graph_global(question)
elif mode in ['4', 'hybrid']:
result = self.query_hybrid(question)
pprint(result)
else: # 默认智能模式
self.query_smart(question)
except KeyboardInterrupt:
print("\n👋 感谢使用,再见!")
break
except Exception as e:
print(f"❌ 发生错误: {e}")
print("请重试或输入 'quit' 退出")
def main():
"""主函数"""
try:
# 初始化系统(首次运行设置rebuild_index=True)
rag_system = AdaptiveRAGWithGraph(
enable_graphrag=True,
rebuild_index=False # 设为True重新构建索引
)
# 显示图谱统计
rag_system.get_graph_statistics()
# 测试查询
print("\n" + "="*60)
print("🧪 测试查询示例")
print("="*60)
# 示例1: 本地查询
rag_system.query_graph_local("LLM Agent的主要组成部分是什么?")
# 示例2: 全局查询
rag_system.query_graph_global("这些文档主要讨论了什么主题?")
# 启动交互模式
rag_system.interactive_mode()
except Exception as e:
print(f"❌ 系统初始化失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()