"""Shared base agent functionality for MCPMark agents.""" from __future__ import annotations import asyncio import copy import json import uuid from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Callable from src.logger import get_logger from .mcp import MCPStdioServer, MCPHttpServer from .utils import TokenUsageTracker logger = get_logger(__name__) class BaseMCPAgent(ABC): """Base class with shared functionality for MCPMark agents.""" STDIO_SERVICES = [ "notion", "filesystem", "playwright", "playwright_webarena", "postgres", "insforge", "github", ] HTTP_SERVICES = ["supabase"] DEFAULT_TIMEOUT = 600 COMPACTION_DISABLED_TOKEN = 999_999_999 CLAUDE_THINKING_BUDGETS = { "low": 1024, "medium": 2048, "high": 4096, } def __init__( self, litellm_input_model_name: str, api_key: str, base_url: str, mcp_service: str, timeout: int = DEFAULT_TIMEOUT, service_config: Optional[Dict[str, Any]] = None, service_config_provider: Optional[Callable[[], Dict[str, Any]]] = None, reasoning_effort: Optional[str] = "default", compaction_token: int = COMPACTION_DISABLED_TOKEN, extra_body: Optional[Dict[str, Any]] = None, summarize_tool_response: bool = False, ): self.litellm_input_model_name = litellm_input_model_name self.api_key = api_key self.base_url = base_url self.mcp_service = mcp_service self.timeout = timeout self.service_config = service_config or {} self._service_config_provider = service_config_provider self.reasoning_effort = reasoning_effort or "default" self.compaction_token = int(compaction_token) # Extra request-body fields forwarded verbatim to litellm.acompletion # (e.g. DeepSeek dimcode `{"thinking": {"type": "enabled"}}`). self.extra_body = extra_body or None # When True (playwright runs only), each browser tool result is # LLM-summarized BEFORE entering the message history — the model sees # and the trajectory records the SAME compressed content. self.summarize_tool_response = bool(summarize_tool_response) self.is_claude = self._is_anthropic_model(litellm_input_model_name) self.use_claude_thinking = self.is_claude and self.reasoning_effort != "default" self.usage_tracker = TokenUsageTracker() self.litellm_run_model_name = None self._partial_messages: List[Dict[str, Any]] = [] self._partial_token_usage: Dict[str, int] = {} self._partial_turn_count: int = 0 logger.debug( "Initialized %s for service '%s' with model '%s'", self.__class__.__name__, self.mcp_service, self.litellm_input_model_name, ) # Warn if Gemini 3 model uses unsupported reasoning_effort value if self._is_gemini_3_model() and self.reasoning_effort not in [ "default", "low", "high", ]: logger.warning( "Gemini 3 models only support reasoning_effort 'low' or 'high', " "got '%s'. LiteLLM may map this to the nearest supported value.", self.reasoning_effort, ) def __repr__(self) -> str: # pragma: no cover - debug helper return ( f"{self.__class__.__name__}(service='{self.mcp_service}', " f"model='{self.litellm_input_model_name}')" ) @abstractmethod async def execute( self, instruction: str, tool_call_log_file: Optional[str] = None, ) -> Dict[str, Any]: """Execute the agent logic and return execution metadata.""" def execute_sync( self, instruction: str, tool_call_log_file: Optional[str] = None, ) -> Dict[str, Any]: """Synchronous wrapper for async execution.""" return asyncio.run(self.execute(instruction, tool_call_log_file)) def get_usage_stats(self) -> Dict[str, Any]: """Return aggregated usage statistics.""" return self.usage_tracker.get_stats() def reset_usage_stats(self): """Clear usage statistics.""" self.usage_tracker.reset() # ------------------------------------------------------------------ # Shared helpers # ------------------------------------------------------------------ def _is_anthropic_model(self, model_name: str) -> bool: return "claude" in model_name.lower() def _get_claude_thinking_budget(self) -> Optional[int]: if not self.use_claude_thinking: return None return self.CLAUDE_THINKING_BUDGETS.get(self.reasoning_effort, 2048) def _refresh_service_config(self): if not self._service_config_provider: return try: latest_cfg = self._service_config_provider() or {} self.service_config.update(latest_cfg) except Exception as exc: # pragma: no cover - best effort refresh logger.warning("Failed to refresh service config: %s", exc) def _reset_progress(self): self._partial_messages = [] self._partial_token_usage = {} self._partial_turn_count = 0 def _update_progress( self, messages: List[Dict[str, Any]], token_usage: Dict[str, Any], turn_count: int, ): try: self._partial_messages = copy.deepcopy(messages) self._partial_token_usage = dict(token_usage or {}) self._partial_turn_count = int(turn_count or 0) except Exception: # pragma: no cover - defensive copy pass # ------------------------------------------------------------------ # MCP server management # ------------------------------------------------------------------ async def _create_mcp_server(self) -> Any: if self.mcp_service in self.STDIO_SERVICES: return self._create_stdio_server() if self.mcp_service in self.HTTP_SERVICES: return self._create_http_server() raise ValueError(f"Unsupported MCP service: {self.mcp_service}") def _create_stdio_server(self) -> MCPStdioServer: if self.mcp_service == "notion": notion_key = self.service_config.get("notion_key") if not notion_key: raise ValueError("Notion API key required") return MCPStdioServer( command="npx", args=["-y", "@notionhq/notion-mcp-server"], env={ "OPENAPI_MCP_HEADERS": ( '{"Authorization": "Bearer ' + notion_key + '", ' '"Notion-Version": "2022-06-28"}' ) }, ) if self.mcp_service == "filesystem": test_directory = self.service_config.get("test_directory") if not test_directory: raise ValueError("Test directory required for filesystem service") return MCPStdioServer( command="npx", args=[ "-y", "@modelcontextprotocol/server-filesystem", str(test_directory), ], ) if self.mcp_service in ("playwright", "playwright_webarena"): browser = self.service_config.get("browser", "chromium") headless = self.service_config.get("headless", True) viewport_width = self.service_config.get("viewport_width", 1280) viewport_height = self.service_config.get("viewport_height", 720) args = ["-y", "@playwright/mcp@latest"] if headless: args.append("--headless") args.extend( [ "--isolated", "--no-sandbox", "--browser", browser, "--viewport-size", f"{viewport_width},{viewport_height}", ] ) return MCPStdioServer(command="npx", args=args) if self.mcp_service == "postgres": host = self.service_config.get("host", "localhost") port = self.service_config.get("port", 5432) username = self.service_config.get("username") password = self.service_config.get("password") database = self.service_config.get( "current_database" ) or self.service_config.get("database") if not all([username, password, database]): raise ValueError("PostgreSQL requires username, password, and database") database_url = ( f"postgresql://{username}:{password}@{host}:{port}/{database}" ) return MCPStdioServer( command="pipx", args=["run", "postgres-mcp", "--access-mode=unrestricted"], env={"DATABASE_URI": database_url}, ) if self.mcp_service == "insforge": api_key = self.service_config.get("api_key") backend_url = self.service_config.get("backend_url") if not all([api_key, backend_url]): raise ValueError("Insforge requires api_key and backend_url") return MCPStdioServer( command="npx", args=["-y", "@insforge/mcp@dev"], env={ "INSFORGE_API_KEY": api_key, "INSFORGE_BACKEND_URL": backend_url, }, ) raise ValueError(f"Unsupported stdio service: {self.mcp_service}") def _create_http_server(self) -> MCPHttpServer: if self.mcp_service == "github": github_token = self.service_config.get("github_token") if not github_token: raise ValueError("GitHub token required") return MCPHttpServer( url="https://api.githubcopilot.com/mcp/", headers={ "Authorization": f"Bearer {github_token}", "User-Agent": "MCPMark/1.0", }, ) raise ValueError(f"Unsupported HTTP service: {self.mcp_service}") # ------------------------------------------------------------------ # Message/Tool formatting helpers # ------------------------------------------------------------------ def _compaction_enabled(self) -> bool: return 0 < self.compaction_token < self.COMPACTION_DISABLED_TOKEN def _count_prompt_tokens_litellm(self, messages: List[Dict[str, Any]]) -> int: try: from litellm import token_counter return int( token_counter(model=self.litellm_input_model_name, messages=messages) or 0 ) except Exception: # pragma: no cover - best effort return 0 def _convert_to_sdk_format( self, messages: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: sdk_format: List[Dict[str, Any]] = [] function_call_map: Dict[str, str] = {} for msg in messages: role = msg.get("role") if role == "user": user_content = msg.get("content", "") if isinstance(user_content, list): tool_results = [ item for item in user_content if isinstance(item, dict) and item.get("type") == "tool_result" ] if tool_results: for tr in tool_results: content_items = tr.get("content", []) text_content = "" for ci in content_items: if isinstance(ci, dict) and ci.get("type") == "text": text_content = ci.get("text", "") break sdk_format.append( { "call_id": tr.get("tool_use_id", ""), "output": json.dumps( { "type": "text", "text": text_content, "annotations": None, "meta": None, } ), "type": "function_call_output", } ) else: text_parts = [] for item in user_content: if isinstance(item, dict) and item.get("type") == "text": text_parts.append(item.get("text", "")) sdk_format.append( {"content": "\n".join(text_parts), "role": "user"} ) else: sdk_format.append({"content": user_content, "role": "user"}) elif role == "assistant": tool_calls = msg.get("tool_calls", []) function_call = msg.get("function_call") content = msg.get("content") if isinstance(content, list): text_parts = [] claude_tool_uses = [] for block in content: if isinstance(block, dict): if block.get("type") == "text": text_parts.append(block.get("text", "")) elif block.get("type") == "thinking": thinking_text = block.get("thinking", "") if thinking_text: text_parts.append( f"\n{thinking_text}\n" ) elif block.get("type") == "tool_use": claude_tool_uses.append(block) content = "\n".join(text_parts) if claude_tool_uses and not tool_calls: tool_calls = [] for tu in claude_tool_uses: tool_calls.append( { "id": tu.get("id"), "function": { "name": tu.get("name"), "arguments": json.dumps(tu.get("input", {})), }, } ) if content: sdk_format.append( { "id": "__fake_id__", "content": [ { "annotations": [], "text": content, "type": "output_text", } ], "role": "assistant", "status": "completed", "type": "message", } ) if tool_calls: for tool_call in tool_calls: call_id = tool_call.get("id", f"call_{uuid.uuid4().hex}") func_name = tool_call.get("function", {}).get("name", "") sdk_format.append( { "arguments": tool_call.get("function", {}).get( "arguments", "{}" ), "call_id": call_id, "name": func_name, "type": "function_call", "id": "__fake_id__", } ) if function_call: func_name = function_call.get("name", "") call_id = f"call_{uuid.uuid4().hex}" function_call_map[func_name] = call_id sdk_format.append( { "arguments": function_call.get("arguments", "{}"), "call_id": call_id, "name": func_name, "type": "function_call", "id": "__fake_id__", } ) elif role == "tool": sdk_format.append( { "call_id": msg.get("tool_call_id", ""), "output": json.dumps( { "type": "text", "text": msg.get("content", ""), "annotations": None, "meta": None, } ), "type": "function_call_output", } ) elif role == "function": func_name = msg.get("name", "") call_id = function_call_map.get(func_name, f"call_{uuid.uuid4().hex}") sdk_format.append( { "call_id": call_id, "output": json.dumps( { "type": "text", "text": msg.get("content", ""), "annotations": None, "meta": None, } ), "type": "function_call_output", } ) return sdk_format def _convert_to_anthropic_format( self, tools: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: anthropic_tools = [] for tool in tools: anthropic_tool = { "name": tool.get("name"), "description": tool.get("description", ""), "input_schema": tool.get( "inputSchema", {"type": "object", "properties": {}, "required": []}, ), } anthropic_tools.append(anthropic_tool) return anthropic_tools def _is_gemini_model(self) -> bool: model_lower = self.litellm_input_model_name.lower() return "gemini" in model_lower or "bison" in model_lower def _is_gemini_3_model(self) -> bool: """Check if this is a Gemini 3 series model.""" model_lower = self.litellm_input_model_name.lower() return "gemini-3" in model_lower or "gemini/gemini-3" in model_lower def _simplify_schema_for_gemini( self, schema: Optional[Dict[str, Any]] ) -> Dict[str, Any]: if not isinstance(schema, dict): return schema or {} simplified: Dict[str, Any] = {} for key, value in schema.items(): if key == "type" and isinstance(value, list): simplified[key] = value[0] if value else "string" elif key == "items" and isinstance(value, dict): simplified[key] = self._simplify_schema_for_gemini(value) elif key == "properties" and isinstance(value, dict): simplified[key] = { prop_key: self._simplify_schema_for_gemini(prop_val) for prop_key, prop_val in value.items() } elif isinstance(value, dict): simplified[key] = self._simplify_schema_for_gemini(value) elif isinstance(value, list) and key not in ("required", "enum"): simplified[key] = [ self._simplify_schema_for_gemini(item) if isinstance(item, dict) else item for item in value ] else: simplified[key] = value return simplified def _convert_to_openai_format( self, tools: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: functions = [] is_gemini = self._is_gemini_model() if is_gemini: logger.debug( "Detected Gemini model '%s' – simplifying tool schemas", self.litellm_input_model_name, ) for tool in tools: input_schema = tool.get( "inputSchema", {"type": "object", "properties": {}, "required": []} ) if is_gemini: simplified = self._simplify_schema_for_gemini(input_schema) if simplified != input_schema: input_schema = simplified logger.debug("Simplified schema for tool '%s'", tool.get("name")) functions.append( { "name": tool.get("name"), "description": tool.get("description", ""), "parameters": input_schema, } ) if is_gemini: logger.info("Converted %d tools for Gemini compatibility", len(functions)) return functions