from langgraph.graph import START, StateGraph, END from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import ToolNode, tools_condition from functools import partial from .state import AgentState from .nodes import * def create_agent_graph(llm, vectordb, search_tools) -> StateGraph: """Create the RAG agent graph.""" graph = StateGraph(AgentState) checkpointer = MemorySaver() llm_with_tools = llm.bind_tools(search_tools) web_search_tool_node = ToolNode(search_tools) # --- Nodes --- graph.add_node("router_node", partial(router_node, llm=llm)) graph.add_node("vectordb_node", partial(vectordb_node, vectorstore=vectordb)) graph.add_node("web_search_agent_node", partial(web_search_agent_node, llm=llm_with_tools)) graph.add_node("web_search_tool_node", web_search_tool_node) graph.add_node("generate_node", partial(generate_node, llm=llm)) # --- Edges --- graph.add_edge(START, "router_node") graph.add_conditional_edges( "router_node", routing_logic, { "vectordb_node": "vectordb_node", "web_search_agent_node": "web_search_agent_node", "generate_node": "generate_node", } ) graph.add_conditional_edges( "web_search_agent_node", tools_condition, { "tools": "web_search_tool_node", "__end__": "generate_node", } ) graph.add_edge("vectordb_node", "generate_node") graph.add_edge("web_search_tool_node", "generate_node") graph.add_edge("generate_node", END) agent_graph = graph.compile( checkpointer=checkpointer, ) return agent_graph if __name__ == "__main__": pass