|
|
|
|
|
""" |
|
|
Reactive Echo Agent |
|
|
|
|
|
Implements a LangGraph-based ReAct loop where the LLM decides whether to |
|
|
invoke tools and receives their results before continuing the conversation. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import operator |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional, TypedDict, Annotated |
|
|
|
|
|
from langchain_core.language_models import BaseLanguageModel |
|
|
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage |
|
|
from langchain_core.tools import BaseTool |
|
|
from langgraph.graph import END, StateGraph |
|
|
|
|
|
|
|
|
class ToolCallLog(TypedDict): |
|
|
"""Structured record of an executed tool call.""" |
|
|
|
|
|
timestamp: str |
|
|
tool_call_id: str |
|
|
name: str |
|
|
args: Any |
|
|
content: str |
|
|
|
|
|
|
|
|
class EchoAgentState(TypedDict): |
|
|
"""State carried through the LangGraph execution.""" |
|
|
|
|
|
messages: Annotated[List[AnyMessage], operator.add] |
|
|
|
|
|
|
|
|
class ReactiveEchoAgent: |
|
|
""" |
|
|
Minimal ReAct-style agent. |
|
|
|
|
|
The agent delegates decision making to the bound language model. Whenever |
|
|
the model emits tool calls, the specified LangChain tools are executed and |
|
|
their `ToolMessage` responses are appended to the conversation history |
|
|
before handing control back to the model. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: BaseLanguageModel, |
|
|
tools: List[BaseTool], |
|
|
*, |
|
|
system_prompt: str = "", |
|
|
checkpointer: Any = None, |
|
|
log_tools: bool = True, |
|
|
log_dir: Optional[str] = "logs", |
|
|
) -> None: |
|
|
self._system_prompt = system_prompt |
|
|
self._log_tools = log_tools |
|
|
self._log_dir = Path(log_dir or "logs") |
|
|
if self._log_tools: |
|
|
self._log_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
workflow = StateGraph(EchoAgentState) |
|
|
workflow.add_node("process", self._process_request) |
|
|
workflow.add_node("execute", self._execute_tools) |
|
|
workflow.add_conditional_edges("process", self._has_tool_calls, {True: "execute", False: END}) |
|
|
workflow.add_edge("execute", "process") |
|
|
workflow.set_entry_point("process") |
|
|
|
|
|
self.workflow = workflow.compile(checkpointer=checkpointer) |
|
|
self.tools = {tool.name: tool for tool in tools} |
|
|
self.model = model.bind_tools(list(self.tools.values())) |
|
|
|
|
|
@property |
|
|
def system_prompt(self) -> str: |
|
|
return self._system_prompt |
|
|
|
|
|
def update_system_prompt(self, prompt: str) -> None: |
|
|
"""Set a new system prompt for subsequent runs.""" |
|
|
self._system_prompt = prompt |
|
|
|
|
|
|
|
|
def _process_request(self, state: Dict[str, Any]) -> Dict[str, List[AnyMessage]]: |
|
|
messages: List[AnyMessage] = list(state.get("messages", [])) |
|
|
if self._system_prompt: |
|
|
messages = [SystemMessage(content=self._system_prompt)] + messages |
|
|
|
|
|
response = self.model.invoke(messages) |
|
|
return {"messages": [response]} |
|
|
|
|
|
def _has_tool_calls(self, state: Dict[str, Any]) -> bool: |
|
|
last_message = state["messages"][-1] |
|
|
return bool(getattr(last_message, "tool_calls", [])) |
|
|
|
|
|
def _execute_tools(self, state: Dict[str, Any]) -> Dict[str, List[ToolMessage]]: |
|
|
tool_messages: List[ToolMessage] = [] |
|
|
for call in state["messages"][-1].tool_calls: |
|
|
tool_name = call.get("name") |
|
|
tool_args = call.get("args", {}) |
|
|
tool_id = call.get("id", "") |
|
|
|
|
|
if tool_name not in self.tools: |
|
|
result_content = json.dumps( |
|
|
{"status": "error", "error": f"Unknown tool '{tool_name}'"}, ensure_ascii=False |
|
|
) |
|
|
else: |
|
|
try: |
|
|
result = self.tools[tool_name].invoke(tool_args) |
|
|
|
|
|
result_content = json.dumps(result, ensure_ascii=False, default=str) |
|
|
except Exception as exc: |
|
|
result_content = json.dumps( |
|
|
{"status": "error", "error": f"{type(exc).__name__}: {exc}"}, ensure_ascii=False |
|
|
) |
|
|
|
|
|
message = ToolMessage( |
|
|
tool_call_id=tool_id, |
|
|
name=tool_name or "unknown_tool", |
|
|
content=result_content, |
|
|
additional_kwargs={"args": tool_args}, |
|
|
) |
|
|
tool_messages.append(message) |
|
|
|
|
|
self._log_tool_messages(tool_messages) |
|
|
return {"messages": tool_messages} |
|
|
|
|
|
|
|
|
def _log_tool_messages(self, tool_messages: List[ToolMessage]) -> None: |
|
|
if not self._log_tools or not tool_messages: |
|
|
return |
|
|
|
|
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") |
|
|
log_path = self._log_dir / f"tool_calls_{timestamp}.json" |
|
|
logs: List[ToolCallLog] = [] |
|
|
for message in tool_messages: |
|
|
logs.append(ToolCallLog( |
|
|
tool_call_id=message.tool_call_id, |
|
|
name=message.name, |
|
|
args=message.additional_kwargs.get("args", {}), |
|
|
content=message.content, |
|
|
timestamp=datetime.utcnow().isoformat(), |
|
|
)) |
|
|
|
|
|
log_path.write_text(json.dumps(logs, indent=2), encoding="utf-8") |
|
|
|