| """ReAct agent implementation for the MCPMark pipeline.""" |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import json |
| import time |
| from typing import Any, Dict, List, Optional, Callable |
|
|
| import litellm |
|
|
| from src.logger import get_logger |
| from .base_agent import BaseMCPAgent |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class ReActAgent(BaseMCPAgent): |
| """ReAct-style agent that reuses MCPMark infrastructure.""" |
|
|
| DEFAULT_SYSTEM_PROMPT = ( |
| "You are a careful ReAct (reasoning and acting) agent. " |
| "At each step you must decide whether to call a tool or provide a final response. " |
| "Only use the tools that are listed for you. When you finish, respond with either the final answer " |
| "or the phrase \"Task completed.\" if no further detail is required. " |
| "Every reply must be valid JSON without code fences." |
| ) |
| COMPACTION_PROMPT = ( |
| "You are performing a CONTEXT CHECKPOINT COMPACTION.\n" |
| "Summarize the conversation so far for another model to continue.\n\n" |
| "Include:\n" |
| "- Current progress and key decisions made\n" |
| "- Important context, constraints, or user preferences\n" |
| "- What remains to be done (clear next steps)\n" |
| "- Any critical data, examples, or references needed to continue\n\n" |
| "Be concise and structured. Do NOT call tools." |
| ) |
|
|
| def __init__( |
| self, |
| litellm_input_model_name: str, |
| api_key: str, |
| base_url: str, |
| mcp_service: str, |
| timeout: int = BaseMCPAgent.DEFAULT_TIMEOUT, |
| service_config: Optional[Dict[str, Any]] = None, |
| service_config_provider: Optional[Callable[[], Dict[str, Any]]] = None, |
| reasoning_effort: Optional[str] = "default", |
| max_iterations: int = 100, |
| system_prompt: Optional[str] = None, |
| compaction_token: int = BaseMCPAgent.COMPACTION_DISABLED_TOKEN, |
| ): |
| super().__init__( |
| litellm_input_model_name=litellm_input_model_name, |
| api_key=api_key, |
| base_url=base_url, |
| mcp_service=mcp_service, |
| timeout=timeout, |
| service_config=service_config, |
| service_config_provider=service_config_provider, |
| reasoning_effort=reasoning_effort, |
| compaction_token=compaction_token, |
| ) |
| self.max_iterations = max_iterations |
| self.react_system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT |
|
|
| async def execute( |
| self, |
| instruction: str, |
| tool_call_log_file: Optional[str] = None, |
| ) -> Dict[str, Any]: |
| start_time = time.time() |
|
|
| try: |
| self._reset_progress() |
| self._refresh_service_config() |
|
|
| async def _run_react(): |
| return await self._execute_react_loop(instruction, tool_call_log_file) |
|
|
| result = await asyncio.wait_for(_run_react(), timeout=self.timeout) |
| execution_time = time.time() - start_time |
| self.usage_tracker.update( |
| success=result.get("success", False), |
| token_usage=result.get("token_usage", {}), |
| turn_count=result.get("turn_count", 0), |
| execution_time=execution_time, |
| ) |
| result["execution_time"] = execution_time |
| return result |
| except Exception as exc: |
| execution_time = time.time() - start_time |
|
|
| if isinstance(exc, asyncio.TimeoutError): |
| error_msg = f"Execution timed out after {self.timeout} seconds" |
| logger.error(error_msg) |
| else: |
| error_msg = f"ReAct agent execution failed: {exc}" |
| logger.error(error_msg, exc_info=True) |
|
|
| self.usage_tracker.update( |
| success=False, |
| token_usage=self._partial_token_usage or {}, |
| turn_count=self._partial_turn_count or 0, |
| execution_time=execution_time, |
| ) |
|
|
| if self._partial_messages: |
| final_msg = self._convert_to_sdk_format(self._partial_messages) |
| else: |
| final_msg = [] |
|
|
| return { |
| "success": False, |
| "output": final_msg, |
| "token_usage": self._partial_token_usage or {}, |
| "turn_count": self._partial_turn_count or 0, |
| "execution_time": execution_time, |
| "error": error_msg, |
| "litellm_run_model_name": self.litellm_run_model_name, |
| } |
|
|
| async def _execute_react_loop( |
| self, |
| instruction: str, |
| tool_call_log_file: Optional[str], |
| ) -> Dict[str, Any]: |
| system_message = {"role": "system", "content": self.react_system_prompt} |
| total_tokens = { |
| "input_tokens": 0, |
| "output_tokens": 0, |
| "total_tokens": 0, |
| "reasoning_tokens": 0, |
| } |
| turn_count = 0 |
| success = False |
| final_error: Optional[str] = None |
|
|
| mcp_server = await self._create_mcp_server() |
| async with mcp_server: |
| tools = await mcp_server.list_tools() |
| tool_map = {tool.get("name"): tool for tool in tools} |
| tools_description = self._render_tools_description(tools) |
|
|
| task_message = { |
| "role": "user", |
| "content": self._build_task_prompt( |
| instruction=instruction, |
| tools_description=tools_description, |
| ), |
| } |
| messages: List[Dict[str, Any]] = [system_message, task_message] |
| self._update_progress(messages, total_tokens, turn_count) |
|
|
| for step in range(1, self.max_iterations + 1): |
| current_prompt_tokens = 0 |
| if self._compaction_enabled(): |
| current_prompt_tokens = self._count_prompt_tokens_litellm(messages) |
|
|
| if self._compaction_enabled() and current_prompt_tokens >= self.compaction_token: |
| logger.info( |
| f"| [compaction] Triggered at prompt tokens: {current_prompt_tokens:,}" |
| ) |
| if tool_call_log_file: |
| try: |
| with open(tool_call_log_file, "a", encoding="utf-8") as log_file: |
| log_file.write( |
| f"| [compaction] Triggered at prompt tokens: {current_prompt_tokens:,}\n" |
| ) |
| except Exception: |
| pass |
|
|
| compact_messages = [ |
| {"role": "system", "content": self.COMPACTION_PROMPT}, |
| {"role": "user", "content": json.dumps(messages, ensure_ascii=False)}, |
| ] |
| compact_kwargs = { |
| "model": self.litellm_input_model_name, |
| "messages": compact_messages, |
| "api_key": self.api_key, |
| } |
| if self.base_url: |
| compact_kwargs["base_url"] = self.base_url |
|
|
| compact_response = await litellm.acompletion(**compact_kwargs) |
| usage = getattr(compact_response, "usage", None) |
| if usage: |
| prompt_tokens = ( |
| getattr(usage, "prompt_tokens", None) |
| or getattr(usage, "input_tokens", None) |
| or 0 |
| ) |
| completion_tokens = ( |
| getattr(usage, "completion_tokens", None) |
| or getattr(usage, "output_tokens", None) |
| or 0 |
| ) |
| total_tokens_count = getattr(usage, "total_tokens", None) |
| if total_tokens_count is None: |
| total_tokens_count = prompt_tokens + completion_tokens |
|
|
| total_tokens["input_tokens"] += int(prompt_tokens or 0) |
| total_tokens["output_tokens"] += int(completion_tokens or 0) |
| total_tokens["total_tokens"] += int(total_tokens_count or 0) |
|
|
| summary = "" |
| try: |
| summary = compact_response.choices[0].message.content or "" |
| except Exception: |
| summary = "" |
| summary = summary.strip() or "(no summary)" |
|
|
| messages = [ |
| system_message, |
| task_message, |
| { |
| "role": "user", |
| "content": ( |
| "Context summary (auto-compacted due to token limit):\n" |
| f"{summary}" |
| ), |
| }, |
| ] |
| self._update_progress(messages, total_tokens, turn_count) |
|
|
| completion_kwargs = { |
| "model": self.litellm_input_model_name, |
| "messages": messages, |
| "api_key": self.api_key, |
| } |
| if self.base_url: |
| completion_kwargs["base_url"] = self.base_url |
| if self.reasoning_effort != "default": |
| completion_kwargs["reasoning_effort"] = self.reasoning_effort |
|
|
| try: |
| response = await asyncio.wait_for( |
| litellm.acompletion(**completion_kwargs), |
| timeout=self.timeout / 2, |
| ) |
| except asyncio.TimeoutError: |
| final_error = f"LLM call timed out on step {step}" |
| logger.error(final_error) |
| break |
| except Exception as exc: |
| final_error = f"LLM call failed on step {step}: {exc}" |
| logger.error(final_error) |
| if "ContextWindowExceededError" in str(exc): |
| continue |
| break |
|
|
| if turn_count == 0 and getattr(response, "model", None): |
| self.litellm_run_model_name = response.model.split("/")[-1] |
|
|
| usage = getattr(response, "usage", None) |
| if usage: |
| prompt_tokens = ( |
| getattr(usage, "prompt_tokens", None) |
| or getattr(usage, "input_tokens", None) |
| or 0 |
| ) |
| completion_tokens = ( |
| getattr(usage, "completion_tokens", None) |
| or getattr(usage, "output_tokens", None) |
| or 0 |
| ) |
| total_tokens_count = getattr(usage, "total_tokens", None) |
| if total_tokens_count is None: |
| total_tokens_count = prompt_tokens + completion_tokens |
|
|
| total_tokens["input_tokens"] += prompt_tokens |
| total_tokens["output_tokens"] += completion_tokens |
| total_tokens["total_tokens"] += total_tokens_count |
|
|
| |
| if hasattr(response.usage, 'completion_tokens_details'): |
| details = response.usage.completion_tokens_details |
| if hasattr(details, 'reasoning_tokens'): |
| total_tokens["reasoning_tokens"] += details.reasoning_tokens or 0 |
|
|
| choice = response.choices[0] |
| message_obj = getattr(choice, "message", None) |
| if message_obj is None and isinstance(choice, dict): |
| message_obj = choice.get("message") |
|
|
| if message_obj is None: |
| content_raw = getattr(choice, "text", "") |
| else: |
| content_raw = message_obj.get("content", "") |
|
|
| assistant_text = self._normalize_content(content_raw) |
| assistant_message = {"role": "assistant", "content": assistant_text} |
| messages.append(assistant_message) |
| turn_count += 1 |
| self._update_progress(messages, total_tokens, turn_count) |
|
|
| parsed = self._parse_react_response(assistant_text) |
| if not parsed or "thought" not in parsed: |
| warning = ( |
| "The previous response was not valid JSON following the required schema. " |
| "Please respond again using the JSON formats provided." |
| ) |
| messages.append({"role": "user", "content": warning}) |
| self._update_progress(messages, total_tokens, turn_count) |
| final_error = "Model produced an invalid response format." |
| continue |
|
|
| thought = parsed.get("thought", "") |
| action = parsed.get("action") |
| answer = parsed.get("answer") |
| result = parsed.get("result") |
|
|
| logger.info(f"|\n| \033[1;3mThought\033[0m: {str(thought)}") |
| if tool_call_log_file: |
| try: |
| with open(tool_call_log_file, "a", encoding="utf-8") as log_file: |
| log_file.write(f"| {str(thought)}\n") |
| except Exception: |
| pass |
| if action is not None: |
| func_name = action.get("tool") |
| arguments = action.get("arguments", {}) or {} |
| args_str = json.dumps(arguments, separators=(",", ": ")) |
| display_arguments = args_str[:140] + "..." if len(args_str) > 140 else args_str |
| logger.info(f"| \033[1;3mAction\033[0m: \033[1m{func_name}\033[0m \033[2;37m{display_arguments}\033[0m") |
|
|
|
|
| if answer is not None: |
| success = True |
| break |
|
|
| if action is not None and isinstance(action, dict): |
| tool_name = action.get("tool") |
| arguments = action.get("arguments", {}) or {} |
|
|
| if tool_name not in tool_map: |
| observation = ( |
| f"Invalid tool '{tool_name}'. Available tools: " |
| f"{', '.join(tool_map)}" |
| ) |
| else: |
| try: |
| tool_response = await asyncio.wait_for( |
| mcp_server.call_tool(tool_name, arguments), |
| timeout=60, |
| ) |
| observation = self._tool_result_to_text(tool_response) |
| except asyncio.TimeoutError: |
| observation = f"Tool '{tool_name}' timed out" |
| except Exception as tool_exc: |
| observation = f"Tool '{tool_name}' failed: {tool_exc}" |
|
|
| if tool_call_log_file: |
| try: |
| with open(tool_call_log_file, "a", encoding="utf-8") as log_file: |
| log_file.write(f"| {tool_name} {json.dumps(arguments, ensure_ascii=False)}\n") |
| except Exception: |
| pass |
|
|
| observation_message = { |
| "role": "user", |
| "content": ( |
| f"Observation:\n{observation}\n" |
| "Please continue reasoning and reply using the required JSON format." |
| ), |
| } |
| messages.append(observation_message) |
| self._update_progress(messages, total_tokens, turn_count) |
| continue |
|
|
| if result is not None: |
| observation_message = { |
| "role": "user", |
| "content": ( |
| f"Observation:\n{result}\n" |
| "Please continue reasoning and reply using the required JSON format." |
| ), |
| } |
| messages.append(observation_message) |
| self._update_progress(messages, total_tokens, turn_count) |
| continue |
|
|
| |
| messages.append( |
| { |
| "role": "user", |
| "content": ( |
| "The previous reply did not include an action, result, or answer. " |
| "Please respond again using the JSON formats provided." |
| ), |
| } |
| ) |
| self._update_progress(messages, total_tokens, turn_count) |
|
|
| if not success and final_error is None: |
| final_error = ( |
| f"Max iterations ({self.max_iterations}) reached without a final answer." |
| ) |
|
|
| if total_tokens["total_tokens"] > 0: |
| log_msg = ( |
| f"|\n|\n| Token usage: Total: {total_tokens['total_tokens']:,} | " |
| f"Input: {total_tokens['input_tokens']:,} | " |
| f"Output: {total_tokens['output_tokens']:,}" |
| ) |
| if total_tokens.get("reasoning_tokens", 0) > 0: |
| log_msg += f" | Reasoning: {total_tokens['reasoning_tokens']:,}" |
| logger.info(log_msg) |
| logger.info(f"| Turns: {turn_count}") |
|
|
| sdk_messages = self._convert_to_sdk_format(messages) |
|
|
| return { |
| "success": success, |
| "output": sdk_messages, |
| "token_usage": total_tokens, |
| "turn_count": turn_count, |
| "error": None if success else final_error, |
| "litellm_run_model_name": self.litellm_run_model_name, |
| } |
|
|
| def _build_task_prompt( |
| self, |
| instruction: str, |
| tools_description: str, |
| ) -> str: |
| return ( |
| f"Task:\n{instruction}\n\n" |
| f"Available MCP tools:\n{tools_description}\n\n" |
| "Respond using the JSON formats below.\n\n" |
| "If you need to use a tool:\n" |
| "{\n" |
| ' "thought": "Reasoning for the next action",\n' |
| ' "action": {\n' |
| ' "tool": "tool-name",\n' |
| ' "arguments": {\n' |
| ' "parameter": value\n' |
| " }\n" |
| " }\n" |
| "}\n\n" |
| "If you can provide the final answer:\n" |
| "{\n" |
| ' "thought": "Reasoning that justifies the answer",\n' |
| ' "answer": "Either the final solution or \'Task completed.\' when no more detail is required"\n' |
| "}\n\n" |
| "Remember: omitting the action object ends the task, so only do this when finished." |
| ) |
|
|
| def _render_tools_description(self, tools: List[Dict[str, Any]]) -> str: |
| descriptions = [] |
| for tool in tools: |
| name = tool.get("name", "unknown") |
| description = tool.get("description", "No description provided.") |
| input_schema = tool.get("inputSchema", {}) or {} |
| properties = input_schema.get("properties", {}) or {} |
| required = set(input_schema.get("required", []) or []) |
|
|
| arg_lines = [] |
| for prop_name, prop_details in properties.items(): |
| details = json.dumps(prop_details, ensure_ascii=False, indent=2) |
| suffix = " (required)" if prop_name in required else "" |
| arg_lines.append(f"- {prop_name}{suffix}: {details}") |
|
|
| if arg_lines: |
| arguments_text = "\n".join(arg_lines) |
| else: |
| arguments_text = "(no arguments)" |
|
|
| descriptions.append( |
| f"Tool: {name}\nDescription: {description}\nArguments:\n{arguments_text}" |
| ) |
|
|
| return "\n\n".join(descriptions) if descriptions else "(no tools available)" |
|
|
| def _normalize_content(self, content: Any) -> str: |
| if isinstance(content, str): |
| return content |
| if isinstance(content, list): |
| parts = [] |
| for block in content: |
| if isinstance(block, dict): |
| if block.get("type") == "text": |
| parts.append(block.get("text", "")) |
| elif "text" in block: |
| parts.append(str(block.get("text"))) |
| else: |
| parts.append(str(block)) |
| return "\n".join(part for part in parts if part) |
| return json.dumps(content, ensure_ascii=False) |
|
|
| def _parse_react_response(self, payload: str) -> Dict[str, Any]: |
| candidate = payload.strip().strip("`").strip() |
| if candidate.lower().startswith("json"): |
| candidate = candidate[4:].lstrip() |
| try: |
| return json.loads(candidate) |
| except json.JSONDecodeError: |
| return {} |
|
|
| def _tool_result_to_text(self, result: Any) -> str: |
| if result is None: |
| return "" |
| if isinstance(result, str): |
| return result |
| try: |
| return json.dumps(result, ensure_ascii=False) |
| except TypeError: |
| return str(result) |
|
|