# ============================================================================ # agent_langgraph.py — LangGraph backend (supervisor + task nodes + edges) # ============================================================================ # # CONTRACT: BACKEND_NAME, get_client, run, build_code_snippets # # PATTERN — THE SUPERVISOR STATE GRAPH # ------------------------------------ # Unlike the tool-calling loop in agent_py.py, LangGraph makes the control # flow an EXPLICIT graph with named nodes and directed edges. This is # the "supervisor" pattern: one router node dispatches work to one of # several specialized task agents, each with a scoped set of tools. # # Nodes: # supervisor — decides which task agent to call next, or to stop # math_agent — handles arithmetic tools (add, multiply) # info_agent — handles weather + ML paper catalog lookups # respond — writes the final user-facing reply from accumulated results # # Edges: # START -> supervisor # supervisor -> math_agent (conditional) # supervisor -> info_agent (conditional) # supervisor -> respond (conditional) # math_agent -> supervisor (loop back) # info_agent -> supervisor (loop back) # respond -> END # # IMPORT NOTE # ----------- # Imports langchain_mistralai and langgraph. If either is missing, # importing this module raises ImportError and app.py hides the # LangGraph mode from the dropdown. # ============================================================================ import os import json from typing import TypedDict, Annotated from operator import add as _list_merge from langchain_mistralai import ChatMistralAI from langgraph.graph import StateGraph, START, END from parameters import MODEL, TEMPERATURE, MAX_TOKENS, MAX_AGENT_STEPS from tools import TOOL_FUNCTIONS, TOOL_SCHEMAS BACKEND_NAME = "LangGraph Agent" # ---------------------------------------------------------------- # Which tools belong to which task agent # ---------------------------------------------------------------- MATH_TOOLS = {"add", "multiply"} INFO_TOOLS = {"get_weather", "search_ml_examples", "ml_paper_info", "list_ml_papers"} # ---------------------------------------------------------------- # Graph state — a TypedDict that flows through every node. # The Annotated[list, _list_merge] tells LangGraph to CONCATENATE # these lists when multiple nodes write to them, instead of replacing. # ---------------------------------------------------------------- class AgentState(TypedDict): user_message: str steps: Annotated[list, _list_merge] tool_results: Annotated[list, _list_merge] next_action: str reply: str iteration: int # ---------------------------------------------------------------- # Client # ---------------------------------------------------------------- def get_client(api_key): """Return a configured ChatMistralAI (LangGraph uses LangChain's model).""" key = (api_key or "").strip() or os.environ.get("MISTRAL_API_KEY", "") return ChatMistralAI( model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, mistral_api_key=key, ) # ---------------------------------------------------------------- # NODE: supervisor # Reads the user message plus any prior tool results and decides # whether to dispatch to math_agent, info_agent, or respond. # Uses simple prompt-based routing (ask for one word back) which is # more reliable across providers than function-calling for this. # ---------------------------------------------------------------- def supervisor_node(state, client): iteration = state.get("iteration", 0) + 1 # Safety cap — prevent infinite loops if iteration > MAX_AGENT_STEPS: return { "next_action": "respond", "iteration": iteration, "steps": [{ "step": iteration, "type": "limit", "tool": "supervisor", "args": "-", "result": "max iterations reached", }], } prior = state.get("tool_results", []) prior_summary = ( "\n".join(f"- {r['tool']}({r['args']}) -> {r['result']}" for r in prior) if prior else "none yet" ) supervisor_prompt = ( "You are a supervisor routing tasks to specialized sub-agents.\n\n" f"Original user message: {state['user_message']}\n\n" f"Prior tool results:\n{prior_summary}\n\n" "Available sub-agents:\n" " math — handles arithmetic (add, multiply)\n" " info — handles weather lookups and the ML paper catalog\n" " respond — emit the final answer to the user " "(choose this when all needed information has been gathered)\n\n" "Reply with EXACTLY ONE WORD: math, info, or respond." ) resp = client.invoke(supervisor_prompt) text = (getattr(resp, "content", "") or "").strip().lower() if "math" in text: action = "math" elif "info" in text: action = "info" else: action = "respond" return { "next_action": action, "iteration": iteration, "steps": [{ "step": iteration, "type": "llm_call", "tool": "supervisor", "args": state["user_message"][:80], "result": f"routed to {action}", }], } # ---------------------------------------------------------------- # Helper used by both task nodes — bind a scoped set of tools and # make one LLM call, then execute whatever tool calls come back. # ---------------------------------------------------------------- def _run_task_agent(state, client, tool_names, agent_label): scoped_schemas = [ {"type": "function", "function": s["function"]} for s in TOOL_SCHEMAS if s["function"]["name"] in tool_names ] model_with_tools = client.bind_tools(scoped_schemas) prior = state.get("tool_results", []) prior_str = ( "\n".join(f"- {r['tool']}({r['args']}) -> {r['result']}" for r in prior) if prior else "none" ) prompt = ( f"User asked: {state['user_message']}\n" f"Prior tool results:\n{prior_str}\n\n" f"You are the {agent_label}. Call the appropriate tool to make " f"progress on the part of the request that falls in your scope." ) resp = model_with_tools.invoke(prompt) iteration = state.get("iteration", 0) new_steps = [] new_results = [] for tc in (getattr(resp, "tool_calls", []) or []): name = tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", None) args = tc.get("args", {}) if isinstance(tc, dict) else getattr(tc, "args", {}) if name in TOOL_FUNCTIONS: result = TOOL_FUNCTIONS[name](**args) new_steps.append({ "step": iteration, "type": "tool_call", "tool": name, "args": json.dumps(args, default=str), "result": str(result), }) new_results.append({ "tool": name, "args": json.dumps(args, default=str), "result": str(result), }) if not new_steps: # The task agent decided not to call any tool — record a no-op. new_steps.append({ "step": iteration, "type": "tool_call", "tool": agent_label, "args": state["user_message"][:80], "result": "no tool call made", }) return {"steps": new_steps, "tool_results": new_results} # ---------------------------------------------------------------- # NODE: math_agent — scoped to arithmetic tools # ---------------------------------------------------------------- def math_agent_node(state, client): return _run_task_agent(state, client, MATH_TOOLS, "math_agent") # ---------------------------------------------------------------- # NODE: info_agent — scoped to weather + ML catalog tools # ---------------------------------------------------------------- def info_agent_node(state, client): return _run_task_agent(state, client, INFO_TOOLS, "info_agent") # ---------------------------------------------------------------- # NODE: respond — synthesize the final reply from accumulated results # ---------------------------------------------------------------- def respond_node(state, client): prior = state.get("tool_results", []) prior_summary = ( "\n".join(f"- {r['tool']}({r['args']}) -> {r['result']}" for r in prior) if prior else "no tools were called" ) prompt = ( f"User asked: {state['user_message']}\n\n" f"Tool results gathered:\n{prior_summary}\n\n" "Write a clear, direct reply to the user based on these results." ) resp = client.invoke(prompt) reply = (getattr(resp, "content", "") or "").strip() iteration = state.get("iteration", 0) + 1 return { "reply": reply, "steps": [{ "step": iteration, "type": "final", "tool": "respond", "args": "-", "result": reply, }], } # ---------------------------------------------------------------- # ROUTER: conditional edge function from supervisor # ---------------------------------------------------------------- def route_from_supervisor(state): action = state.get("next_action", "respond") if action == "math": return "math_agent" if action == "info": return "info_agent" return "respond" # ---------------------------------------------------------------- # Graph builder — compiled on every run so the client is captured in closures # ---------------------------------------------------------------- def _build_graph(client): graph = StateGraph(AgentState) graph.add_node("supervisor", lambda s: supervisor_node(s, client)) graph.add_node("math_agent", lambda s: math_agent_node(s, client)) graph.add_node("info_agent", lambda s: info_agent_node(s, client)) graph.add_node("respond", lambda s: respond_node(s, client)) graph.add_edge(START, "supervisor") graph.add_conditional_edges( "supervisor", route_from_supervisor, { "math_agent": "math_agent", "info_agent": "info_agent", "respond": "respond", }, ) graph.add_edge("math_agent", "supervisor") graph.add_edge("info_agent", "supervisor") graph.add_edge("respond", END) return graph.compile() def run(client, user_message): """Build and execute the state graph end-to-end.""" graph = _build_graph(client) initial_state = { "user_message": user_message, "steps": [], "tool_results": [], "next_action": "", "reply": "", "iteration": 0, } final_state = graph.invoke( initial_state, config={"recursion_limit": MAX_AGENT_STEPS * 4}, ) # Renumber steps sequentially for display steps = final_state.get("steps", []) for i, s in enumerate(steps, start=1): s["step"] = i return { "reply": final_state.get("reply", "") or "", "steps": steps, "extracted": { "tool_results": final_state.get("tool_results", []), "total_iterations": final_state.get("iteration", 0), }, } def build_code_snippets(user_message, steps): lines = [ "# Backend: LangGraph (supervisor pattern)", "# Explicit state graph with supervisor node + 2 task nodes + respond node.", f"# User message: {user_message}", "", "from typing import TypedDict, Annotated", "from operator import add", "from langgraph.graph import StateGraph, START, END", "from langchain_mistralai import ChatMistralAI", "", "class AgentState(TypedDict):", " user_message: str", " steps: Annotated[list, add] # concat across nodes", " tool_results: Annotated[list, add] # concat across nodes", " next_action: str # 'math', 'info', or 'respond'", " reply: str", " iteration: int", "", "# --- Build the graph ---", "graph = StateGraph(AgentState)", "graph.add_node('supervisor', supervisor_node)", "graph.add_node('math_agent', math_agent_node)", "graph.add_node('info_agent', info_agent_node)", "graph.add_node('respond', respond_node)", "", "graph.add_edge(START, 'supervisor')", "graph.add_conditional_edges(", " 'supervisor', route_from_supervisor,", " {", " 'math_agent': 'math_agent',", " 'info_agent': 'info_agent',", " 'respond': 'respond',", " },", ")", "graph.add_edge('math_agent', 'supervisor') # loop back", "graph.add_edge('info_agent', 'supervisor') # loop back", "graph.add_edge('respond', END)", "", "compiled = graph.compile()", f"final = compiled.invoke({{'user_message': {user_message!r}, ...}})", "reply = final['reply']", "", "# ---------- actual step log ----------", ] for s in steps: lines.append(f"# Step {s['step']} [{s['type']}] node/tool={s['tool']}") lines.append(f"# args: {s['args']}") lines.append(f"# result: {s['result']}") return "\n".join(lines)