File size: 5,440 Bytes
8f51ef2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
#!/usr/bin/env python3
"""
Reactive Echo Agent
Implements a LangGraph-based ReAct loop where the LLM decides whether to
invoke tools and receives their results before continuing the conversation.
"""
from __future__ import annotations
import json
import operator
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, TypedDict, Annotated
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
class ToolCallLog(TypedDict):
"""Structured record of an executed tool call."""
timestamp: str
tool_call_id: str
name: str
args: Any
content: str
class EchoAgentState(TypedDict):
"""State carried through the LangGraph execution."""
messages: Annotated[List[AnyMessage], operator.add]
class ReactiveEchoAgent:
"""
Minimal ReAct-style agent.
The agent delegates decision making to the bound language model. Whenever
the model emits tool calls, the specified LangChain tools are executed and
their `ToolMessage` responses are appended to the conversation history
before handing control back to the model.
"""
def __init__(
self,
model: BaseLanguageModel,
tools: List[BaseTool],
*,
system_prompt: str = "",
checkpointer: Any = None,
log_tools: bool = True,
log_dir: Optional[str] = "logs",
) -> None:
self._system_prompt = system_prompt
self._log_tools = log_tools
self._log_dir = Path(log_dir or "logs")
if self._log_tools:
self._log_dir.mkdir(parents=True, exist_ok=True)
# Prepare LangGraph workflow
workflow = StateGraph(EchoAgentState)
workflow.add_node("process", self._process_request)
workflow.add_node("execute", self._execute_tools)
workflow.add_conditional_edges("process", self._has_tool_calls, {True: "execute", False: END})
workflow.add_edge("execute", "process")
workflow.set_entry_point("process")
self.workflow = workflow.compile(checkpointer=checkpointer)
self.tools = {tool.name: tool for tool in tools}
self.model = model.bind_tools(list(self.tools.values()))
@property
def system_prompt(self) -> str:
return self._system_prompt
def update_system_prompt(self, prompt: str) -> None:
"""Set a new system prompt for subsequent runs."""
self._system_prompt = prompt
# -- LangGraph node implementations -------------------------------------------------
def _process_request(self, state: Dict[str, Any]) -> Dict[str, List[AnyMessage]]:
messages: List[AnyMessage] = list(state.get("messages", []))
if self._system_prompt:
messages = [SystemMessage(content=self._system_prompt)] + messages
response = self.model.invoke(messages)
return {"messages": [response]}
def _has_tool_calls(self, state: Dict[str, Any]) -> bool:
last_message = state["messages"][-1]
return bool(getattr(last_message, "tool_calls", []))
def _execute_tools(self, state: Dict[str, Any]) -> Dict[str, List[ToolMessage]]:
tool_messages: List[ToolMessage] = []
for call in state["messages"][-1].tool_calls:
tool_name = call.get("name")
tool_args = call.get("args", {})
tool_id = call.get("id", "")
if tool_name not in self.tools:
result_content = json.dumps(
{"status": "error", "error": f"Unknown tool '{tool_name}'"}, ensure_ascii=False
)
else:
try:
result = self.tools[tool_name].invoke(tool_args)
# Tool results can be complex objects; coerce to JSON string if possible.
result_content = json.dumps(result, ensure_ascii=False, default=str)
except Exception as exc: # noqa: BLE001
result_content = json.dumps(
{"status": "error", "error": f"{type(exc).__name__}: {exc}"}, ensure_ascii=False
)
message = ToolMessage(
tool_call_id=tool_id,
name=tool_name or "unknown_tool",
content=result_content,
additional_kwargs={"args": tool_args},
)
tool_messages.append(message)
self._log_tool_messages(tool_messages)
return {"messages": tool_messages}
# -- Helpers ------------------------------------------------------------------------
def _log_tool_messages(self, tool_messages: List[ToolMessage]) -> None:
if not self._log_tools or not tool_messages:
return
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
log_path = self._log_dir / f"tool_calls_{timestamp}.json"
logs: List[ToolCallLog] = []
for message in tool_messages:
logs.append(ToolCallLog(
tool_call_id=message.tool_call_id,
name=message.name,
args=message.additional_kwargs.get("args", {}),
content=message.content,
timestamp=datetime.utcnow().isoformat(),
))
log_path.write_text(json.dumps(logs, indent=2), encoding="utf-8")
|