agent_world_model_env / server /web_agent.py
ChilleD's picture
Upload folder using huggingface_hub
11a4e6e verified
"""LLM-driven agent loop for the AWM web UI."""
from __future__ import annotations
import json
import re
from dataclasses import dataclass, field
from typing import Any, AsyncIterator
from openai import AsyncOpenAI
from .prompts import DEFAULT_SYSTEM_PROMPT
def parse_tool_call(content: str) -> dict | None:
m = re.search(r"<tool_call>\s*(.*?)\s*</tool_call>", content, re.DOTALL)
if not m:
return None
try:
data = json.loads(m.group(1).strip())
except json.JSONDecodeError:
return None
if isinstance(data, list):
data = data[0] if data else None
if not isinstance(data, dict) or "name" not in data:
return None
return data
def format_tools(tools: list[dict]) -> str:
lines = [f"Available MCP Tools ({len(tools)} tools):", "=" * 60]
for i, t in enumerate(tools, 1):
name = t.get("name") or t.get("tool_name", "")
desc = t.get("description", "")
schema = t.get("input_schema") or t.get("inputSchema") or {}
lines.append(f"{i}. {name}")
lines.append(f" Description: {desc}")
props = schema.get("properties", {})
required = set(schema.get("required", []))
if props:
lines.append(" Parameters:")
for pname, pinfo in props.items():
req = " (required)" if pname in required else ""
lines.append(
f" - {pname}: {pinfo.get('type', 'any')}{req} — "
f"{pinfo.get('description', '')}"
)
else:
lines.append(" Parameters: None")
lines.append("")
return "\n".join(lines)
def _truncate(s: str, limit: int = 2000) -> str:
if len(s) <= limit:
return s
return s[:limit] + f"\n... (truncated, full length {len(s)} chars)"
@dataclass
class AgentEvent:
kind: str # "info" | "llm_response" | "tool_call" | "tool_result" | "verify" | "done" | "error"
text: str = ""
payload: dict = field(default_factory=dict)
class AwmAgent:
def __init__(
self,
web_manager: Any,
llm_base_url: str,
llm_api_key: str,
llm_model: str,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
max_iterations: int = 10,
temperature: float = 1.0,
max_tokens: int = 2048,
):
self._web = web_manager
self._client = AsyncOpenAI(base_url=llm_base_url, api_key=llm_api_key)
self._model = llm_model
self._system_prompt = system_prompt
self._max_iterations = max_iterations
self._temperature = temperature
self._max_tokens = max_tokens
self._stop_requested = False
def request_stop(self) -> None:
self._stop_requested = True
async def _list_tools_dict(self) -> list[dict]:
result = await self._web.step_environment({"type": "list_tools"})
obs = result.get("observation", {}) or {}
tools = obs.get("tools", []) or []
out = []
for t in tools:
if isinstance(t, dict):
out.append(t)
else:
out.append(
{
"name": getattr(t, "name", ""),
"description": getattr(t, "description", ""),
"input_schema": getattr(t, "input_schema", {}),
}
)
return out
async def _call_tool(self, tool_name: str, args: dict) -> dict:
return await self._web.step_environment(
{"type": "call_tool", "tool_name": tool_name, "arguments": args}
)
async def run(
self,
task: str,
verifier_mode: str | None = None,
final_answer_fallback: str = "",
auto_verify: bool = True,
auto_done: bool = True,
) -> AsyncIterator[AgentEvent]:
"""Run the agent on the already-reset env, yielding events as it goes."""
try:
tools = await self._list_tools_dict()
except Exception as e:
yield AgentEvent(kind="error", text=f"list_tools failed: {e}")
return
yield AgentEvent(
kind="info",
text=f"Discovered {len(tools)} tools.",
payload={"tool_names": [t.get("name") for t in tools]},
)
tools_text = format_tools(tools)
messages: list[dict] = [
{"role": "system", "content": self._system_prompt},
{"role": "user", "content": task},
{
"role": "user",
"content": f"Available tools:\n{tools_text}",
},
]
last_assistant_content = ""
for step in range(1, self._max_iterations + 1):
if self._stop_requested:
yield AgentEvent(kind="info", text="Stopped by user.")
return
try:
resp = await self._client.chat.completions.create(
model=self._model,
messages=messages,
temperature=self._temperature,
max_completion_tokens=self._max_tokens,
)
except Exception as e:
yield AgentEvent(
kind="error", text=f"LLM call failed at step {step}: {e}"
)
return
content = resp.choices[0].message.content or ""
last_assistant_content = content
messages.append({"role": "assistant", "content": content})
yield AgentEvent(
kind="llm_response",
text=content,
payload={"step": step},
)
tc = parse_tool_call(content)
if tc is None:
yield AgentEvent(
kind="info",
text=f"No <tool_call> in step {step}; treating as final answer.",
)
break
name = tc.get("name", "")
arguments = tc.get("arguments") or {}
yield AgentEvent(
kind="tool_call",
text=f"{name} {json.dumps(arguments, ensure_ascii=False)[:300]}",
payload={"name": name, "arguments": arguments, "step": step},
)
tool_response = ""
try:
if name == "list_tools":
result = await self._web.step_environment({"type": "list_tools"})
obs = result.get("observation", {}) or {}
tools = await self._list_tools_dict()
tool_response = format_tools(tools)
elif name == "call_tool":
inner_name = arguments.get("tool_name", "")
inner_args = arguments.get("arguments", "{}")
if isinstance(inner_args, str):
try:
inner_args = json.loads(inner_args)
except json.JSONDecodeError:
inner_args = {}
if not isinstance(inner_args, dict):
inner_args = {}
result = await self._call_tool(inner_name, inner_args)
obs = result.get("observation", {}) or {}
if obs.get("tool_result") is not None:
tr = obs["tool_result"]
tool_response = (
json.dumps(tr, ensure_ascii=False)
if not isinstance(tr, str)
else tr
)
elif obs.get("error"):
tool_response = f"Error: {obs['error']}"
else:
tool_response = json.dumps(obs, ensure_ascii=False)
else:
tool_response = (
f"Error: Unknown function '{name}'. "
"Use 'list_tools' or 'call_tool'."
)
except Exception as e:
tool_response = f"Error during tool dispatch: {e}"
yield AgentEvent(
kind="tool_result",
text=_truncate(tool_response),
payload={"step": step},
)
messages.append(
{
"role": "user",
"content": f"Tool response:\n{_truncate(tool_response, 6000)}",
}
)
else:
yield AgentEvent(
kind="info",
text=f"Max iterations ({self._max_iterations}) reached.",
)
if auto_verify and verifier_mode:
verify_args: dict = {"verifier_mode": verifier_mode}
if last_assistant_content:
verify_args["final_answer"] = last_assistant_content
elif final_answer_fallback:
verify_args["final_answer"] = final_answer_fallback
try:
verify_result = await self._call_tool("verify", verify_args)
obs = verify_result.get("observation", {}) or {}
yield AgentEvent(
kind="verify",
text=(
f"reward_type={obs.get('reward_type')} "
f"reward={verify_result.get('reward')}"
),
payload={
"reward_type": obs.get("reward_type"),
"reward": verify_result.get("reward"),
"verify_result": obs.get("verify_result"),
},
)
except Exception as e:
yield AgentEvent(kind="error", text=f"verify failed: {e}")
if auto_done:
try:
done_result = await self._call_tool("done", {"keep_session": True})
obs = done_result.get("observation", {}) or {}
yield AgentEvent(
kind="done",
text=(
f"Episode done. trajectory_path="
f"{obs.get('trajectory_path') or '(none)'}"
),
payload={
"trajectory_path": obs.get("trajectory_path"),
"session_dir": obs.get("session_dir"),
},
)
except Exception as e:
yield AgentEvent(kind="error", text=f"done failed: {e}")