mcpmark / src /agents /react_agent.py
haochengsama's picture
Add files using upload-large-folder tool
a2ec7b6 verified
Raw
History Blame Contribute Delete
21.7 kB
"""ReAct agent implementation for the MCPMark pipeline."""
from __future__ import annotations
import asyncio
import json
import time
from typing import Any, Dict, List, Optional, Callable
import litellm
from src.logger import get_logger
from .base_agent import BaseMCPAgent
logger = get_logger(__name__)
class ReActAgent(BaseMCPAgent):
"""ReAct-style agent that reuses MCPMark infrastructure."""
DEFAULT_SYSTEM_PROMPT = (
"You are a careful ReAct (reasoning and acting) agent. "
"At each step you must decide whether to call a tool or provide a final response. "
"Only use the tools that are listed for you. When you finish, respond with either the final answer "
"or the phrase \"Task completed.\" if no further detail is required. "
"Every reply must be valid JSON without code fences."
)
COMPACTION_PROMPT = (
"You are performing a CONTEXT CHECKPOINT COMPACTION.\n"
"Summarize the conversation so far for another model to continue.\n\n"
"Include:\n"
"- Current progress and key decisions made\n"
"- Important context, constraints, or user preferences\n"
"- What remains to be done (clear next steps)\n"
"- Any critical data, examples, or references needed to continue\n\n"
"Be concise and structured. Do NOT call tools."
)
def __init__(
self,
litellm_input_model_name: str,
api_key: str,
base_url: str,
mcp_service: str,
timeout: int = BaseMCPAgent.DEFAULT_TIMEOUT,
service_config: Optional[Dict[str, Any]] = None,
service_config_provider: Optional[Callable[[], Dict[str, Any]]] = None,
reasoning_effort: Optional[str] = "default",
max_iterations: int = 100,
system_prompt: Optional[str] = None,
compaction_token: int = BaseMCPAgent.COMPACTION_DISABLED_TOKEN,
):
super().__init__(
litellm_input_model_name=litellm_input_model_name,
api_key=api_key,
base_url=base_url,
mcp_service=mcp_service,
timeout=timeout,
service_config=service_config,
service_config_provider=service_config_provider,
reasoning_effort=reasoning_effort,
compaction_token=compaction_token,
)
self.max_iterations = max_iterations
self.react_system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
async def execute(
self,
instruction: str,
tool_call_log_file: Optional[str] = None,
) -> Dict[str, Any]:
start_time = time.time()
try:
self._reset_progress()
self._refresh_service_config()
async def _run_react():
return await self._execute_react_loop(instruction, tool_call_log_file)
result = await asyncio.wait_for(_run_react(), timeout=self.timeout)
execution_time = time.time() - start_time
self.usage_tracker.update(
success=result.get("success", False),
token_usage=result.get("token_usage", {}),
turn_count=result.get("turn_count", 0),
execution_time=execution_time,
)
result["execution_time"] = execution_time
return result
except Exception as exc: # noqa: BLE001
execution_time = time.time() - start_time
if isinstance(exc, asyncio.TimeoutError):
error_msg = f"Execution timed out after {self.timeout} seconds"
logger.error(error_msg)
else:
error_msg = f"ReAct agent execution failed: {exc}"
logger.error(error_msg, exc_info=True)
self.usage_tracker.update(
success=False,
token_usage=self._partial_token_usage or {},
turn_count=self._partial_turn_count or 0,
execution_time=execution_time,
)
if self._partial_messages:
final_msg = self._convert_to_sdk_format(self._partial_messages)
else:
final_msg = []
return {
"success": False,
"output": final_msg,
"token_usage": self._partial_token_usage or {},
"turn_count": self._partial_turn_count or 0,
"execution_time": execution_time,
"error": error_msg,
"litellm_run_model_name": self.litellm_run_model_name,
}
async def _execute_react_loop(
self,
instruction: str,
tool_call_log_file: Optional[str],
) -> Dict[str, Any]:
system_message = {"role": "system", "content": self.react_system_prompt}
total_tokens = {
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"reasoning_tokens": 0,
}
turn_count = 0
success = False
final_error: Optional[str] = None
mcp_server = await self._create_mcp_server()
async with mcp_server:
tools = await mcp_server.list_tools()
tool_map = {tool.get("name"): tool for tool in tools}
tools_description = self._render_tools_description(tools)
task_message = {
"role": "user",
"content": self._build_task_prompt(
instruction=instruction,
tools_description=tools_description,
),
}
messages: List[Dict[str, Any]] = [system_message, task_message]
self._update_progress(messages, total_tokens, turn_count)
for step in range(1, self.max_iterations + 1):
current_prompt_tokens = 0
if self._compaction_enabled():
current_prompt_tokens = self._count_prompt_tokens_litellm(messages)
if self._compaction_enabled() and current_prompt_tokens >= self.compaction_token:
logger.info(
f"| [compaction] Triggered at prompt tokens: {current_prompt_tokens:,}"
)
if tool_call_log_file:
try:
with open(tool_call_log_file, "a", encoding="utf-8") as log_file:
log_file.write(
f"| [compaction] Triggered at prompt tokens: {current_prompt_tokens:,}\n"
)
except Exception: # noqa: BLE001
pass
compact_messages = [
{"role": "system", "content": self.COMPACTION_PROMPT},
{"role": "user", "content": json.dumps(messages, ensure_ascii=False)},
]
compact_kwargs = {
"model": self.litellm_input_model_name,
"messages": compact_messages,
"api_key": self.api_key,
}
if self.base_url:
compact_kwargs["base_url"] = self.base_url
compact_response = await litellm.acompletion(**compact_kwargs)
usage = getattr(compact_response, "usage", None)
if usage:
prompt_tokens = (
getattr(usage, "prompt_tokens", None)
or getattr(usage, "input_tokens", None)
or 0
)
completion_tokens = (
getattr(usage, "completion_tokens", None)
or getattr(usage, "output_tokens", None)
or 0
)
total_tokens_count = getattr(usage, "total_tokens", None)
if total_tokens_count is None:
total_tokens_count = prompt_tokens + completion_tokens
total_tokens["input_tokens"] += int(prompt_tokens or 0)
total_tokens["output_tokens"] += int(completion_tokens or 0)
total_tokens["total_tokens"] += int(total_tokens_count or 0)
summary = ""
try:
summary = compact_response.choices[0].message.content or ""
except Exception: # noqa: BLE001
summary = ""
summary = summary.strip() or "(no summary)"
messages = [
system_message,
task_message,
{
"role": "user",
"content": (
"Context summary (auto-compacted due to token limit):\n"
f"{summary}"
),
},
]
self._update_progress(messages, total_tokens, turn_count)
completion_kwargs = {
"model": self.litellm_input_model_name,
"messages": messages,
"api_key": self.api_key,
}
if self.base_url:
completion_kwargs["base_url"] = self.base_url
if self.reasoning_effort != "default":
completion_kwargs["reasoning_effort"] = self.reasoning_effort
try:
response = await asyncio.wait_for(
litellm.acompletion(**completion_kwargs),
timeout=self.timeout / 2,
)
except asyncio.TimeoutError:
final_error = f"LLM call timed out on step {step}"
logger.error(final_error)
break
except Exception as exc: # noqa: BLE001
final_error = f"LLM call failed on step {step}: {exc}"
logger.error(final_error)
if "ContextWindowExceededError" in str(exc):
continue
break
if turn_count == 0 and getattr(response, "model", None):
self.litellm_run_model_name = response.model.split("/")[-1]
usage = getattr(response, "usage", None)
if usage:
prompt_tokens = (
getattr(usage, "prompt_tokens", None)
or getattr(usage, "input_tokens", None)
or 0
)
completion_tokens = (
getattr(usage, "completion_tokens", None)
or getattr(usage, "output_tokens", None)
or 0
)
total_tokens_count = getattr(usage, "total_tokens", None)
if total_tokens_count is None:
total_tokens_count = prompt_tokens + completion_tokens
total_tokens["input_tokens"] += prompt_tokens
total_tokens["output_tokens"] += completion_tokens
total_tokens["total_tokens"] += total_tokens_count
# Extract reasoning tokens if available
if hasattr(response.usage, 'completion_tokens_details'):
details = response.usage.completion_tokens_details
if hasattr(details, 'reasoning_tokens'):
total_tokens["reasoning_tokens"] += details.reasoning_tokens or 0
choice = response.choices[0]
message_obj = getattr(choice, "message", None)
if message_obj is None and isinstance(choice, dict):
message_obj = choice.get("message")
if message_obj is None:
content_raw = getattr(choice, "text", "")
else:
content_raw = message_obj.get("content", "")
assistant_text = self._normalize_content(content_raw)
assistant_message = {"role": "assistant", "content": assistant_text}
messages.append(assistant_message)
turn_count += 1
self._update_progress(messages, total_tokens, turn_count)
parsed = self._parse_react_response(assistant_text)
if not parsed or "thought" not in parsed:
warning = (
"The previous response was not valid JSON following the required schema. "
"Please respond again using the JSON formats provided."
)
messages.append({"role": "user", "content": warning})
self._update_progress(messages, total_tokens, turn_count)
final_error = "Model produced an invalid response format."
continue
thought = parsed.get("thought", "")
action = parsed.get("action")
answer = parsed.get("answer")
result = parsed.get("result")
logger.info(f"|\n| \033[1;3mThought\033[0m: {str(thought)}")
if tool_call_log_file:
try:
with open(tool_call_log_file, "a", encoding="utf-8") as log_file:
log_file.write(f"| {str(thought)}\n")
except Exception: # noqa: BLE001
pass
if action is not None:
func_name = action.get("tool")
arguments = action.get("arguments", {}) or {}
args_str = json.dumps(arguments, separators=(",", ": "))
display_arguments = args_str[:140] + "..." if len(args_str) > 140 else args_str
logger.info(f"| \033[1;3mAction\033[0m: \033[1m{func_name}\033[0m \033[2;37m{display_arguments}\033[0m")
if answer is not None:
success = True
break
if action is not None and isinstance(action, dict):
tool_name = action.get("tool")
arguments = action.get("arguments", {}) or {}
if tool_name not in tool_map:
observation = (
f"Invalid tool '{tool_name}'. Available tools: "
f"{', '.join(tool_map)}"
)
else:
try:
tool_response = await asyncio.wait_for(
mcp_server.call_tool(tool_name, arguments),
timeout=60,
)
observation = self._tool_result_to_text(tool_response)
except asyncio.TimeoutError:
observation = f"Tool '{tool_name}' timed out"
except Exception as tool_exc: # noqa: BLE001
observation = f"Tool '{tool_name}' failed: {tool_exc}"
if tool_call_log_file:
try:
with open(tool_call_log_file, "a", encoding="utf-8") as log_file:
log_file.write(f"| {tool_name} {json.dumps(arguments, ensure_ascii=False)}\n")
except Exception: # noqa: BLE001
pass
observation_message = {
"role": "user",
"content": (
f"Observation:\n{observation}\n"
"Please continue reasoning and reply using the required JSON format."
),
}
messages.append(observation_message)
self._update_progress(messages, total_tokens, turn_count)
continue
if result is not None:
observation_message = {
"role": "user",
"content": (
f"Observation:\n{result}\n"
"Please continue reasoning and reply using the required JSON format."
),
}
messages.append(observation_message)
self._update_progress(messages, total_tokens, turn_count)
continue
# Unexpected structure: ask model to restate properly
messages.append(
{
"role": "user",
"content": (
"The previous reply did not include an action, result, or answer. "
"Please respond again using the JSON formats provided."
),
}
)
self._update_progress(messages, total_tokens, turn_count)
if not success and final_error is None:
final_error = (
f"Max iterations ({self.max_iterations}) reached without a final answer."
)
if total_tokens["total_tokens"] > 0:
log_msg = (
f"|\n|\n| Token usage: Total: {total_tokens['total_tokens']:,} | "
f"Input: {total_tokens['input_tokens']:,} | "
f"Output: {total_tokens['output_tokens']:,}"
)
if total_tokens.get("reasoning_tokens", 0) > 0:
log_msg += f" | Reasoning: {total_tokens['reasoning_tokens']:,}"
logger.info(log_msg)
logger.info(f"| Turns: {turn_count}")
sdk_messages = self._convert_to_sdk_format(messages)
return {
"success": success,
"output": sdk_messages,
"token_usage": total_tokens,
"turn_count": turn_count,
"error": None if success else final_error,
"litellm_run_model_name": self.litellm_run_model_name,
}
def _build_task_prompt(
self,
instruction: str,
tools_description: str,
) -> str:
return (
f"Task:\n{instruction}\n\n"
f"Available MCP tools:\n{tools_description}\n\n"
"Respond using the JSON formats below.\n\n"
"If you need to use a tool:\n"
"{\n"
' "thought": "Reasoning for the next action",\n'
' "action": {\n'
' "tool": "tool-name",\n'
' "arguments": {\n'
' "parameter": value\n'
" }\n"
" }\n"
"}\n\n"
"If you can provide the final answer:\n"
"{\n"
' "thought": "Reasoning that justifies the answer",\n'
' "answer": "Either the final solution or \'Task completed.\' when no more detail is required"\n'
"}\n\n"
"Remember: omitting the action object ends the task, so only do this when finished."
)
def _render_tools_description(self, tools: List[Dict[str, Any]]) -> str:
descriptions = []
for tool in tools:
name = tool.get("name", "unknown")
description = tool.get("description", "No description provided.")
input_schema = tool.get("inputSchema", {}) or {}
properties = input_schema.get("properties", {}) or {}
required = set(input_schema.get("required", []) or [])
arg_lines = []
for prop_name, prop_details in properties.items():
details = json.dumps(prop_details, ensure_ascii=False, indent=2)
suffix = " (required)" if prop_name in required else ""
arg_lines.append(f"- {prop_name}{suffix}: {details}")
if arg_lines:
arguments_text = "\n".join(arg_lines)
else:
arguments_text = "(no arguments)"
descriptions.append(
f"Tool: {name}\nDescription: {description}\nArguments:\n{arguments_text}"
)
return "\n\n".join(descriptions) if descriptions else "(no tools available)"
def _normalize_content(self, content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for block in content:
if isinstance(block, dict):
if block.get("type") == "text":
parts.append(block.get("text", ""))
elif "text" in block:
parts.append(str(block.get("text")))
else:
parts.append(str(block))
return "\n".join(part for part in parts if part)
return json.dumps(content, ensure_ascii=False)
def _parse_react_response(self, payload: str) -> Dict[str, Any]:
candidate = payload.strip().strip("`").strip()
if candidate.lower().startswith("json"):
candidate = candidate[4:].lstrip()
try:
return json.loads(candidate)
except json.JSONDecodeError:
return {}
def _tool_result_to_text(self, result: Any) -> str:
if result is None:
return ""
if isinstance(result, str):
return result
try:
return json.dumps(result, ensure_ascii=False)
except TypeError:
return str(result)