File size: 4,534 Bytes
d8328bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Agent controller orchestrating the LLM ↔ tool server interaction loop."""

from __future__ import annotations

import json
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Sequence

from tools.schemas import ToolCall, ToolResult

from .client_llm import Message, NexaSciModelClient
from .client_llm_remote import RemoteNexaSciClient
from .tool_client import ToolClient

TOOLCALL_REGEX = re.compile(r"~~~toolcall(.*?)~~~", re.DOTALL)
FINAL_REGEX = re.compile(r"~~~final(.*?)~~~", re.DOTALL)


@dataclass
class AgentRunResult:
    """Container describing the outcome of an agent run."""

    final_response: Dict[str, Any]
    messages: Sequence[Message]
    tool_results: Sequence[ToolResult] = field(default_factory=list)

    def pretty(self) -> str:
        """Return a readable JSON representation of the final response."""

        return json.dumps(self.final_response, indent=2)


class AgentController:
    """Core agent loop handling tool invocation and final response parsing."""

    def __init__(
        self,
        llm_client: NexaSciModelClient | RemoteNexaSciClient | None = None,
        tool_client: ToolClient | None = None,
        *,
        max_turns: int = 8,
        use_remote_model: bool = False,
        model_server_url: str = "http://127.0.0.1:8001",
    ) -> None:
        """Initialize the agent controller.

        Parameters
        ----------
        llm_client:
            Optional LLM client. If None, will create one based on use_remote_model.
        tool_client:
            Optional tool client. If None, will create from config.
        max_turns:
            Maximum number of agent turns.
        use_remote_model:
            If True, connect to remote model server instead of loading locally.
        model_server_url:
            URL of the model server (if use_remote_model is True).
        """
        if llm_client is None:
            if use_remote_model:
                llm_client = RemoteNexaSciClient(base_url=model_server_url)
            else:
                llm_client = NexaSciModelClient(lazy_load=True)
        self.llm_client = llm_client
        self.tool_client = tool_client or ToolClient.from_config()
        self.max_turns = max_turns

    def run(self, user_prompt: str) -> AgentRunResult:
        """Execute the agent loop until a final response is produced."""

        messages: List[Message] = [Message(role="user", content=user_prompt)]
        tool_results: List[ToolResult] = []

        for _ in range(self.max_turns):
            response_text = self.llm_client.generate(messages)
            messages.append(Message(role="assistant", content=response_text))

            tool_calls = _extract_tool_calls(response_text)
            if tool_calls:
                for call in tool_calls:
                    result = self._dispatch_tool(call)
                    tool_results.append(result)
                    messages.append(
                        Message(
                            role="tool",
                            content=json.dumps(result.output, ensure_ascii=False),
                        )
                    )
                continue

            final_payload = _extract_final_response(response_text)
            if final_payload is not None:
                return AgentRunResult(final_response=final_payload, messages=messages, tool_results=tool_results)

        raise RuntimeError("Agent did not produce a final response within the maximum number of turns.")

    def _dispatch_tool(self, call: ToolCall) -> ToolResult:
        """Invoke the requested tool via the ToolClient."""

        return self.tool_client.call_tool(call)


def _extract_tool_calls(text: str) -> List[ToolCall]:
    """Parse tool call JSON payloads embedded in the assistant response."""

    tool_calls: List[ToolCall] = []
    for match in TOOLCALL_REGEX.findall(text):
        snippet = match.strip()
        if not snippet:
            continue
        try:
            payload = json.loads(snippet)
            tool_calls.append(ToolCall(**payload))
        except json.JSONDecodeError:
            continue
    return tool_calls


def _extract_final_response(text: str) -> Dict[str, Any] | None:
    """Parse the final response JSON block from the assistant output."""

    match = FINAL_REGEX.search(text)
    if not match:
        return None
    snippet = match.group(1).strip()
    if not snippet:
        return {}
    return json.loads(snippet)


__all__ = ["AgentController", "AgentRunResult"]