File size: 1,742 Bytes
30ee88a
 
 
 
 
 
 
 
0fc97a4
 
30ee88a
 
 
 
 
aa018e3
0fc97a4
 
 
 
 
aa018e3
 
0fc97a4
 
 
 
 
 
 
 
 
 
aa018e3
0fc97a4
aa018e3
 
 
 
 
 
 
94e0eef
 
0fc97a4
 
30ee88a
0fc97a4
aa018e3
30ee88a
0fc97a4
30ee88a
 
 
 
 
0fc97a4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from langgraph.graph import START, StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode, tools_condition
from functools import partial

from .state import AgentState
from .nodes import *


def create_agent_graph(llm, vectordb, search_tools) -> StateGraph:
    """Create the RAG agent graph."""

    graph = StateGraph(AgentState)
    checkpointer = MemorySaver()

    llm_with_tools = llm.bind_tools(search_tools)
    web_search_tool_node = ToolNode(search_tools) 

    # --- Nodes ---
    graph.add_node("router_node", partial(router_node, llm=llm)) 
    graph.add_node("vectordb_node", partial(vectordb_node, vectorstore=vectordb))
    graph.add_node("web_search_agent_node", partial(web_search_agent_node, llm=llm_with_tools))  
    graph.add_node("web_search_tool_node", web_search_tool_node)
    graph.add_node("generate_node", partial(generate_node, llm=llm))

    # --- Edges ---
    graph.add_edge(START, "router_node")

    graph.add_conditional_edges(
        "router_node",
        routing_logic,
        {
            "vectordb_node": "vectordb_node",
            "web_search_agent_node": "web_search_agent_node",
            "generate_node": "generate_node", 
        }
    )

    graph.add_conditional_edges(
        "web_search_agent_node",
        tools_condition,
        {
            "tools": "web_search_tool_node",  
            "__end__": "generate_node",  
        }
    )

    graph.add_edge("vectordb_node", "generate_node")
    graph.add_edge("web_search_tool_node", "generate_node")

    graph.add_edge("generate_node", END)

    agent_graph = graph.compile(
        checkpointer=checkpointer,
    )

    return agent_graph

if __name__ == "__main__":
    pass