|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import json
|
| from typing import TypedDict, Annotated
|
| from operator import add as _list_merge
|
|
|
| from langchain_mistralai import ChatMistralAI
|
| from langgraph.graph import StateGraph, START, END
|
|
|
| from parameters import MODEL, TEMPERATURE, MAX_TOKENS, MAX_AGENT_STEPS
|
| from tools import TOOL_FUNCTIONS, TOOL_SCHEMAS
|
|
|
|
|
| BACKEND_NAME = "LangGraph Agent"
|
|
|
|
|
|
|
|
|
|
|
| MATH_TOOLS = {"add", "multiply"}
|
| INFO_TOOLS = {"get_weather", "search_ml_examples", "ml_paper_info", "list_ml_papers"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class AgentState(TypedDict):
|
| user_message: str
|
| steps: Annotated[list, _list_merge]
|
| tool_results: Annotated[list, _list_merge]
|
| next_action: str
|
| reply: str
|
| iteration: int
|
|
|
|
|
|
|
|
|
|
|
| def get_client(api_key):
|
| """Return a configured ChatMistralAI (LangGraph uses LangChain's model)."""
|
| key = (api_key or "").strip() or os.environ.get("MISTRAL_API_KEY", "")
|
| return ChatMistralAI(
|
| model=MODEL,
|
| temperature=TEMPERATURE,
|
| max_tokens=MAX_TOKENS,
|
| mistral_api_key=key,
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def supervisor_node(state, client):
|
| iteration = state.get("iteration", 0) + 1
|
|
|
|
|
| if iteration > MAX_AGENT_STEPS:
|
| return {
|
| "next_action": "respond",
|
| "iteration": iteration,
|
| "steps": [{
|
| "step": iteration, "type": "limit", "tool": "supervisor",
|
| "args": "-", "result": "max iterations reached",
|
| }],
|
| }
|
|
|
| prior = state.get("tool_results", [])
|
| prior_summary = (
|
| "\n".join(f"- {r['tool']}({r['args']}) -> {r['result']}" for r in prior)
|
| if prior else "none yet"
|
| )
|
|
|
| supervisor_prompt = (
|
| "You are a supervisor routing tasks to specialized sub-agents.\n\n"
|
| f"Original user message: {state['user_message']}\n\n"
|
| f"Prior tool results:\n{prior_summary}\n\n"
|
| "Available sub-agents:\n"
|
| " math — handles arithmetic (add, multiply)\n"
|
| " info — handles weather lookups and the ML paper catalog\n"
|
| " respond — emit the final answer to the user "
|
| "(choose this when all needed information has been gathered)\n\n"
|
| "Reply with EXACTLY ONE WORD: math, info, or respond."
|
| )
|
|
|
| resp = client.invoke(supervisor_prompt)
|
| text = (getattr(resp, "content", "") or "").strip().lower()
|
|
|
| if "math" in text:
|
| action = "math"
|
| elif "info" in text:
|
| action = "info"
|
| else:
|
| action = "respond"
|
|
|
| return {
|
| "next_action": action,
|
| "iteration": iteration,
|
| "steps": [{
|
| "step": iteration,
|
| "type": "llm_call",
|
| "tool": "supervisor",
|
| "args": state["user_message"][:80],
|
| "result": f"routed to {action}",
|
| }],
|
| }
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _run_task_agent(state, client, tool_names, agent_label):
|
| scoped_schemas = [
|
| {"type": "function", "function": s["function"]}
|
| for s in TOOL_SCHEMAS
|
| if s["function"]["name"] in tool_names
|
| ]
|
| model_with_tools = client.bind_tools(scoped_schemas)
|
|
|
| prior = state.get("tool_results", [])
|
| prior_str = (
|
| "\n".join(f"- {r['tool']}({r['args']}) -> {r['result']}" for r in prior)
|
| if prior else "none"
|
| )
|
|
|
| prompt = (
|
| f"User asked: {state['user_message']}\n"
|
| f"Prior tool results:\n{prior_str}\n\n"
|
| f"You are the {agent_label}. Call the appropriate tool to make "
|
| f"progress on the part of the request that falls in your scope."
|
| )
|
|
|
| resp = model_with_tools.invoke(prompt)
|
| iteration = state.get("iteration", 0)
|
|
|
| new_steps = []
|
| new_results = []
|
| for tc in (getattr(resp, "tool_calls", []) or []):
|
| name = tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", None)
|
| args = tc.get("args", {}) if isinstance(tc, dict) else getattr(tc, "args", {})
|
| if name in TOOL_FUNCTIONS:
|
| result = TOOL_FUNCTIONS[name](**args)
|
| new_steps.append({
|
| "step": iteration,
|
| "type": "tool_call",
|
| "tool": name,
|
| "args": json.dumps(args, default=str),
|
| "result": str(result),
|
| })
|
| new_results.append({
|
| "tool": name,
|
| "args": json.dumps(args, default=str),
|
| "result": str(result),
|
| })
|
|
|
| if not new_steps:
|
|
|
| new_steps.append({
|
| "step": iteration,
|
| "type": "tool_call",
|
| "tool": agent_label,
|
| "args": state["user_message"][:80],
|
| "result": "no tool call made",
|
| })
|
|
|
| return {"steps": new_steps, "tool_results": new_results}
|
|
|
|
|
|
|
|
|
|
|
| def math_agent_node(state, client):
|
| return _run_task_agent(state, client, MATH_TOOLS, "math_agent")
|
|
|
|
|
|
|
|
|
|
|
| def info_agent_node(state, client):
|
| return _run_task_agent(state, client, INFO_TOOLS, "info_agent")
|
|
|
|
|
|
|
|
|
|
|
| def respond_node(state, client):
|
| prior = state.get("tool_results", [])
|
| prior_summary = (
|
| "\n".join(f"- {r['tool']}({r['args']}) -> {r['result']}" for r in prior)
|
| if prior else "no tools were called"
|
| )
|
|
|
| prompt = (
|
| f"User asked: {state['user_message']}\n\n"
|
| f"Tool results gathered:\n{prior_summary}\n\n"
|
| "Write a clear, direct reply to the user based on these results."
|
| )
|
| resp = client.invoke(prompt)
|
| reply = (getattr(resp, "content", "") or "").strip()
|
|
|
| iteration = state.get("iteration", 0) + 1
|
| return {
|
| "reply": reply,
|
| "steps": [{
|
| "step": iteration,
|
| "type": "final",
|
| "tool": "respond",
|
| "args": "-",
|
| "result": reply,
|
| }],
|
| }
|
|
|
|
|
|
|
|
|
|
|
| def route_from_supervisor(state):
|
| action = state.get("next_action", "respond")
|
| if action == "math":
|
| return "math_agent"
|
| if action == "info":
|
| return "info_agent"
|
| return "respond"
|
|
|
|
|
|
|
|
|
|
|
| def _build_graph(client):
|
| graph = StateGraph(AgentState)
|
|
|
| graph.add_node("supervisor", lambda s: supervisor_node(s, client))
|
| graph.add_node("math_agent", lambda s: math_agent_node(s, client))
|
| graph.add_node("info_agent", lambda s: info_agent_node(s, client))
|
| graph.add_node("respond", lambda s: respond_node(s, client))
|
|
|
| graph.add_edge(START, "supervisor")
|
| graph.add_conditional_edges(
|
| "supervisor",
|
| route_from_supervisor,
|
| {
|
| "math_agent": "math_agent",
|
| "info_agent": "info_agent",
|
| "respond": "respond",
|
| },
|
| )
|
| graph.add_edge("math_agent", "supervisor")
|
| graph.add_edge("info_agent", "supervisor")
|
| graph.add_edge("respond", END)
|
|
|
| return graph.compile()
|
|
|
|
|
| def run(client, user_message):
|
| """Build and execute the state graph end-to-end."""
|
| graph = _build_graph(client)
|
|
|
| initial_state = {
|
| "user_message": user_message,
|
| "steps": [],
|
| "tool_results": [],
|
| "next_action": "",
|
| "reply": "",
|
| "iteration": 0,
|
| }
|
|
|
| final_state = graph.invoke(
|
| initial_state,
|
| config={"recursion_limit": MAX_AGENT_STEPS * 4},
|
| )
|
|
|
|
|
| steps = final_state.get("steps", [])
|
| for i, s in enumerate(steps, start=1):
|
| s["step"] = i
|
|
|
| return {
|
| "reply": final_state.get("reply", "") or "",
|
| "steps": steps,
|
| "extracted": {
|
| "tool_results": final_state.get("tool_results", []),
|
| "total_iterations": final_state.get("iteration", 0),
|
| },
|
| }
|
|
|
|
|
| def build_code_snippets(user_message, steps):
|
| lines = [
|
| "# Backend: LangGraph (supervisor pattern)",
|
| "# Explicit state graph with supervisor node + 2 task nodes + respond node.",
|
| f"# User message: {user_message}",
|
| "",
|
| "from typing import TypedDict, Annotated",
|
| "from operator import add",
|
| "from langgraph.graph import StateGraph, START, END",
|
| "from langchain_mistralai import ChatMistralAI",
|
| "",
|
| "class AgentState(TypedDict):",
|
| " user_message: str",
|
| " steps: Annotated[list, add] # concat across nodes",
|
| " tool_results: Annotated[list, add] # concat across nodes",
|
| " next_action: str # 'math', 'info', or 'respond'",
|
| " reply: str",
|
| " iteration: int",
|
| "",
|
| "# --- Build the graph ---",
|
| "graph = StateGraph(AgentState)",
|
| "graph.add_node('supervisor', supervisor_node)",
|
| "graph.add_node('math_agent', math_agent_node)",
|
| "graph.add_node('info_agent', info_agent_node)",
|
| "graph.add_node('respond', respond_node)",
|
| "",
|
| "graph.add_edge(START, 'supervisor')",
|
| "graph.add_conditional_edges(",
|
| " 'supervisor', route_from_supervisor,",
|
| " {",
|
| " 'math_agent': 'math_agent',",
|
| " 'info_agent': 'info_agent',",
|
| " 'respond': 'respond',",
|
| " },",
|
| ")",
|
| "graph.add_edge('math_agent', 'supervisor') # loop back",
|
| "graph.add_edge('info_agent', 'supervisor') # loop back",
|
| "graph.add_edge('respond', END)",
|
| "",
|
| "compiled = graph.compile()",
|
| f"final = compiled.invoke({{'user_message': {user_message!r}, ...}})",
|
| "reply = final['reply']",
|
| "",
|
| "# ---------- actual step log ----------",
|
| ]
|
| for s in steps:
|
| lines.append(f"# Step {s['step']} [{s['type']}] node/tool={s['tool']}")
|
| lines.append(f"# args: {s['args']}")
|
| lines.append(f"# result: {s['result']}")
|
| return "\n".join(lines)
|
|
|