File size: 1,478 Bytes
16d5a75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from langgraph.graph import StateGraph, START, END
from .func import (
    State,
    trim_history,
    execute_tool,
    generate_answer_rag,
)
from langgraph.graph.state import CompiledStateGraph
from langgraph.checkpoint.memory import InMemorySaver


class RAGAgentTemplate:
    def __init__(self):
        self.builder = StateGraph(State)

    @staticmethod
    def should_continue(state: State):
        messages = state["messages"]
        last_message = messages[-1]
        if last_message.tool_calls:
            return "execute_tool"
        return END

    def node(self):
        self.builder.add_node("trim_history", trim_history)
        self.builder.add_node("generate_answer_rag", generate_answer_rag)
        self.builder.add_node("execute_tool", execute_tool)

    def edge(self):
        self.builder.add_edge(START, "trim_history")
        self.builder.add_edge("trim_history", "generate_answer_rag")
        self.builder.add_conditional_edges(
            "generate_answer_rag",
            self.should_continue,
            {
                END: END,
                "execute_tool": "execute_tool",
            },
        )
        self.builder.add_edge("execute_tool", "generate_answer_rag")
        self.builder.add_edge("generate_answer_rag", END)

    def __call__(self) -> CompiledStateGraph:
        self.node()
        self.edge()

        return self.builder.compile(checkpointer=InMemorySaver())


rag_agent_template_agent = RAGAgentTemplate()()