File size: 11,805 Bytes
399f3c6
 
 
 
 
9cce495
399f3c6
 
 
9cce495
399f3c6
 
 
9cce495
 
 
 
 
 
399f3c6
 
 
 
 
 
 
 
 
 
 
45bd829
399f3c6
 
 
 
 
9b75bde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399f3c6
 
4f5443a
399f3c6
 
 
 
 
9cce495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399f3c6
 
47b875d
399f3c6
 
 
 
 
 
 
9b75bde
 
 
 
 
 
 
 
 
 
399f3c6
 
11e6cae
ef805fe
 
11e6cae
 
ef805fe
 
399f3c6
 
 
 
 
 
 
 
9cce495
 
399f3c6
 
 
 
 
 
 
9cce495
399f3c6
 
 
9cce495
399f3c6
 
 
 
 
 
9cce495
399f3c6
f3ef5e1
399f3c6
 
 
9cce495
399f3c6
 
 
 
3f73db0
399f3c6
 
 
 
 
0990104
 
 
 
 
 
 
399f3c6
2d46508
399f3c6
2d46508
399f3c6
 
 
 
 
 
5ad083c
399f3c6
2d46508
399f3c6
 
 
3f73db0
399f3c6
5ad083c
399f3c6
0990104
 
 
9cce495
2d46508
399f3c6
 
9cce495
 
2d46508
 
9cce495
 
399f3c6
5ad083c
 
 
399f3c6
9cce495
399f3c6
 
9cce495
 
 
 
 
 
 
2d46508
 
9cce495
 
 
 
399f3c6
 
5ad083c
 
 
 
 
399f3c6
 
 
2d46508
399f3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d46508
 
5ad083c
 
 
 
 
 
 
 
 
 
399f3c6
 
 
 
 
 
2d46508
 
399f3c6
 
 
 
 
2d46508
399f3c6
 
8008bd3
399f3c6
 
9cce495
 
ee3fb2e
2d46508
 
 
5ad083c
 
 
 
 
 
 
 
 
 
399f3c6
 
 
 
 
 
2d46508
 
399f3c6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""
主应用程序入口
集成所有模块,构建工作流并运行自适应RAG系统
"""

import time
from langgraph.graph import END, StateGraph, START
from pprint import pprint

from config import setup_environment, validate_api_keys, ENABLE_GRAPHRAG
from document_processor import initialize_document_processor
from routers_and_graders import initialize_graders_and_router
from workflow_nodes import WorkflowNodes, GraphState
try:
    from knowledge_graph import initialize_knowledge_graph, initialize_community_summarizer
    from graph_retriever import initialize_graph_retriever
except ImportError:
    print("⚠️ 无法导入知识图谱模块,GraphRAG功能将不可用")
    ENABLE_GRAPHRAG = False


class AdaptiveRAGSystem:
    """自适应RAG系统主类"""
    
    def __init__(self):
        print("初始化自适应RAG系统...")
        
        # 设置环境和验证API密钥
        try:
            setup_environment()
            validate_api_keys()  # 验证API密钥是否正确设置
            print("✅ API密钥验证成功")
        except ValueError as e:
            print(f"❌ {e}")
            raise
        
        # 检查 Ollama 服务是否运行
        print("🔍 检查 Ollama 服务状态...")
        if not self._check_ollama_service():
            print("\n" + "="*60)
            print("❌ Ollama 服务未启动!")
            print("="*60)
            print("\n请先启动 Ollama 服务:")
            print("\n方法1: 在终端运行")
            print("  $ ollama serve")
            print("\n方法2: 在 Kaggle Notebook 中运行")
            print("  import subprocess")
            print("  subprocess.Popen(['ollama', 'serve'])")
            print("\n方法3: 使用快捷脚本")
            print("  %run KAGGLE_LOAD_OLLAMA.py")
            print("="*60)
            raise ConnectionError("Ollama 服务未运行,请先启动服务")
        
        print("✅ Ollama 服务运行正常")
        
        # 初始化文档处理器
        print("设置文档处理器...")
        self.doc_processor, self.vectorstore, self.retriever, self.doc_splits = initialize_document_processor()
        
        # 初始化评分器和路由器
        print("初始化评分器和路由器...")
        self.graders = initialize_graders_and_router()
        
        # 初始化知识图谱 (如果启用)
        self.graph_retriever = None
        if ENABLE_GRAPHRAG:
            print("初始化 GraphRAG...")
            try:
                kg = initialize_knowledge_graph()
                # 尝试加载已有的图谱数据
                try:
                    kg.load_from_file("knowledge_graph.json")
                except FileNotFoundError:
                    print("   未找到 existing knowledge_graph.json, 将使用空图谱")
                
                self.graph_retriever = initialize_graph_retriever(kg)
                print("✅ GraphRAG 初始化成功")
            except Exception as e:
                print(f"⚠️ GraphRAG 初始化失败: {e}")
        
        # 初始化工作流节点
        print("设置工作流节点...")
        # WorkflowNodes 将在 _build_workflow 中初始化
        
        # 构建工作流
        print("构建工作流图...")
        self.app = self._build_workflow()
        
        print("✅ 自适应RAG系统初始化完成!")
    
    def _check_ollama_service(self) -> bool:
        """检查 Ollama 服务是否运行"""
        import requests
        try:
            # 尝试连接 Ollama API
            response = requests.get('http://localhost:11434/api/tags', timeout=2)
            return response.status_code == 200
        except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
            return False
    
    def _build_workflow(self):
        """构建工作流图"""
        # 创建工作流节点实例,传递DocumentProcessor实例和retriever
        self.workflow_nodes = WorkflowNodes(
            doc_processor=self.doc_processor,
            graders=self.graders,
            retriever=self.retriever
        )
        
        workflow = StateGraph(GraphState)
        
        # 定义节点
        workflow.add_node("web_search", self.workflow_nodes.web_search)
        workflow.add_node("retrieve", self.workflow_nodes.retrieve)
        workflow.add_node("grade_documents", self.workflow_nodes.grade_documents)
        workflow.add_node("generate", self.workflow_nodes.generate)
        workflow.add_node("transform_query", self.workflow_nodes.transform_query)
        workflow.add_node("decompose_query", self.workflow_nodes.decompose_query)
        workflow.add_node("prepare_next_query", self.workflow_nodes.prepare_next_query)
        
        # 构建图
        workflow.add_conditional_edges(
            START,
            self.workflow_nodes.route_question,
            {
                "web_search": "web_search",
                "vectorstore": "decompose_query", # 向量检索前先进行查询分解
            },
        )
        workflow.add_edge("web_search", "generate")
        workflow.add_edge("decompose_query", "retrieve")
        workflow.add_edge("retrieve", "grade_documents")
        workflow.add_conditional_edges(
            "grade_documents",
            self.workflow_nodes.decide_to_generate,
            {
                "transform_query": "transform_query",
                "prepare_next_query": "prepare_next_query",
                "generate": "generate",
                "web_search": "web_search", # 添加 web_search 作为回退选项
            },
        )
        workflow.add_edge("transform_query", "retrieve")
        workflow.add_edge("prepare_next_query", "retrieve")
        workflow.add_conditional_edges(
            "generate",
            self.workflow_nodes.grade_generation_v_documents_and_question,
            {
                "not supported": "transform_query",  # 修复:有幻觉时重新转换查询,而不是再次生成
                "useful": END,
                "not useful": "transform_query",
            },
        )
        
        # 编译(设置递归限制以防止无限循环)
        return workflow.compile(
            checkpointer=None,
            interrupt_before=None,
            interrupt_after=None,
            debug=False
        )
    
    async def query(self, question: str, verbose: bool = True):
        """
        处理查询 (异步版本)
        
        Args:
            question (str): 用户问题
            verbose (bool): 是否显示详细输出
            
        Returns:
            dict: 包含最终答案和评估指标的字典
        """
        import asyncio
        print(f"\n🔍 处理问题: {question}")
        print("=" * 50)
        
        inputs = {"question": question, "retry_count": 0}  # 初始化重试计数器
        final_generation = None
        retrieval_metrics = None
        
        # 设置配置,增加递归限制
        config = {"recursion_limit": 50}  # 增加到 50,默认是 25
        
        print("\n🤖 思考过程:")
        async for output in self.app.astream(inputs, config=config):
            for key, value in output.items():
                if verbose:
                    # 简单的节点执行提示,模拟流式感
                    print(f"  ↳ 执行节点: {key}...", end="\r")
                    # 异步暂停
                    await asyncio.sleep(0.1) 
                    print(f"  ✅ 完成节点: {key}      ")
                    
                final_generation = value.get("generation", final_generation)
                # 保存检索评估指标
                if "retrieval_metrics" in value:
                    retrieval_metrics = value["retrieval_metrics"]
        
        print("\n" + "=" * 50)
        print("🎯 最终答案:")
        print("-" * 30)
        
        # 模拟流式输出效果 (打字机效果)
        if final_generation:
            import sys
            for char in final_generation:
                sys.stdout.write(char)
                sys.stdout.flush()
                # 异步暂停
                await asyncio.sleep(0.01) # 控制打字速度
            print() # 换行
        else:
            print("未生成答案")
            
        print("=" * 50)
        
        # 返回包含答案和评估指标的字典
        return {
            "answer": final_generation,
            "retrieval_metrics": retrieval_metrics
        }
    
    def interactive_mode(self):
        """交互模式,允许用户持续提问"""
        import asyncio
        print("\n🤖 欢迎使用自适应RAG系统!")
        print("💡 输入问题开始对话,输入 'quit' 或 'exit' 退出")
        print("-" * 50)
        
        while True:
            try:
                question = input("\n❓ 请输入您的问题: ").strip()
                
                if question.lower() in ['quit', 'exit', '退出', 'q']:
                    print("👋 感谢使用,再见!")
                    break
                
                if not question:
                    print("⚠️  请输入一个有效的问题")
                    continue
                
                # 使用 asyncio.run 执行异步查询
                result = asyncio.run(self.query(question))
                
                # 显示检索评估摘要
                if result.get("retrieval_metrics"):
                    metrics = result["retrieval_metrics"]
                    print("\n📊 检索评估摘要:")
                    print(f"   - 检索耗时: {metrics.get('latency', 0):.4f}秒")
                    print(f"   - 检索文档数: {metrics.get('retrieved_docs_count', 0)}")
                    print(f"   - Precision@3: {metrics.get('precision_at_3', 0):.4f}")
                    print(f"   - Recall@3: {metrics.get('recall_at_3', 0):.4f}")
                    print(f"   - MAP: {metrics.get('map_score', 0):.4f}")
                
            except KeyboardInterrupt:
                print("\n👋 感谢使用,再见!")
                break
            except Exception as e:
                print(f"❌ 发生错误: {e}")
                import traceback
                traceback.print_exc()
                print("请重试或输入 'quit' 退出")


def main():
    """主函数"""
    import asyncio
    try:
        # 初始化系统
        rag_system: AdaptiveRAGSystem = AdaptiveRAGSystem()
        
        # 测试查询
        # test_question = "AlphaCodium论文讲的是什么?"
        test_question = "LangGraph的作者目前在哪家公司工作?"
        # test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
        
        # 使用 asyncio.run 执行异步查询
        result = asyncio.run(rag_system.query(test_question))
        
        # 显示测试查询的检索评估摘要
        if result.get("retrieval_metrics"):
            metrics = result["retrieval_metrics"]
            print("\n📊 测试查询检索评估摘要:")
            print(f"   - 检索耗时: {metrics.get('latency', 0):.4f}秒")
            print(f"   - 检索文档数: {metrics.get('retrieved_docs_count', 0)}")
            print(f"   - Precision@3: {metrics.get('precision_at_3', 0):.4f}")
            print(f"   - Recall@3: {metrics.get('recall_at_3', 0):.4f}")
            print(f"   - MAP: {metrics.get('map_score', 0):.4f}")
        
        # 启动交互模式
        rag_system.interactive_mode()
        
    except Exception as e:
        print(f"❌ 系统初始化失败: {e}")
        import traceback
        traceback.print_exc()
        print("请检查配置和依赖是否正确安装")


if __name__ == "__main__":
    main()