Echo / agents /react_agent.py
moein99's picture
Initial Echo Space
8f51ef2
#!/usr/bin/env python3
"""
Reactive Echo Agent
Implements a LangGraph-based ReAct loop where the LLM decides whether to
invoke tools and receives their results before continuing the conversation.
"""
from __future__ import annotations
import json
import operator
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, TypedDict, Annotated
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
class ToolCallLog(TypedDict):
"""Structured record of an executed tool call."""
timestamp: str
tool_call_id: str
name: str
args: Any
content: str
class EchoAgentState(TypedDict):
"""State carried through the LangGraph execution."""
messages: Annotated[List[AnyMessage], operator.add]
class ReactiveEchoAgent:
"""
Minimal ReAct-style agent.
The agent delegates decision making to the bound language model. Whenever
the model emits tool calls, the specified LangChain tools are executed and
their `ToolMessage` responses are appended to the conversation history
before handing control back to the model.
"""
def __init__(
self,
model: BaseLanguageModel,
tools: List[BaseTool],
*,
system_prompt: str = "",
checkpointer: Any = None,
log_tools: bool = True,
log_dir: Optional[str] = "logs",
) -> None:
self._system_prompt = system_prompt
self._log_tools = log_tools
self._log_dir = Path(log_dir or "logs")
if self._log_tools:
self._log_dir.mkdir(parents=True, exist_ok=True)
# Prepare LangGraph workflow
workflow = StateGraph(EchoAgentState)
workflow.add_node("process", self._process_request)
workflow.add_node("execute", self._execute_tools)
workflow.add_conditional_edges("process", self._has_tool_calls, {True: "execute", False: END})
workflow.add_edge("execute", "process")
workflow.set_entry_point("process")
self.workflow = workflow.compile(checkpointer=checkpointer)
self.tools = {tool.name: tool for tool in tools}
self.model = model.bind_tools(list(self.tools.values()))
@property
def system_prompt(self) -> str:
return self._system_prompt
def update_system_prompt(self, prompt: str) -> None:
"""Set a new system prompt for subsequent runs."""
self._system_prompt = prompt
# -- LangGraph node implementations -------------------------------------------------
def _process_request(self, state: Dict[str, Any]) -> Dict[str, List[AnyMessage]]:
messages: List[AnyMessage] = list(state.get("messages", []))
if self._system_prompt:
messages = [SystemMessage(content=self._system_prompt)] + messages
response = self.model.invoke(messages)
return {"messages": [response]}
def _has_tool_calls(self, state: Dict[str, Any]) -> bool:
last_message = state["messages"][-1]
return bool(getattr(last_message, "tool_calls", []))
def _execute_tools(self, state: Dict[str, Any]) -> Dict[str, List[ToolMessage]]:
tool_messages: List[ToolMessage] = []
for call in state["messages"][-1].tool_calls:
tool_name = call.get("name")
tool_args = call.get("args", {})
tool_id = call.get("id", "")
if tool_name not in self.tools:
result_content = json.dumps(
{"status": "error", "error": f"Unknown tool '{tool_name}'"}, ensure_ascii=False
)
else:
try:
result = self.tools[tool_name].invoke(tool_args)
# Tool results can be complex objects; coerce to JSON string if possible.
result_content = json.dumps(result, ensure_ascii=False, default=str)
except Exception as exc: # noqa: BLE001
result_content = json.dumps(
{"status": "error", "error": f"{type(exc).__name__}: {exc}"}, ensure_ascii=False
)
message = ToolMessage(
tool_call_id=tool_id,
name=tool_name or "unknown_tool",
content=result_content,
additional_kwargs={"args": tool_args},
)
tool_messages.append(message)
self._log_tool_messages(tool_messages)
return {"messages": tool_messages}
# -- Helpers ------------------------------------------------------------------------
def _log_tool_messages(self, tool_messages: List[ToolMessage]) -> None:
if not self._log_tools or not tool_messages:
return
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
log_path = self._log_dir / f"tool_calls_{timestamp}.json"
logs: List[ToolCallLog] = []
for message in tool_messages:
logs.append(ToolCallLog(
tool_call_id=message.tool_call_id,
name=message.name,
args=message.additional_kwargs.get("args", {}),
content=message.content,
timestamp=datetime.utcnow().isoformat(),
))
log_path.write_text(json.dumps(logs, indent=2), encoding="utf-8")