File size: 5,690 Bytes
399f3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8008bd3
399f3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
主应用程序入口
集成所有模块,构建工作流并运行自适应RAG系统
"""

from langgraph.graph import END, StateGraph, START
from pprint import pprint

from config import setup_environment
from document_processor import initialize_document_processor
from routers_and_graders import initialize_graders_and_router
from workflow_nodes import WorkflowNodes, GraphState


class AdaptiveRAGSystem:
    """自适应RAG系统主类"""
    
    def __init__(self):
        print("初始化自适应RAG系统...")
        
        # 设置环境和验证API密钥
        try:
            setup_environment()
       
            print("✅ API密钥验证成功")
        except ValueError as e:
            print(f"❌ {e}")
            raise
        
        # 初始化文档处理器
        print("设置文档处理器...")
        self.doc_processor, self.vectorstore, self.retriever = initialize_document_processor()
        
        # 初始化评分器和路由器
        print("初始化评分器和路由器...")
        self.graders = initialize_graders_and_router()
        
        # 初始化工作流节点
        print("设置工作流节点...")
        self.workflow_nodes = WorkflowNodes(self.retriever, self.graders)
        
        # 构建工作流
        print("构建工作流图...")
        self.app = self._build_workflow()
        
        print("✅ 自适应RAG系统初始化完成!")
    
    def _build_workflow(self):
        """构建工作流图"""
        workflow = StateGraph(GraphState)
        
        # 定义节点
        workflow.add_node("web_search", self.workflow_nodes.web_search)
        workflow.add_node("retrieve", self.workflow_nodes.retrieve)
        workflow.add_node("grade_documents", self.workflow_nodes.grade_documents)
        workflow.add_node("generate", self.workflow_nodes.generate)
        workflow.add_node("transform_query", self.workflow_nodes.transform_query)
        
        # 构建图
        workflow.add_conditional_edges(
            START,
            self.workflow_nodes.route_question,
            {
                "web_search": "web_search",
                "vectorstore": "retrieve",
            },
        )
        workflow.add_edge("web_search", "generate")
        workflow.add_edge("retrieve", "grade_documents")
        workflow.add_conditional_edges(
            "grade_documents",
            self.workflow_nodes.decide_to_generate,
            {
                "transform_query": "transform_query",
                "generate": "generate",
            },
        )
        workflow.add_edge("transform_query", "retrieve")
        workflow.add_conditional_edges(
            "generate",
            self.workflow_nodes.grade_generation_v_documents_and_question,
            {
                "not supported": "generate",
                "useful": END,
                "not useful": "transform_query",
            },
        )
        
        # 编译
        return workflow.compile()
    
    def query(self, question: str, verbose: bool = True):
        """
        处理查询
        
        Args:
            question (str): 用户问题
            verbose (bool): 是否显示详细输出
            
        Returns:
            str: 最终答案
        """
        print(f"\n🔍 处理问题: {question}")
        print("=" * 50)
        
        inputs = {"question": question}
        final_generation = None
        
        for output in self.app.stream(inputs):
            for key, value in output.items():
                if verbose:
                    pprint(f"节点 '{key}':")
                    # 可选:在每个节点打印完整状态
                    # pprint(value, indent=2, width=80, depth=None)
                final_generation = value.get("generation", final_generation)
            if verbose:
                pprint("\n---\n")
        
        print("🎯 最终答案:")
        print("-" * 30)
        print(final_generation)
        print("=" * 50)
        
        return final_generation
    
    def interactive_mode(self):
        """交互模式,允许用户持续提问"""
        print("\n🤖 欢迎使用自适应RAG系统!")
        print("💡 输入问题开始对话,输入 'quit' 或 'exit' 退出")
        print("-" * 50)
        
        while True:
            try:
                question = input("\n❓ 请输入您的问题: ").strip()
                
                if question.lower() in ['quit', 'exit', '退出', 'q']:
                    print("👋 感谢使用,再见!")
                    break
                
                if not question:
                    print("⚠️  请输入一个有效的问题")
                    continue
                
                self.query(question)
                
            except KeyboardInterrupt:
                print("\n👋 感谢使用,再见!")
                break
            except Exception as e:
                print(f"❌ 发生错误: {e}")
                print("请重试或输入 'quit' 退出")


def main():
    """主函数"""
    try:
        # 初始化系统
        rag_system: AdaptiveRAGSystem = AdaptiveRAGSystem()
        
        # 测试查询
        # test_question = "AlphaCodium论文讲的是什么?"
        test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
        rag_system.query(test_question)
        
        # 启动交互模式
        rag_system.interactive_mode()
        
    except Exception as e:
        print(f"❌ 系统初始化失败: {e}")
        print("请检查配置和依赖是否正确安装")


if __name__ == "__main__":
    main()