File size: 5,440 Bytes
8f51ef2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python3
"""
Reactive Echo Agent

Implements a LangGraph-based ReAct loop where the LLM decides whether to
invoke tools and receives their results before continuing the conversation.
"""

from __future__ import annotations

import json
import operator
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, TypedDict, Annotated

from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph


class ToolCallLog(TypedDict):
    """Structured record of an executed tool call."""

    timestamp: str
    tool_call_id: str
    name: str
    args: Any
    content: str


class EchoAgentState(TypedDict):
    """State carried through the LangGraph execution."""

    messages: Annotated[List[AnyMessage], operator.add]


class ReactiveEchoAgent:
    """
    Minimal ReAct-style agent.

    The agent delegates decision making to the bound language model. Whenever
    the model emits tool calls, the specified LangChain tools are executed and
    their `ToolMessage` responses are appended to the conversation history
    before handing control back to the model.
    """

    def __init__(
        self,
        model: BaseLanguageModel,
        tools: List[BaseTool],
        *,
        system_prompt: str = "",
        checkpointer: Any = None,
        log_tools: bool = True,
        log_dir: Optional[str] = "logs",
    ) -> None:
        self._system_prompt = system_prompt
        self._log_tools = log_tools
        self._log_dir = Path(log_dir or "logs")
        if self._log_tools:
            self._log_dir.mkdir(parents=True, exist_ok=True)

        # Prepare LangGraph workflow
        workflow = StateGraph(EchoAgentState)
        workflow.add_node("process", self._process_request)
        workflow.add_node("execute", self._execute_tools)
        workflow.add_conditional_edges("process", self._has_tool_calls, {True: "execute", False: END})
        workflow.add_edge("execute", "process")
        workflow.set_entry_point("process")

        self.workflow = workflow.compile(checkpointer=checkpointer)
        self.tools = {tool.name: tool for tool in tools}
        self.model = model.bind_tools(list(self.tools.values()))

    @property
    def system_prompt(self) -> str:
        return self._system_prompt

    def update_system_prompt(self, prompt: str) -> None:
        """Set a new system prompt for subsequent runs."""
        self._system_prompt = prompt

    # -- LangGraph node implementations -------------------------------------------------
    def _process_request(self, state: Dict[str, Any]) -> Dict[str, List[AnyMessage]]:
        messages: List[AnyMessage] = list(state.get("messages", []))
        if self._system_prompt:
            messages = [SystemMessage(content=self._system_prompt)] + messages

        response = self.model.invoke(messages)
        return {"messages": [response]}

    def _has_tool_calls(self, state: Dict[str, Any]) -> bool:
        last_message = state["messages"][-1]
        return bool(getattr(last_message, "tool_calls", []))

    def _execute_tools(self, state: Dict[str, Any]) -> Dict[str, List[ToolMessage]]:
        tool_messages: List[ToolMessage] = []
        for call in state["messages"][-1].tool_calls:
            tool_name = call.get("name")
            tool_args = call.get("args", {})
            tool_id = call.get("id", "")

            if tool_name not in self.tools:
                result_content = json.dumps(
                    {"status": "error", "error": f"Unknown tool '{tool_name}'"}, ensure_ascii=False
                )
            else:
                try:
                    result = self.tools[tool_name].invoke(tool_args)
                    # Tool results can be complex objects; coerce to JSON string if possible.
                    result_content = json.dumps(result, ensure_ascii=False, default=str)
                except Exception as exc:  # noqa: BLE001
                    result_content = json.dumps(
                        {"status": "error", "error": f"{type(exc).__name__}: {exc}"}, ensure_ascii=False
                    )

            message = ToolMessage(
                tool_call_id=tool_id,
                name=tool_name or "unknown_tool",
                content=result_content,
                additional_kwargs={"args": tool_args},
            )
            tool_messages.append(message)

        self._log_tool_messages(tool_messages)
        return {"messages": tool_messages}

    # -- Helpers ------------------------------------------------------------------------
    def _log_tool_messages(self, tool_messages: List[ToolMessage]) -> None:
        if not self._log_tools or not tool_messages:
            return

        timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
        log_path = self._log_dir / f"tool_calls_{timestamp}.json"
        logs: List[ToolCallLog] = []
        for message in tool_messages:
            logs.append(ToolCallLog(
                tool_call_id=message.tool_call_id,
                name=message.name,
                args=message.additional_kwargs.get("args", {}),
                content=message.content,
                timestamp=datetime.utcnow().isoformat(),
            ))

        log_path.write_text(json.dumps(logs, indent=2), encoding="utf-8")