File size: 10,675 Bytes
399f3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67e46c9
 
69629dd
 
 
 
 
 
67e46c9
69629dd
67e46c9
 
399f3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f144ed
399f3c6
 
 
 
 
 
 
 
9f144ed
 
 
 
 
399f3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""
GraphRAG集成示例
展示如何在自适应RAG系统中使用知识图谱功能
"""

import os
from pprint import pprint

from config import (
    setup_environment, 
    ENABLE_GRAPHRAG,
    GRAPHRAG_INDEX_PATH,
    GRAPHRAG_BATCH_SIZE
)
from document_processor import initialize_document_processor
from graph_indexer import initialize_graph_indexer
from graph_retriever import initialize_graph_retriever


class AdaptiveRAGWithGraph:
    """集成GraphRAG的自适应RAG系统"""
    
    def __init__(self, enable_graphrag=True, rebuild_index=False):
        print("🚀 初始化集成GraphRAG的自适应RAG系统...")
        print("="*60)
        
        # 设置环境
        try:
            setup_environment()
            print("✅ 环境配置完成")
        except ValueError as e:
            print(f"❌ {e}")
            raise
        
        # 初始化文档处理器
        print("\n📚 初始化文档处理器...")
        self.doc_processor, self.vectorstore, self.retriever, self.doc_splits = \
            initialize_document_processor()
        
        # GraphRAG组件
        self.enable_graphrag = enable_graphrag and ENABLE_GRAPHRAG
        self.graph_indexer = None
        self.graph_retriever = None
        self.knowledge_graph = None
        
        if self.enable_graphrag:
            self._setup_graphrag(rebuild_index)
        
        print("\n" + "="*60)
        print("✅ 系统初始化完成!")
        print("="*60)
    
    def _setup_graphrag(self, rebuild_index=False):
        """设置GraphRAG组件"""
        print("\n🔷 设置GraphRAG组件...")
        
        # 初始化索引器
        self.graph_indexer = initialize_graph_indexer()
        
        # 检查是否已有索引
        index_exists = os.path.exists(GRAPHRAG_INDEX_PATH)
        
        if index_exists and not rebuild_index:
            print(f"📂 发现现有索引: {GRAPHRAG_INDEX_PATH}")
            print("   加载现有索引...")
            self.knowledge_graph = self.graph_indexer.load_index(GRAPHRAG_INDEX_PATH)
        else:
            if rebuild_index:
                print("🔄 重新构建索引...")
            else:
                print("📝 首次构建索引...")
            
            if self.doc_splits is None:
                try:
                    docs_from_vs = self.doc_processor.get_all_documents_from_vectorstore()
                    if docs_from_vs:
                        self.doc_splits = docs_from_vs
                    else:
                        docs = self.doc_processor.load_documents()
                        self.doc_splits = self.doc_processor.split_documents(docs)
                except Exception as e:
                    print(f"   ❌ 准备GraphRAG文档块失败: {e}")
                    raise

            # 构建索引
            self.knowledge_graph = self.graph_indexer.index_documents(
                documents=self.doc_splits,
                batch_size=GRAPHRAG_BATCH_SIZE,
                save_path=GRAPHRAG_INDEX_PATH
            )
        
        # 初始化检索器
        self.graph_retriever = initialize_graph_retriever(self.knowledge_graph)
        print("✅ GraphRAG组件设置完成")
    
    def query_vector_only(self, question: str) -> str:
        """仅使用向量检索"""
        print(f"\n{'='*60}")
        print(f"🔍 向量检索模式")
        print(f"问题: {question}")
        print(f"{'='*60}")
        
        docs = self.retriever.get_relevant_documents(question)
        
        print(f"\n📄 检索到 {len(docs)} 个文档片段:")
        for i, doc in enumerate(docs[:3], 1):
            print(f"\n片段 {i}:")
            print(f"{doc.page_content[:200]}...")
        
        return self.doc_processor.format_docs(docs)
    
    def query_graph_local(self, question: str) -> str:
        """使用图谱本地查询"""
        if not self.enable_graphrag:
            return "GraphRAG未启用"
        
        print(f"\n{'='*60}")
        print(f"🔎 图谱本地查询模式")
        print(f"问题: {question}")
        print(f"{'='*60}")
        
        answer = self.graph_retriever.local_query(question)
        
        print(f"\n💡 答案:")
        print(answer)
        
        return answer
    
    def query_graph_global(self, question: str) -> str:
        """使用图谱全局查询"""
        if not self.enable_graphrag:
            return "GraphRAG未启用"
        
        print(f"\n{'='*60}")
        print(f"🌍 图谱全局查询模式")
        print(f"问题: {question}")
        print(f"{'='*60}")
        
        answer = self.graph_retriever.global_query(question)
        
        print(f"\n💡 答案:")
        print(answer)
        
        return answer
    
    def query_hybrid(self, question: str) -> dict:
        """混合查询:向量 + 图谱"""
        if not self.enable_graphrag:
            return {"error": "GraphRAG未启用"}
        
        print(f"\n{'='*60}")
        print(f"🔀 混合查询模式")
        print(f"问题: {question}")
        print(f"{'='*60}")
        
        # 向量检索
        vector_docs = self.retriever.get_relevant_documents(question)
        vector_context = self.doc_processor.format_docs(vector_docs[:3])
        
        # 图谱查询
        graph_results = self.graph_retriever.hybrid_query_with_metrics(question)
        
        result = {
            "question": question,
            "vector_retrieval": {
                "doc_count": len(vector_docs),
                "context": vector_context[:500] + "..." if len(vector_context) > 500 else vector_context
            },
            "graph_local": graph_results["local"],
            "graph_global": graph_results["global"],
            "graph_local_hallucination": graph_results.get("local_hallucination"),
            "graph_global_hallucination": graph_results.get("global_hallucination"),
            "graph_local_metrics": graph_results.get("local_metrics"),
            "graph_global_metrics": graph_results.get("global_metrics")
        }
        
        print("\n📊 结果汇总:")
        print(f"  • 向量检索: {len(vector_docs)} 个文档")
        print(f"  • 图谱本地查询完成")
        print(f"  • 图谱全局查询完成")
        
        return result
    
    def query_smart(self, question: str) -> str:
        """智能查询:自动选择最佳策略"""
        if not self.enable_graphrag:
            return self.query_vector_only(question)
        
        print(f"\n{'='*60}")
        print(f"🧠 智能查询模式")
        print(f"问题: {question}")
        print(f"{'='*60}")
        
        answer = self.graph_retriever.smart_query(question)
        
        print(f"\n💡 答案:")
        print(answer)
        
        return answer
    
    def get_graph_statistics(self):
        """获取知识图谱统计信息"""
        if not self.enable_graphrag or not self.knowledge_graph:
            print("GraphRAG未启用或图谱未构建")
            return
        
        stats = self.knowledge_graph.get_statistics()
        
        print("\n" + "="*60)
        print("📊 知识图谱统计信息")
        print("="*60)
        print(f"节点数: {stats['num_nodes']}")
        print(f"边数: {stats['num_edges']}")
        print(f"社区数: {stats['num_communities']}")
        print(f"图密度: {stats['density']:.4f}")
        print("\n实体类型分布:")
        for etype, count in stats['entity_types'].items():
            print(f"  • {etype}: {count}")
        print("="*60)
        
        return stats
    
    def interactive_mode(self):
        """交互模式"""
        print("\n" + "="*60)
        print("🤖 欢迎使用GraphRAG增强的自适应RAG系统!")
        print("="*60)
        print("\n查询模式:")
        print("  1️⃣  vector   - 仅向量检索")
        print("  2️⃣  local    - 图谱本地查询")
        print("  3️⃣  global   - 图谱全局查询")
        print("  4️⃣  hybrid   - 混合查询")
        print("  5️⃣  smart    - 智能查询(推荐)")
        print("  6️⃣  stats    - 显示图谱统计")
        print("  7️⃣  quit     - 退出")
        print("-"*60)
        
        while True:
            try:
                mode = input("\n选择模式 (1-7): ").strip()
                
                if mode in ['7', 'quit', 'exit', '退出', 'q']:
                    print("👋 感谢使用,再见!")
                    break
                
                if mode in ['6', 'stats']:
                    self.get_graph_statistics()
                    continue
                
                question = input("❓ 请输入问题: ").strip()
                
                if not question:
                    print("⚠️  请输入有效问题")
                    continue
                
                if mode in ['1', 'vector']:
                    self.query_vector_only(question)
                elif mode in ['2', 'local']:
                    self.query_graph_local(question)
                elif mode in ['3', 'global']:
                    self.query_graph_global(question)
                elif mode in ['4', 'hybrid']:
                    result = self.query_hybrid(question)
                    pprint(result)
                else:  # 默认智能模式
                    self.query_smart(question)
                
            except KeyboardInterrupt:
                print("\n👋 感谢使用,再见!")
                break
            except Exception as e:
                print(f"❌ 发生错误: {e}")
                print("请重试或输入 'quit' 退出")


def main():
    """主函数"""
    try:
        # 初始化系统(首次运行设置rebuild_index=True)
        rag_system = AdaptiveRAGWithGraph(
            enable_graphrag=True,
            rebuild_index=False  # 设为True重新构建索引
        )
        
        # 显示图谱统计
        rag_system.get_graph_statistics()
        
        # 测试查询
        print("\n" + "="*60)
        print("🧪 测试查询示例")
        print("="*60)
        
        # 示例1: 本地查询
        rag_system.query_graph_local("LLM Agent的主要组成部分是什么?")
        
        # 示例2: 全局查询  
        rag_system.query_graph_global("这些文档主要讨论了什么主题?")
        
        # 启动交互模式
        rag_system.interactive_mode()
        
    except Exception as e:
        print(f"❌ 系统初始化失败: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()