Spaces:
Paused
Paused
File size: 4,534 Bytes
d8328bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
"""Agent controller orchestrating the LLM ↔ tool server interaction loop."""
from __future__ import annotations
import json
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Sequence
from tools.schemas import ToolCall, ToolResult
from .client_llm import Message, NexaSciModelClient
from .client_llm_remote import RemoteNexaSciClient
from .tool_client import ToolClient
TOOLCALL_REGEX = re.compile(r"~~~toolcall(.*?)~~~", re.DOTALL)
FINAL_REGEX = re.compile(r"~~~final(.*?)~~~", re.DOTALL)
@dataclass
class AgentRunResult:
"""Container describing the outcome of an agent run."""
final_response: Dict[str, Any]
messages: Sequence[Message]
tool_results: Sequence[ToolResult] = field(default_factory=list)
def pretty(self) -> str:
"""Return a readable JSON representation of the final response."""
return json.dumps(self.final_response, indent=2)
class AgentController:
"""Core agent loop handling tool invocation and final response parsing."""
def __init__(
self,
llm_client: NexaSciModelClient | RemoteNexaSciClient | None = None,
tool_client: ToolClient | None = None,
*,
max_turns: int = 8,
use_remote_model: bool = False,
model_server_url: str = "http://127.0.0.1:8001",
) -> None:
"""Initialize the agent controller.
Parameters
----------
llm_client:
Optional LLM client. If None, will create one based on use_remote_model.
tool_client:
Optional tool client. If None, will create from config.
max_turns:
Maximum number of agent turns.
use_remote_model:
If True, connect to remote model server instead of loading locally.
model_server_url:
URL of the model server (if use_remote_model is True).
"""
if llm_client is None:
if use_remote_model:
llm_client = RemoteNexaSciClient(base_url=model_server_url)
else:
llm_client = NexaSciModelClient(lazy_load=True)
self.llm_client = llm_client
self.tool_client = tool_client or ToolClient.from_config()
self.max_turns = max_turns
def run(self, user_prompt: str) -> AgentRunResult:
"""Execute the agent loop until a final response is produced."""
messages: List[Message] = [Message(role="user", content=user_prompt)]
tool_results: List[ToolResult] = []
for _ in range(self.max_turns):
response_text = self.llm_client.generate(messages)
messages.append(Message(role="assistant", content=response_text))
tool_calls = _extract_tool_calls(response_text)
if tool_calls:
for call in tool_calls:
result = self._dispatch_tool(call)
tool_results.append(result)
messages.append(
Message(
role="tool",
content=json.dumps(result.output, ensure_ascii=False),
)
)
continue
final_payload = _extract_final_response(response_text)
if final_payload is not None:
return AgentRunResult(final_response=final_payload, messages=messages, tool_results=tool_results)
raise RuntimeError("Agent did not produce a final response within the maximum number of turns.")
def _dispatch_tool(self, call: ToolCall) -> ToolResult:
"""Invoke the requested tool via the ToolClient."""
return self.tool_client.call_tool(call)
def _extract_tool_calls(text: str) -> List[ToolCall]:
"""Parse tool call JSON payloads embedded in the assistant response."""
tool_calls: List[ToolCall] = []
for match in TOOLCALL_REGEX.findall(text):
snippet = match.strip()
if not snippet:
continue
try:
payload = json.loads(snippet)
tool_calls.append(ToolCall(**payload))
except json.JSONDecodeError:
continue
return tool_calls
def _extract_final_response(text: str) -> Dict[str, Any] | None:
"""Parse the final response JSON block from the assistant output."""
match = FINAL_REGEX.search(text)
if not match:
return None
snippet = match.group(1).strip()
if not snippet:
return {}
return json.loads(snippet)
__all__ = ["AgentController", "AgentRunResult"]
|