File size: 13,896 Bytes
a52bae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
# ============================================================================
# agent_langgraph.py β€” LangGraph backend (supervisor + task nodes + edges)
# ============================================================================
#
# CONTRACT: BACKEND_NAME, get_client, run, build_code_snippets
#
# PATTERN β€” THE SUPERVISOR STATE GRAPH
# ------------------------------------
# Unlike the tool-calling loop in agent_py.py, LangGraph makes the control
# flow an EXPLICIT graph with named nodes and directed edges. This is
# the "supervisor" pattern: one router node dispatches work to one of
# several specialized task agents, each with a scoped set of tools.
#
# Nodes:
#   supervisor   β€” decides which task agent to call next, or to stop
#   math_agent   β€” handles arithmetic tools (add, multiply)
#   info_agent   β€” handles weather + ML paper catalog lookups
#   respond      β€” writes the final user-facing reply from accumulated results
#
# Edges:
#   START -> supervisor
#   supervisor -> math_agent      (conditional)
#   supervisor -> info_agent      (conditional)
#   supervisor -> respond         (conditional)
#   math_agent -> supervisor      (loop back)
#   info_agent -> supervisor      (loop back)
#   respond    -> END
#
# IMPORT NOTE
# -----------
# Imports langchain_mistralai and langgraph. If either is missing,
# importing this module raises ImportError and app.py hides the
# LangGraph mode from the dropdown.
# ============================================================================

import os
import json
from typing import TypedDict, Annotated
from operator import add as _list_merge

from langchain_mistralai import ChatMistralAI
from langgraph.graph import StateGraph, START, END

from parameters import MODEL, TEMPERATURE, MAX_TOKENS, MAX_AGENT_STEPS
from tools import TOOL_FUNCTIONS, TOOL_SCHEMAS


BACKEND_NAME = "LangGraph Agent"


# ----------------------------------------------------------------
# Which tools belong to which task agent
# ----------------------------------------------------------------
MATH_TOOLS = {"add", "multiply"}
INFO_TOOLS = {"get_weather", "search_ml_examples", "ml_paper_info", "list_ml_papers"}


# ----------------------------------------------------------------
# Graph state β€” a TypedDict that flows through every node.
# The Annotated[list, _list_merge] tells LangGraph to CONCATENATE
# these lists when multiple nodes write to them, instead of replacing.
# ----------------------------------------------------------------
class AgentState(TypedDict):
    user_message: str
    steps: Annotated[list, _list_merge]
    tool_results: Annotated[list, _list_merge]
    next_action: str
    reply: str
    iteration: int


# ----------------------------------------------------------------
# Client
# ----------------------------------------------------------------
def get_client(api_key):
    """Return a configured ChatMistralAI (LangGraph uses LangChain's model)."""
    key = (api_key or "").strip() or os.environ.get("MISTRAL_API_KEY", "")
    return ChatMistralAI(
        model=MODEL,
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS,
        mistral_api_key=key,
    )


# ----------------------------------------------------------------
# NODE: supervisor
# Reads the user message plus any prior tool results and decides
# whether to dispatch to math_agent, info_agent, or respond.
# Uses simple prompt-based routing (ask for one word back) which is
# more reliable across providers than function-calling for this.
# ----------------------------------------------------------------
def supervisor_node(state, client):
    iteration = state.get("iteration", 0) + 1

    # Safety cap β€” prevent infinite loops
    if iteration > MAX_AGENT_STEPS:
        return {
            "next_action": "respond",
            "iteration": iteration,
            "steps": [{
                "step": iteration, "type": "limit", "tool": "supervisor",
                "args": "-", "result": "max iterations reached",
            }],
        }

    prior = state.get("tool_results", [])
    prior_summary = (
        "\n".join(f"- {r['tool']}({r['args']}) -> {r['result']}" for r in prior)
        if prior else "none yet"
    )

    supervisor_prompt = (
        "You are a supervisor routing tasks to specialized sub-agents.\n\n"
        f"Original user message: {state['user_message']}\n\n"
        f"Prior tool results:\n{prior_summary}\n\n"
        "Available sub-agents:\n"
        "  math    β€” handles arithmetic (add, multiply)\n"
        "  info    β€” handles weather lookups and the ML paper catalog\n"
        "  respond β€” emit the final answer to the user "
        "(choose this when all needed information has been gathered)\n\n"
        "Reply with EXACTLY ONE WORD: math, info, or respond."
    )

    resp = client.invoke(supervisor_prompt)
    text = (getattr(resp, "content", "") or "").strip().lower()

    if "math" in text:
        action = "math"
    elif "info" in text:
        action = "info"
    else:
        action = "respond"

    return {
        "next_action": action,
        "iteration": iteration,
        "steps": [{
            "step": iteration,
            "type": "llm_call",
            "tool": "supervisor",
            "args": state["user_message"][:80],
            "result": f"routed to {action}",
        }],
    }


# ----------------------------------------------------------------
# Helper used by both task nodes β€” bind a scoped set of tools and
# make one LLM call, then execute whatever tool calls come back.
# ----------------------------------------------------------------
def _run_task_agent(state, client, tool_names, agent_label):
    scoped_schemas = [
        {"type": "function", "function": s["function"]}
        for s in TOOL_SCHEMAS
        if s["function"]["name"] in tool_names
    ]
    model_with_tools = client.bind_tools(scoped_schemas)

    prior = state.get("tool_results", [])
    prior_str = (
        "\n".join(f"- {r['tool']}({r['args']}) -> {r['result']}" for r in prior)
        if prior else "none"
    )

    prompt = (
        f"User asked: {state['user_message']}\n"
        f"Prior tool results:\n{prior_str}\n\n"
        f"You are the {agent_label}. Call the appropriate tool to make "
        f"progress on the part of the request that falls in your scope."
    )

    resp = model_with_tools.invoke(prompt)
    iteration = state.get("iteration", 0)

    new_steps = []
    new_results = []
    for tc in (getattr(resp, "tool_calls", []) or []):
        name = tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", None)
        args = tc.get("args", {}) if isinstance(tc, dict) else getattr(tc, "args", {})
        if name in TOOL_FUNCTIONS:
            result = TOOL_FUNCTIONS[name](**args)
            new_steps.append({
                "step": iteration,
                "type": "tool_call",
                "tool": name,
                "args": json.dumps(args, default=str),
                "result": str(result),
            })
            new_results.append({
                "tool": name,
                "args": json.dumps(args, default=str),
                "result": str(result),
            })

    if not new_steps:
        # The task agent decided not to call any tool β€” record a no-op.
        new_steps.append({
            "step": iteration,
            "type": "tool_call",
            "tool": agent_label,
            "args": state["user_message"][:80],
            "result": "no tool call made",
        })

    return {"steps": new_steps, "tool_results": new_results}


# ----------------------------------------------------------------
# NODE: math_agent  β€” scoped to arithmetic tools
# ----------------------------------------------------------------
def math_agent_node(state, client):
    return _run_task_agent(state, client, MATH_TOOLS, "math_agent")


# ----------------------------------------------------------------
# NODE: info_agent  β€” scoped to weather + ML catalog tools
# ----------------------------------------------------------------
def info_agent_node(state, client):
    return _run_task_agent(state, client, INFO_TOOLS, "info_agent")


# ----------------------------------------------------------------
# NODE: respond  β€” synthesize the final reply from accumulated results
# ----------------------------------------------------------------
def respond_node(state, client):
    prior = state.get("tool_results", [])
    prior_summary = (
        "\n".join(f"- {r['tool']}({r['args']}) -> {r['result']}" for r in prior)
        if prior else "no tools were called"
    )

    prompt = (
        f"User asked: {state['user_message']}\n\n"
        f"Tool results gathered:\n{prior_summary}\n\n"
        "Write a clear, direct reply to the user based on these results."
    )
    resp = client.invoke(prompt)
    reply = (getattr(resp, "content", "") or "").strip()

    iteration = state.get("iteration", 0) + 1
    return {
        "reply": reply,
        "steps": [{
            "step": iteration,
            "type": "final",
            "tool": "respond",
            "args": "-",
            "result": reply,
        }],
    }


# ----------------------------------------------------------------
# ROUTER: conditional edge function from supervisor
# ----------------------------------------------------------------
def route_from_supervisor(state):
    action = state.get("next_action", "respond")
    if action == "math":
        return "math_agent"
    if action == "info":
        return "info_agent"
    return "respond"


# ----------------------------------------------------------------
# Graph builder β€” compiled on every run so the client is captured in closures
# ----------------------------------------------------------------
def _build_graph(client):
    graph = StateGraph(AgentState)

    graph.add_node("supervisor", lambda s: supervisor_node(s, client))
    graph.add_node("math_agent", lambda s: math_agent_node(s, client))
    graph.add_node("info_agent", lambda s: info_agent_node(s, client))
    graph.add_node("respond", lambda s: respond_node(s, client))

    graph.add_edge(START, "supervisor")
    graph.add_conditional_edges(
        "supervisor",
        route_from_supervisor,
        {
            "math_agent": "math_agent",
            "info_agent": "info_agent",
            "respond": "respond",
        },
    )
    graph.add_edge("math_agent", "supervisor")
    graph.add_edge("info_agent", "supervisor")
    graph.add_edge("respond", END)

    return graph.compile()


def run(client, user_message):
    """Build and execute the state graph end-to-end."""
    graph = _build_graph(client)

    initial_state = {
        "user_message": user_message,
        "steps": [],
        "tool_results": [],
        "next_action": "",
        "reply": "",
        "iteration": 0,
    }

    final_state = graph.invoke(
        initial_state,
        config={"recursion_limit": MAX_AGENT_STEPS * 4},
    )

    # Renumber steps sequentially for display
    steps = final_state.get("steps", [])
    for i, s in enumerate(steps, start=1):
        s["step"] = i

    return {
        "reply": final_state.get("reply", "") or "",
        "steps": steps,
        "extracted": {
            "tool_results": final_state.get("tool_results", []),
            "total_iterations": final_state.get("iteration", 0),
        },
    }


def build_code_snippets(user_message, steps):
    lines = [
        "# Backend: LangGraph (supervisor pattern)",
        "# Explicit state graph with supervisor node + 2 task nodes + respond node.",
        f"# User message: {user_message}",
        "",
        "from typing import TypedDict, Annotated",
        "from operator import add",
        "from langgraph.graph import StateGraph, START, END",
        "from langchain_mistralai import ChatMistralAI",
        "",
        "class AgentState(TypedDict):",
        "    user_message: str",
        "    steps: Annotated[list, add]           # concat across nodes",
        "    tool_results: Annotated[list, add]   # concat across nodes",
        "    next_action: str                     # 'math', 'info', or 'respond'",
        "    reply: str",
        "    iteration: int",
        "",
        "# --- Build the graph ---",
        "graph = StateGraph(AgentState)",
        "graph.add_node('supervisor', supervisor_node)",
        "graph.add_node('math_agent', math_agent_node)",
        "graph.add_node('info_agent', info_agent_node)",
        "graph.add_node('respond',    respond_node)",
        "",
        "graph.add_edge(START, 'supervisor')",
        "graph.add_conditional_edges(",
        "    'supervisor', route_from_supervisor,",
        "    {",
        "        'math_agent': 'math_agent',",
        "        'info_agent': 'info_agent',",
        "        'respond':    'respond',",
        "    },",
        ")",
        "graph.add_edge('math_agent', 'supervisor')   # loop back",
        "graph.add_edge('info_agent', 'supervisor')   # loop back",
        "graph.add_edge('respond',    END)",
        "",
        "compiled = graph.compile()",
        f"final = compiled.invoke({{'user_message': {user_message!r}, ...}})",
        "reply = final['reply']",
        "",
        "# ---------- actual step log ----------",
    ]
    for s in steps:
        lines.append(f"# Step {s['step']} [{s['type']}] node/tool={s['tool']}")
        lines.append(f"#   args:   {s['args']}")
        lines.append(f"#   result: {s['result']}")
    return "\n".join(lines)