File size: 1,709 Bytes
ca2bb92
 
 
 
 
 
 
b5b19b9
ca2bb92
 
 
b5b19b9
ca2bb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5b19b9
ca2bb92
b5b19b9
 
ca2bb92
b5b19b9
 
 
 
ca2bb92
b5b19b9
 
 
ca2bb92
b5b19b9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import asyncio

from langchain_core.language_models import BaseChatModel
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode

from .base import Agent, AgentResult


def build_graph(llm: BaseChatModel, tools: list) -> CompiledStateGraph:
    """Build and compile a simple ReAct graph for a given LLM and tool list."""
    llm_with_tools = llm.bind_tools(tools)

    def _call_model(state: MessagesState):
        return {"messages": [llm_with_tools.invoke(state["messages"])]}

    def _should_continue(state: MessagesState):
        return "tools" if state["messages"][-1].tool_calls else END

    graph = StateGraph(MessagesState)
    graph.add_node("agent", _call_model)
    graph.add_node("tools", ToolNode(tools))
    graph.add_edge(START, "agent")
    graph.add_conditional_edges("agent", _should_continue, ["tools", END])
    graph.add_edge("tools", "agent")
    return graph.compile()


class LangGraphAgent(Agent):
    """Agent wrapping any compiled LangGraph graph.

    The graph must accept ``{"user_input": str}`` as input and include a
    ``"response"`` key in its output.

    :param mcp_url: MCP server endpoint (stored for reference; tools should
        be pre-loaded and bound into the graph before passing it here).
    :param graph: Compiled graph to run.
    """

    def __init__(self, mcp_url: str, graph: CompiledStateGraph):
        super().__init__(mcp_url)
        self._graph = graph

    def run(self, user_prompt: str) -> AgentResult:
        result = self._graph.invoke({"user_input": user_prompt})
        return AgentResult(response=result["response"], tool_calls=[])