Spjimr / agent_langgraph.py
shahidshaikh's picture
Upload 40 files
a52bae4 verified
# ============================================================================
# agent_langgraph.py — LangGraph backend (supervisor + task nodes + edges)
# ============================================================================
#
# CONTRACT: BACKEND_NAME, get_client, run, build_code_snippets
#
# PATTERN — THE SUPERVISOR STATE GRAPH
# ------------------------------------
# Unlike the tool-calling loop in agent_py.py, LangGraph makes the control
# flow an EXPLICIT graph with named nodes and directed edges. This is
# the "supervisor" pattern: one router node dispatches work to one of
# several specialized task agents, each with a scoped set of tools.
#
# Nodes:
# supervisor — decides which task agent to call next, or to stop
# math_agent — handles arithmetic tools (add, multiply)
# info_agent — handles weather + ML paper catalog lookups
# respond — writes the final user-facing reply from accumulated results
#
# Edges:
# START -> supervisor
# supervisor -> math_agent (conditional)
# supervisor -> info_agent (conditional)
# supervisor -> respond (conditional)
# math_agent -> supervisor (loop back)
# info_agent -> supervisor (loop back)
# respond -> END
#
# IMPORT NOTE
# -----------
# Imports langchain_mistralai and langgraph. If either is missing,
# importing this module raises ImportError and app.py hides the
# LangGraph mode from the dropdown.
# ============================================================================
import os
import json
from typing import TypedDict, Annotated
from operator import add as _list_merge
from langchain_mistralai import ChatMistralAI
from langgraph.graph import StateGraph, START, END
from parameters import MODEL, TEMPERATURE, MAX_TOKENS, MAX_AGENT_STEPS
from tools import TOOL_FUNCTIONS, TOOL_SCHEMAS
BACKEND_NAME = "LangGraph Agent"
# ----------------------------------------------------------------
# Which tools belong to which task agent
# ----------------------------------------------------------------
MATH_TOOLS = {"add", "multiply"}
INFO_TOOLS = {"get_weather", "search_ml_examples", "ml_paper_info", "list_ml_papers"}
# ----------------------------------------------------------------
# Graph state — a TypedDict that flows through every node.
# The Annotated[list, _list_merge] tells LangGraph to CONCATENATE
# these lists when multiple nodes write to them, instead of replacing.
# ----------------------------------------------------------------
class AgentState(TypedDict):
user_message: str
steps: Annotated[list, _list_merge]
tool_results: Annotated[list, _list_merge]
next_action: str
reply: str
iteration: int
# ----------------------------------------------------------------
# Client
# ----------------------------------------------------------------
def get_client(api_key):
"""Return a configured ChatMistralAI (LangGraph uses LangChain's model)."""
key = (api_key or "").strip() or os.environ.get("MISTRAL_API_KEY", "")
return ChatMistralAI(
model=MODEL,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
mistral_api_key=key,
)
# ----------------------------------------------------------------
# NODE: supervisor
# Reads the user message plus any prior tool results and decides
# whether to dispatch to math_agent, info_agent, or respond.
# Uses simple prompt-based routing (ask for one word back) which is
# more reliable across providers than function-calling for this.
# ----------------------------------------------------------------
def supervisor_node(state, client):
iteration = state.get("iteration", 0) + 1
# Safety cap — prevent infinite loops
if iteration > MAX_AGENT_STEPS:
return {
"next_action": "respond",
"iteration": iteration,
"steps": [{
"step": iteration, "type": "limit", "tool": "supervisor",
"args": "-", "result": "max iterations reached",
}],
}
prior = state.get("tool_results", [])
prior_summary = (
"\n".join(f"- {r['tool']}({r['args']}) -> {r['result']}" for r in prior)
if prior else "none yet"
)
supervisor_prompt = (
"You are a supervisor routing tasks to specialized sub-agents.\n\n"
f"Original user message: {state['user_message']}\n\n"
f"Prior tool results:\n{prior_summary}\n\n"
"Available sub-agents:\n"
" math — handles arithmetic (add, multiply)\n"
" info — handles weather lookups and the ML paper catalog\n"
" respond — emit the final answer to the user "
"(choose this when all needed information has been gathered)\n\n"
"Reply with EXACTLY ONE WORD: math, info, or respond."
)
resp = client.invoke(supervisor_prompt)
text = (getattr(resp, "content", "") or "").strip().lower()
if "math" in text:
action = "math"
elif "info" in text:
action = "info"
else:
action = "respond"
return {
"next_action": action,
"iteration": iteration,
"steps": [{
"step": iteration,
"type": "llm_call",
"tool": "supervisor",
"args": state["user_message"][:80],
"result": f"routed to {action}",
}],
}
# ----------------------------------------------------------------
# Helper used by both task nodes — bind a scoped set of tools and
# make one LLM call, then execute whatever tool calls come back.
# ----------------------------------------------------------------
def _run_task_agent(state, client, tool_names, agent_label):
scoped_schemas = [
{"type": "function", "function": s["function"]}
for s in TOOL_SCHEMAS
if s["function"]["name"] in tool_names
]
model_with_tools = client.bind_tools(scoped_schemas)
prior = state.get("tool_results", [])
prior_str = (
"\n".join(f"- {r['tool']}({r['args']}) -> {r['result']}" for r in prior)
if prior else "none"
)
prompt = (
f"User asked: {state['user_message']}\n"
f"Prior tool results:\n{prior_str}\n\n"
f"You are the {agent_label}. Call the appropriate tool to make "
f"progress on the part of the request that falls in your scope."
)
resp = model_with_tools.invoke(prompt)
iteration = state.get("iteration", 0)
new_steps = []
new_results = []
for tc in (getattr(resp, "tool_calls", []) or []):
name = tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", None)
args = tc.get("args", {}) if isinstance(tc, dict) else getattr(tc, "args", {})
if name in TOOL_FUNCTIONS:
result = TOOL_FUNCTIONS[name](**args)
new_steps.append({
"step": iteration,
"type": "tool_call",
"tool": name,
"args": json.dumps(args, default=str),
"result": str(result),
})
new_results.append({
"tool": name,
"args": json.dumps(args, default=str),
"result": str(result),
})
if not new_steps:
# The task agent decided not to call any tool — record a no-op.
new_steps.append({
"step": iteration,
"type": "tool_call",
"tool": agent_label,
"args": state["user_message"][:80],
"result": "no tool call made",
})
return {"steps": new_steps, "tool_results": new_results}
# ----------------------------------------------------------------
# NODE: math_agent — scoped to arithmetic tools
# ----------------------------------------------------------------
def math_agent_node(state, client):
return _run_task_agent(state, client, MATH_TOOLS, "math_agent")
# ----------------------------------------------------------------
# NODE: info_agent — scoped to weather + ML catalog tools
# ----------------------------------------------------------------
def info_agent_node(state, client):
return _run_task_agent(state, client, INFO_TOOLS, "info_agent")
# ----------------------------------------------------------------
# NODE: respond — synthesize the final reply from accumulated results
# ----------------------------------------------------------------
def respond_node(state, client):
prior = state.get("tool_results", [])
prior_summary = (
"\n".join(f"- {r['tool']}({r['args']}) -> {r['result']}" for r in prior)
if prior else "no tools were called"
)
prompt = (
f"User asked: {state['user_message']}\n\n"
f"Tool results gathered:\n{prior_summary}\n\n"
"Write a clear, direct reply to the user based on these results."
)
resp = client.invoke(prompt)
reply = (getattr(resp, "content", "") or "").strip()
iteration = state.get("iteration", 0) + 1
return {
"reply": reply,
"steps": [{
"step": iteration,
"type": "final",
"tool": "respond",
"args": "-",
"result": reply,
}],
}
# ----------------------------------------------------------------
# ROUTER: conditional edge function from supervisor
# ----------------------------------------------------------------
def route_from_supervisor(state):
action = state.get("next_action", "respond")
if action == "math":
return "math_agent"
if action == "info":
return "info_agent"
return "respond"
# ----------------------------------------------------------------
# Graph builder — compiled on every run so the client is captured in closures
# ----------------------------------------------------------------
def _build_graph(client):
graph = StateGraph(AgentState)
graph.add_node("supervisor", lambda s: supervisor_node(s, client))
graph.add_node("math_agent", lambda s: math_agent_node(s, client))
graph.add_node("info_agent", lambda s: info_agent_node(s, client))
graph.add_node("respond", lambda s: respond_node(s, client))
graph.add_edge(START, "supervisor")
graph.add_conditional_edges(
"supervisor",
route_from_supervisor,
{
"math_agent": "math_agent",
"info_agent": "info_agent",
"respond": "respond",
},
)
graph.add_edge("math_agent", "supervisor")
graph.add_edge("info_agent", "supervisor")
graph.add_edge("respond", END)
return graph.compile()
def run(client, user_message):
"""Build and execute the state graph end-to-end."""
graph = _build_graph(client)
initial_state = {
"user_message": user_message,
"steps": [],
"tool_results": [],
"next_action": "",
"reply": "",
"iteration": 0,
}
final_state = graph.invoke(
initial_state,
config={"recursion_limit": MAX_AGENT_STEPS * 4},
)
# Renumber steps sequentially for display
steps = final_state.get("steps", [])
for i, s in enumerate(steps, start=1):
s["step"] = i
return {
"reply": final_state.get("reply", "") or "",
"steps": steps,
"extracted": {
"tool_results": final_state.get("tool_results", []),
"total_iterations": final_state.get("iteration", 0),
},
}
def build_code_snippets(user_message, steps):
lines = [
"# Backend: LangGraph (supervisor pattern)",
"# Explicit state graph with supervisor node + 2 task nodes + respond node.",
f"# User message: {user_message}",
"",
"from typing import TypedDict, Annotated",
"from operator import add",
"from langgraph.graph import StateGraph, START, END",
"from langchain_mistralai import ChatMistralAI",
"",
"class AgentState(TypedDict):",
" user_message: str",
" steps: Annotated[list, add] # concat across nodes",
" tool_results: Annotated[list, add] # concat across nodes",
" next_action: str # 'math', 'info', or 'respond'",
" reply: str",
" iteration: int",
"",
"# --- Build the graph ---",
"graph = StateGraph(AgentState)",
"graph.add_node('supervisor', supervisor_node)",
"graph.add_node('math_agent', math_agent_node)",
"graph.add_node('info_agent', info_agent_node)",
"graph.add_node('respond', respond_node)",
"",
"graph.add_edge(START, 'supervisor')",
"graph.add_conditional_edges(",
" 'supervisor', route_from_supervisor,",
" {",
" 'math_agent': 'math_agent',",
" 'info_agent': 'info_agent',",
" 'respond': 'respond',",
" },",
")",
"graph.add_edge('math_agent', 'supervisor') # loop back",
"graph.add_edge('info_agent', 'supervisor') # loop back",
"graph.add_edge('respond', END)",
"",
"compiled = graph.compile()",
f"final = compiled.invoke({{'user_message': {user_message!r}, ...}})",
"reply = final['reply']",
"",
"# ---------- actual step log ----------",
]
for s in steps:
lines.append(f"# Step {s['step']} [{s['type']}] node/tool={s['tool']}")
lines.append(f"# args: {s['args']}")
lines.append(f"# result: {s['result']}")
return "\n".join(lines)