rag_agent / agent /graph.py
Cheh Kit Hong
changed rag method flags
94e0eef
from langgraph.graph import START, StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode, tools_condition
from functools import partial
from .state import AgentState
from .nodes import *
def create_agent_graph(llm, vectordb, search_tools) -> StateGraph:
"""Create the RAG agent graph."""
graph = StateGraph(AgentState)
checkpointer = MemorySaver()
llm_with_tools = llm.bind_tools(search_tools)
web_search_tool_node = ToolNode(search_tools)
# --- Nodes ---
graph.add_node("router_node", partial(router_node, llm=llm))
graph.add_node("vectordb_node", partial(vectordb_node, vectorstore=vectordb))
graph.add_node("web_search_agent_node", partial(web_search_agent_node, llm=llm_with_tools))
graph.add_node("web_search_tool_node", web_search_tool_node)
graph.add_node("generate_node", partial(generate_node, llm=llm))
# --- Edges ---
graph.add_edge(START, "router_node")
graph.add_conditional_edges(
"router_node",
routing_logic,
{
"vectordb_node": "vectordb_node",
"web_search_agent_node": "web_search_agent_node",
"generate_node": "generate_node",
}
)
graph.add_conditional_edges(
"web_search_agent_node",
tools_condition,
{
"tools": "web_search_tool_node",
"__end__": "generate_node",
}
)
graph.add_edge("vectordb_node", "generate_node")
graph.add_edge("web_search_tool_node", "generate_node")
graph.add_edge("generate_node", END)
agent_graph = graph.compile(
checkpointer=checkpointer,
)
return agent_graph
if __name__ == "__main__":
pass