|
|
from langgraph.graph import START, StateGraph, END |
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
from functools import partial |
|
|
|
|
|
from .state import AgentState |
|
|
from .nodes import * |
|
|
|
|
|
|
|
|
def create_agent_graph(llm, vectordb, search_tools) -> StateGraph: |
|
|
"""Create the RAG agent graph.""" |
|
|
|
|
|
graph = StateGraph(AgentState) |
|
|
checkpointer = MemorySaver() |
|
|
|
|
|
llm_with_tools = llm.bind_tools(search_tools) |
|
|
web_search_tool_node = ToolNode(search_tools) |
|
|
|
|
|
|
|
|
graph.add_node("router_node", partial(router_node, llm=llm)) |
|
|
graph.add_node("vectordb_node", partial(vectordb_node, vectorstore=vectordb)) |
|
|
graph.add_node("web_search_agent_node", partial(web_search_agent_node, llm=llm_with_tools)) |
|
|
graph.add_node("web_search_tool_node", web_search_tool_node) |
|
|
graph.add_node("generate_node", partial(generate_node, llm=llm)) |
|
|
|
|
|
|
|
|
graph.add_edge(START, "router_node") |
|
|
|
|
|
graph.add_conditional_edges( |
|
|
"router_node", |
|
|
routing_logic, |
|
|
{ |
|
|
"vectordb_node": "vectordb_node", |
|
|
"web_search_agent_node": "web_search_agent_node", |
|
|
"generate_node": "generate_node", |
|
|
} |
|
|
) |
|
|
|
|
|
graph.add_conditional_edges( |
|
|
"web_search_agent_node", |
|
|
tools_condition, |
|
|
{ |
|
|
"tools": "web_search_tool_node", |
|
|
"__end__": "generate_node", |
|
|
} |
|
|
) |
|
|
|
|
|
graph.add_edge("vectordb_node", "generate_node") |
|
|
graph.add_edge("web_search_tool_node", "generate_node") |
|
|
|
|
|
graph.add_edge("generate_node", END) |
|
|
|
|
|
agent_graph = graph.compile( |
|
|
checkpointer=checkpointer, |
|
|
) |
|
|
|
|
|
return agent_graph |
|
|
|
|
|
if __name__ == "__main__": |
|
|
pass |