ohollo's picture
Split up experiment files
b5b19b9
Raw
History Blame Contribute Delete
1.71 kB
import asyncio
from langchain_core.language_models import BaseChatModel
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from .base import Agent, AgentResult
def build_graph(llm: BaseChatModel, tools: list) -> CompiledStateGraph:
"""Build and compile a simple ReAct graph for a given LLM and tool list."""
llm_with_tools = llm.bind_tools(tools)
def _call_model(state: MessagesState):
return {"messages": [llm_with_tools.invoke(state["messages"])]}
def _should_continue(state: MessagesState):
return "tools" if state["messages"][-1].tool_calls else END
graph = StateGraph(MessagesState)
graph.add_node("agent", _call_model)
graph.add_node("tools", ToolNode(tools))
graph.add_edge(START, "agent")
graph.add_conditional_edges("agent", _should_continue, ["tools", END])
graph.add_edge("tools", "agent")
return graph.compile()
class LangGraphAgent(Agent):
"""Agent wrapping any compiled LangGraph graph.
The graph must accept ``{"user_input": str}`` as input and include a
``"response"`` key in its output.
:param mcp_url: MCP server endpoint (stored for reference; tools should
be pre-loaded and bound into the graph before passing it here).
:param graph: Compiled graph to run.
"""
def __init__(self, mcp_url: str, graph: CompiledStateGraph):
super().__init__(mcp_url)
self._graph = graph
def run(self, user_prompt: str) -> AgentResult:
result = self._graph.invoke({"user_input": user_prompt})
return AgentResult(response=result["response"], tool_calls=[])