# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. #!/usr/bin/env python3 """ Demo-Ready MCP Server - New Standard Implementation Combines robust session management with comprehensive tool definitions. Features: workspace isolation, tool call tracking, rate limiting, security, and full tool suite. """ import argparse import asyncio import json import logging import time import uuid import yaml from collections import defaultdict, deque from dataclasses import dataclass, field from datetime import datetime, timedelta from pathlib import Path from threading import Thread, Event from typing import Any, Dict, List, Optional # Third-party imports from starlette.applications import Starlette from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, StreamingResponse import uvicorn # Add project root to Python path for imports import sys sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from src.utils.status_codes import JsonRpcErr from http import HTTPStatus # Handle both relative and absolute imports try: from .mcp_tools import MCPTools, get_tool_schemas from .mcp_tools_async import AsyncMCPTools except ImportError: # Fallback for direct script execution from src.tools.mcp_tools import MCPTools, get_tool_schemas try: from src.tools.mcp_tools_async import AsyncMCPTools except ImportError: AsyncMCPTools = None # Workspace knowledge manager disabled WORKSPACE_KNOWLEDGE_AVAILABLE = False # Configure structured logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s', handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler('mcp_server.log') ] ) logger = logging.getLogger(__name__) # ================ CONFIGURATION ================ @dataclass class ServerConfig: """Server configuration with only actually implemented options""" # Server Core Settings host: str = "127.0.0.1" port: int = 6274 debug_mode: bool = False # Session Management session_ttl_seconds: int = 3600 # 1 hour default max_sessions: int = 1000 cleanup_interval_seconds: int = 300 # 5 minutes enable_session_keepalive: bool = True keepalive_touch_interval: int = 300 # Request Handling request_timeout_seconds: int = 120 max_request_size_mb: int = 10 # Client Rate Limiting (per IP) rate_limit_requests_per_minute: int = 300 # Workspace Management base_workspace_dir: str = "workspaces" # Tool Call Tracking & Logging enable_tool_tracking: bool = True max_tracked_calls_per_session: int = 1000 track_detailed_errors: bool = True # Per-tool Rate Limiting Configuration tool_rate_limits: Dict[str, Dict[str, int]] = field(default_factory=dict) @classmethod def from_yaml(cls, config_path: str) -> 'ServerConfig': """Load configuration from YAML file""" try: with open(config_path, 'r') as f: config_data = yaml.safe_load(f) # Extract configuration sections with defaults server_config = config_data.get('server', {}) tracking_config = config_data.get('tracking', {}) tool_rate_limits = config_data.get('tool_rate_limits', {}) return cls( # Server Core Settings host=server_config.get('host', "127.0.0.1"), port=server_config.get('port', 6274), debug_mode=server_config.get('debug_mode', False), # Session Management session_ttl_seconds=server_config.get('session_ttl_seconds', 3600), max_sessions=server_config.get('max_sessions', 1000), cleanup_interval_seconds=server_config.get('cleanup_interval_seconds', 300), enable_session_keepalive=server_config.get('enable_session_keepalive', True), keepalive_touch_interval=server_config.get('keepalive_touch_interval', 300), # Request Handling request_timeout_seconds=server_config.get('request_timeout_seconds', 120), max_request_size_mb=server_config.get('max_request_size_mb', 10), # Client Rate Limiting rate_limit_requests_per_minute=server_config.get('rate_limit_requests_per_minute', 300), # Workspace Management base_workspace_dir=server_config.get('base_workspace_dir', "workspaces"), # Tool Call Tracking & Logging enable_tool_tracking=tracking_config.get('enable_tool_tracking', True), max_tracked_calls_per_session=tracking_config.get('max_tracked_calls_per_session', 1000), track_detailed_errors=tracking_config.get('track_detailed_errors', True), # Per-tool Rate Limiting tool_rate_limits=tool_rate_limits ) except Exception as e: logger.error(f"Failed to load configuration from {config_path}: {e}") logger.info("Using default configuration") return cls() # Global configuration instance - will be set during startup config: Optional[ServerConfig] = None # ================ GLOBAL PER-TOOL RATE LIMITING ================ @dataclass class ToolRateLimit: """Rate limit configuration for a specific tool""" requests_per_minute: float requests_per_hour: float burst_limit: int class GlobalToolRateLimiter: """ Global rate limiter that controls QPS to external APIs per tool. This is shared across all sessions and clients to manage upstream service load. """ def __init__(self, tool_rate_limits: Dict[str, Dict[str, int]]): self.tool_limits: Dict[str, ToolRateLimit] = {} self.tool_requests: Dict[str, deque] = defaultdict(deque) self.lock = asyncio.Lock() # Initialize rate limits for each tool for tool_name, limits_config in tool_rate_limits.items(): self.tool_limits[tool_name] = ToolRateLimit( requests_per_minute=limits_config.get('requests_per_minute', float('inf')), requests_per_hour=limits_config.get('requests_per_hour', float('inf')), burst_limit=limits_config.get('burst_limit', 10) ) self.tool_requests[tool_name] = deque() logger.info(f"Initialized global tool rate limiter for {len(self.tool_limits)} tools") async def is_allowed(self, tool_name: str) -> tuple[bool, Optional[str]]: """ Check if a request to the specified tool is allowed based on global rate limits. Returns: tuple[bool, Optional[str]]: (allowed, reason_if_denied) """ if tool_name not in self.tool_limits: # Tool not configured for rate limiting - allow return True, None async with self.lock: now = time.time() limits = self.tool_limits[tool_name] requests = self.tool_requests[tool_name] # Clean old requests outside the time windows self._cleanup_old_requests(requests, now) # Check various time window limits recent_requests = list(requests) # Check burst limit (rapid requests in last second) - only if specified if limits.burst_limit != float('inf'): burst_count = sum(1 for req_time in recent_requests if now - req_time < 1.0) if burst_count >= limits.burst_limit: return False, f"Tool '{tool_name}' burst limit exceeded ({limits.burst_limit} requests/burst)" # Check per-minute limit - only if specified if limits.requests_per_minute != float('inf'): minute_count = sum(1 for req_time in recent_requests if now - req_time < 60.0) if minute_count >= limits.requests_per_minute: return False, f"Tool '{tool_name}' per-minute limit exceeded ({limits.requests_per_minute} requests/minute)" # Check per-hour limit - only if specified if limits.requests_per_hour != float('inf'): hour_count = sum(1 for req_time in recent_requests if now - req_time < 3600.0) if hour_count >= limits.requests_per_hour: return False, f"Tool '{tool_name}' per-hour limit exceeded ({limits.requests_per_hour} requests/hour)" return True, None async def record_request(self, tool_name: str): """Record a successful request for rate limiting tracking""" if tool_name not in self.tool_limits: return async with self.lock: now = time.time() self.tool_requests[tool_name].append(now) # Keep deque size manageable (only keep last hour of requests) self._cleanup_old_requests(self.tool_requests[tool_name], now) @staticmethod def _cleanup_old_requests(requests: deque, now: float): """Remove requests older than 1 hour to keep memory usage bounded""" while requests and now - requests[0] > 3600.0: # 1 hour requests.popleft() async def get_tool_stats(self, tool_name: str) -> Dict[str, Any]: """Get current usage statistics for a tool""" if tool_name not in self.tool_limits: return {"error": f"Tool '{tool_name}' not configured for rate limiting"} async with self.lock: now = time.time() requests = self.tool_requests[tool_name] limits = self.tool_limits[tool_name] # Clean old requests first self._cleanup_old_requests(requests, now) recent_requests = list(requests) return { "tool_name": tool_name, "current_usage": { "last_second": sum(1 for req_time in recent_requests if now - req_time < 1.0), "last_minute": sum(1 for req_time in recent_requests if now - req_time < 60.0), "last_hour": sum(1 for req_time in recent_requests if now - req_time < 3600.0) }, "limits": { "requests_per_minute": limits.requests_per_minute if limits.requests_per_minute != float('inf') else None, "requests_per_hour": limits.requests_per_hour if limits.requests_per_hour != float('inf') else None, "burst_limit": limits.burst_limit if limits.burst_limit != float('inf') else None }, "utilization": { "minute_utilization": sum(1 for req_time in recent_requests if now - req_time < 60.0) / limits.requests_per_minute if limits.requests_per_minute != float('inf') else 0, "hour_utilization": sum(1 for req_time in recent_requests if now - req_time < 3600.0) / limits.requests_per_hour if limits.requests_per_hour != float('inf') else 0 } } def get_all_stats(self) -> Dict[str, Any]: """Get usage statistics for all tools""" return { tool_name: self.get_tool_stats(tool_name) for tool_name in self.tool_limits.keys() } # Global tool rate limiter instance - will be initialized during startup global_tool_rate_limiter: Optional[GlobalToolRateLimiter] = None # ================ TOOL DEFINITIONS ================ # Tool execution function mapping - maps tool names to their implementation functions def get_tool_function(tool_name: str): """Get the actual function for a tool""" tool_map = { "batch_web_search": lambda tools, **kwargs: tools.batch_web_search(**kwargs), "url_crawler": lambda tools, **kwargs: tools.url_crawler(**kwargs), "download_files": lambda tools, **kwargs: tools.download_files(**kwargs), "list_workspace": lambda tools, **kwargs: tools.list_workspace(**kwargs), "str_replace_based_edit_tool": lambda tools, **kwargs: tools.str_replace_based_edit_tool(**kwargs), "file_stats": lambda tools, **kwargs: tools.file_stats(**kwargs), "file_read": lambda tools, **kwargs: tools.file_read(**kwargs), "file_read_lines": lambda tools, **kwargs: tools.file_read_lines(**kwargs), "content_preview": lambda tools, **kwargs: tools.content_preview(**kwargs), "file_write": lambda tools, **kwargs: tools.file_write(**kwargs), "file_grep_search": lambda tools, **kwargs: tools.file_grep_search(**kwargs), "file_grep_with_context": lambda tools, **kwargs: tools.file_grep_with_context(**kwargs), "file_find_by_name": lambda tools, **kwargs: tools.file_find_by_name(**kwargs), "bash": lambda tools, **kwargs: tools.bash(**kwargs), "task_done": lambda tools, **kwargs: tools.task_done(**kwargs), "think": lambda tools, **kwargs: tools.think(**kwargs), "reflect": lambda tools, **kwargs: tools.reflect(**kwargs), "document_qa": lambda tools, **kwargs: tools.document_qa(**kwargs), "extract_markdown_toc": lambda tools, **kwargs: tools.extract_markdown_toc(**kwargs), "extract_markdown_section": lambda tools, **kwargs: tools.extract_markdown_section(**kwargs), "document_extract": lambda tools, **kwargs: tools.document_extract(**kwargs), "search_result_classifier": lambda tools, **kwargs: tools.search_result_classifier(**kwargs), "info_seeker_subjective_task_done": None, "writer_subjective_task_done": None, "section_writer": lambda tools, **kwargs: tools.section_writer(**kwargs), "concat_section_files": lambda tools, **kwargs: tools.concat_section_files(**kwargs), # Internal tools - available to server but NOT exposed to agents via tool schemas "internal_file_read_unlimited": lambda tools, **kwargs: tools.internal_file_read_unlimited(**kwargs), } return tool_map.get(tool_name) # ================ TOOL CALL TRACKING ================ @dataclass class ToolCallLog: """Individual tool call log entry""" call_id: str timestamp: datetime tool_name: str input_args: Dict[str, Any] output_result: Dict[str, Any] success: bool duration_ms: float error_details: Optional[str] = None session_id: str = "" agent_info: Optional[Dict[str, Any]] = None def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for JSON serialization""" return { "call_id": self.call_id, "timestamp": self.timestamp.isoformat(), "tool_name": self.tool_name, "input_args": self.input_args, "output_result": self.output_result, "success": self.success, "duration_ms": self.duration_ms, "error_details": self.error_details, "session_id": self.session_id, "agent_info": self.agent_info } class ToolCallTracker: """Tracks and saves tool calls to workspace-specific files""" def __init__(self, workspace_path: Path, session_id: str): self.workspace_path = workspace_path self.session_id = session_id self.logs_dir = workspace_path / "tool_call_logs" self.logs_dir.mkdir(exist_ok=True) # Create daily log file today = datetime.now().strftime("%Y-%m-%d") self.current_log_file = self.logs_dir / f"tool_calls_{today}.jsonl" self.summary_file = self.logs_dir / "session_summary.json" # Track call counts self.call_count = 0 self.tool_usage_stats = defaultdict(int) # Initialize session summary self._initialize_session_summary() def _initialize_session_summary(self): """Initialize or update session summary file""" summary = { "session_id": self.session_id, "session_start": datetime.now().isoformat(), "last_updated": datetime.now().isoformat(), "total_tool_calls": 0, "tool_usage_stats": {}, "agent_activity": {}, "workspace_path": str(self.workspace_path) } # Load existing summary if it exists if self.summary_file.exists(): try: with open(self.summary_file, 'r') as f: existing_summary = json.load(f) summary.update(existing_summary) # Don't overwrite session_start if it already exists if "session_start" in existing_summary: summary["session_start"] = existing_summary["session_start"] except Exception as e: logger.warning(f"Could not load existing session summary: {e}") self._save_summary(summary) def _save_summary(self, summary: Dict[str, Any]): """Save session summary to file""" try: with open(self.summary_file, 'w') as f: json.dump(summary, f, indent=2, ensure_ascii=False) except Exception as e: logger.error(f"Failed to save session summary: {e}") def log_tool_call(self, tool_name: str, input_args: Dict[str, Any], output_result: Dict[str, Any], success: bool, duration_ms: float, error_details: Optional[str] = None, agent_info: Optional[Dict[str, Any]] = None) -> str: """Log a tool call and return the call ID""" if not config.enable_tool_tracking: return "" # Respect max call limit per session if self.call_count >= config.max_tracked_calls_per_session: logger.warning(f"Max tracked calls reached for session {self.session_id}") return "" call_id = str(uuid.uuid4()) timestamp = datetime.now() # Create log entry log_entry = ToolCallLog( call_id=call_id, timestamp=timestamp, tool_name=tool_name, input_args=self._sanitize_args(input_args), output_result=self._sanitize_result(output_result), success=success, duration_ms=duration_ms, error_details=error_details if config.track_detailed_errors else None, session_id=self.session_id, agent_info=agent_info ) # Save to JSONL file (one JSON object per line) try: with open(self.current_log_file, 'a', encoding="utf-8") as f: f.write(json.dumps(log_entry.to_dict(), ensure_ascii=False) + '\n') except Exception as e: logger.error(f"Failed to save tool call log: {e}") # Update session summary self._update_session_summary(log_entry) self.call_count += 1 self.tool_usage_stats[tool_name] += 1 return call_id @staticmethod def _sanitize_args(args: Dict[str, Any]) -> Dict[str, Any]: """Sanitize arguments for logging (remove sensitive data)""" sanitized = {} for key, value in args.items(): if isinstance(value, str) and len(value) > 1000: sanitized[key] = value[:1000] + "... [truncated]" elif key.lower() in ['password', 'token', 'secret', 'key']: sanitized[key] = "[REDACTED]" else: sanitized[key] = value return sanitized def _sanitize_result(self, result: Dict[str, Any]) -> Dict[str, Any]: """Sanitize result for logging (remove large content)""" if not isinstance(result, dict): return result sanitized = {} for key, value in result.items(): if isinstance(value, str) and len(value) > 2000: sanitized[key] = value[:2000] + "... [truncated]" elif isinstance(value, dict): sanitized[key] = self._sanitize_result(value) else: sanitized[key] = value return sanitized def _update_session_summary(self, log_entry: ToolCallLog): """Update session summary with new tool call""" try: summary = { "session_id": self.session_id, "last_updated": datetime.now().isoformat(), "total_tool_calls": self.call_count + 1, "tool_usage_stats": dict(self.tool_usage_stats), "workspace_path": str(self.workspace_path) } # Load existing summary if self.summary_file.exists(): with open(self.summary_file, 'r') as f: existing_summary = json.load(f) summary.update(existing_summary) # Update with new data summary["last_updated"] = datetime.now().isoformat() summary["total_tool_calls"] = self.call_count + 1 summary["tool_usage_stats"] = dict(self.tool_usage_stats) summary["tool_usage_stats"][log_entry.tool_name] = self.tool_usage_stats[log_entry.tool_name] + 1 # Track agent activity if log_entry.agent_info: agent_type = log_entry.agent_info.get('type', 'unknown') if 'agent_activity' not in summary: summary['agent_activity'] = {} if agent_type not in summary['agent_activity']: summary['agent_activity'][agent_type] = { 'tool_calls': 0, 'last_active': log_entry.timestamp.isoformat() } summary['agent_activity'][agent_type]['tool_calls'] += 1 summary['agent_activity'][agent_type]['last_active'] = log_entry.timestamp.isoformat() self._save_summary(summary) except Exception as e: logger.error(f"Failed to update session summary: {e}") # ================ SESSION KEEP-ALIVE FOR LONG OPERATIONS ================ class KeepAliveSessionWrapper: """Wrapper that keeps a session alive during long-running operations""" def __init__(self, session: 'Session', touch_interval: int = 300): # Touch every 5 minutes self.session = session self.touch_interval = touch_interval self.keep_alive_thread = None self.stop_event = Event() self.active = False def start_keep_alive(self): """Start the keep-alive mechanism""" if self.active: return self.active = True self.stop_event.clear() def keep_alive_worker(): while not self.stop_event.wait(self.touch_interval): try: self.session.touch() logger.debug("Keep-alive: Touched session {%s}", self.session.id) except Exception as e: logger.error(f"Keep-alive error for session {self.session.id}: {e}") break self.keep_alive_thread = Thread(target=keep_alive_worker, daemon=True) self.keep_alive_thread.start() logger.info(f"Started keep-alive for session {self.session.id}") def stop_keep_alive(self): """Stop the keep-alive mechanism""" if not self.active: return self.active = False self.stop_event.set() if self.keep_alive_thread and self.keep_alive_thread.is_alive(): self.keep_alive_thread.join(timeout=1.0) # Final touch try: self.session.touch() except Exception as e: logger.error(f"Final keep-alive touch error for session {self.session.id}: {e}") logger.info(f"Stopped keep-alive for session {self.session.id}") def __enter__(self): self.start_keep_alive() return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop_keep_alive() # ================ SESSION MANAGEMENT ================ @dataclass class Session: """Thread-safe session data structure with workspace management""" id: str created_at: datetime last_accessed: datetime initialized: bool = False request_count: int = 0 metadata: Dict[str, Any] = field(default_factory=dict) workspace_path: Optional[Path] = None mcp_tools: Optional[MCPTools] = None tool_tracker: Optional[ToolCallTracker] = None def is_expired(self, ttl_seconds: int) -> bool: """Check if session has expired""" return datetime.now() - self.last_accessed > timedelta(seconds=ttl_seconds) def touch(self): """Update last accessed time""" self.last_accessed = datetime.now() self.request_count += 1 def get_mcp_tools(self, prefer_async: bool = True) -> MCPTools: """Get or create MCP tools instance for this session""" if self.mcp_tools is None: # Use async tools if available and preferred if prefer_async and AsyncMCPTools is not None: self.mcp_tools = AsyncMCPTools(workspace_path=str(self.workspace_path) if self.workspace_path else None) else: self.mcp_tools = MCPTools(workspace_path=str(self.workspace_path) if self.workspace_path else None) return self.mcp_tools def get_tool_tracker(self) -> Optional[ToolCallTracker]: """Get or create tool call tracker for this session""" if config.enable_tool_tracking and self.workspace_path: if self.tool_tracker is None: self.tool_tracker = ToolCallTracker(self.workspace_path, self.id) return self.tool_tracker return None class AsyncRLock: """异步可重入锁,模拟 threading.RLock 的异步版本""" def __init__(self): self._lock = asyncio.Lock() self._owner: Optional[asyncio.Task] = None # 记录持有锁的协程任务 self._count = 0 # 重入次数 async def acquire(self): current_task = asyncio.current_task() # 如果当前协程已持有锁,直接增加重入次数 if self._owner == current_task: self._count += 1 return # 否则等待获取锁 await self._lock.acquire() self._owner = current_task self._count = 1 async def release(self): if self._owner != asyncio.current_task(): raise RuntimeError("不能释放非当前协程持有的锁") self._count -= 1 if self._count == 0: # 重入次数归零时,真正释放锁 self._owner = None self._lock.release() # 支持 async with 语法 async def __aenter__(self): await self.acquire() return self async def __aexit__(self, exc_type, exc, tb): await self.release() class ThreadSafeSessionManager: """Thread-safe session manager with workspace management""" def __init__(self, ttl_seconds: int = 3600, max_sessions: int = 1000, base_workspace_dir: str = "workspaces"): self.ttl_seconds = ttl_seconds self.max_sessions = max_sessions self.base_workspace_dir = Path(base_workspace_dir) self.base_workspace_dir.mkdir(exist_ok=True) # Thread-safe session storage self.sessions: Dict[str, Session] = {} self.lock = AsyncRLock() # Start cleanup thread self._start_cleanup_thread() async def create_session(self) -> str: """Create a new session and return session ID""" session_id = str(uuid.uuid4()) async with self.lock: # Check session limits if len(self.sessions) >= self.max_sessions: await self._cleanup_oldest_sessions() # Create workspace directory workspace_path = self.base_workspace_dir / session_id workspace_path.mkdir(exist_ok=True, parents=True) # Create session session = Session( id=session_id, created_at=datetime.now(), last_accessed=datetime.now(), workspace_path=workspace_path ) self.sessions[session_id] = session logger.info(f"Created session {session_id} with workspace {workspace_path}") return session_id async def get_session(self, session_id: str) -> Optional[Session]: """Get session by ID if it exists and is not expired""" async with self.lock: session = self.sessions.get(session_id) if session and not session.is_expired(self.ttl_seconds): session.touch() return session elif session: # Remove expired session del self.sessions[session_id] logger.info(f"Removed expired session {session_id}") return None async def get_or_create_session(self, session_id: Optional[str] = None) -> Session: """Get existing session or create new one""" if session_id: session = await self.get_session(session_id) if session: return session # Create new session new_session_id = await self.create_session() return self.sessions[new_session_id] async def _cleanup_expired_sessions(self): """Remove expired sessions""" async with self.lock: expired_sessions = [] for session_id, session in self.sessions.items(): if session.is_expired(self.ttl_seconds): expired_sessions.append(session_id) for session_id in expired_sessions: del self.sessions[session_id] logger.info(f"Cleaned up expired session {session_id}") async def _cleanup_oldest_sessions(self): """Remove oldest sessions when limit is reached""" async with self.lock: if len(self.sessions) < self.max_sessions: return # Sort by last accessed time and remove oldest sorted_sessions = sorted( self.sessions.items(), key=lambda x: x[1].last_accessed ) sessions_to_remove = len(self.sessions) - self.max_sessions + 10 # Remove extra for i in range(sessions_to_remove): if i < len(sorted_sessions): session_id = sorted_sessions[i][0] del self.sessions[session_id] logger.info(f"Removed old session {session_id} due to session limit") def _start_cleanup_thread(self): """Start background cleanup thread""" def cleanup_worker(): while True: try: time.sleep(config.cleanup_interval_seconds) # Run async method in sync context loop = asyncio.new_event_loop() loop.run_until_complete(self._cleanup_expired_sessions()) loop.close() except Exception as e: logger.error(f"Error in cleanup thread: {e}") import threading cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) cleanup_thread.start() logger.info("Started session cleanup thread") async def get_stats(self) -> Dict[str, Any]: """Get session manager statistics""" async with self.lock: return { "total_sessions": len(self.sessions), "max_sessions": self.max_sessions, "ttl_seconds": self.ttl_seconds, "session_ids": list(self.sessions.keys()) } # ================ MIDDLEWARE AND SECURITY ================ class RateLimiter: """Simple rate limiter with time-window tracking""" def __init__(self, requests_per_minute: int = 60): self.requests_per_minute = requests_per_minute self.requests: Dict[str, List[float]] = defaultdict(list) self.lock = asyncio.Lock() async def is_allowed(self, client_id: str) -> bool: """Check if request is allowed for client""" async with self.lock: now = time.time() minute_ago = now - 60 # Clean old requests self.requests[client_id] = [ req_time for req_time in self.requests[client_id] if req_time > minute_ago ] # Check rate limit if len(self.requests[client_id]) >= self.requests_per_minute: return False # Add current request self.requests[client_id].append(now) return True class RequestValidator: """Validates incoming MCP requests""" @staticmethod def validate_mcp_request(data: Dict[str, Any]) -> tuple[bool, Optional[str]]: """Validate basic MCP request structure""" if not isinstance(data, dict): return False, "Request must be a JSON object" if "method" not in data: return False, "Missing 'method' field" if "id" not in data: return False, "Missing 'id' field" return True, None @staticmethod def validate_tool_call(params: Dict[str, Any]) -> tuple[bool, Optional[str]]: """Validate tool call parameters""" if not isinstance(params, dict): return False, "Tool parameters must be a JSON object" if "name" not in params: return False, "Missing tool 'name'" if "arguments" not in params: return False, "Missing tool 'arguments'" tool_name = params["name"] # Get detailed schemas detailed_schemas = get_tool_schemas() if tool_name not in detailed_schemas: return False, f"Unknown tool: {tool_name}. Available tools: {sorted(list(detailed_schemas.keys()))}" return True, None class SecurityMiddleware(BaseHTTPMiddleware): """Security middleware for basic protection""" async def dispatch(self, request: Request, call_next): # Check content length content_length = request.headers.get("content-length") if content_length and int(content_length) > config.max_request_size_mb * 1024 * 1024: return JSONResponse( status_code=HTTPStatus.REQUEST_ENTITY_TOO_LARGE, content={"error": "Request too large"} ) # Add security headers response = await call_next(request) response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" return response class RateLimitMiddleware(BaseHTTPMiddleware): """Rate limiting middleware""" def __init__(self, app, input_rate_limiter: RateLimiter): super().__init__(app) self.rate_limiter = input_rate_limiter async def dispatch(self, request: Request, call_next): # Get client identifier (IP address) client_ip = request.client.host if request.client else "unknown" if not await self.rate_limiter.is_allowed(client_ip): return JSONResponse( status_code=HTTPStatus.TOO_MANY_REQUESTS, content={"error": "Rate limit exceeded"} ) return await call_next(request) # Global session manager session_manager = None rate_limiter = None @dataclass class RateLimitViolation: """Represents a rate limit violation with standardized error information""" tool_name: str limit_type: str # "burst", "second", "minute", "hour" current_usage: int limit_value: float retry_after_seconds: float def to_user_friendly_message(self) -> str: """Generate user-friendly error message""" if self.limit_type == "burst": return f"Service temporarily unavailable: Too many rapid requests to {self.tool_name}. Please wait {self.retry_after_seconds:.0f} seconds before trying again." elif self.limit_type == "second": return f"Service temporarily unavailable: {self.tool_name} request rate exceeded ({self.limit_value}/second). Please wait {self.retry_after_seconds:.0f} seconds before trying again." elif self.limit_type == "minute": return f"Service temporarily unavailable: {self.tool_name} quota exceeded ({self.limit_value}/minute). Please try again in {self.retry_after_seconds:.0f} seconds." elif self.limit_type == "hour": return f"Service temporarily unavailable: {self.tool_name} hourly quota exceeded ({self.limit_value}/hour). Please try again in {self.retry_after_seconds:.0f} minutes." else: return f"Service temporarily unavailable: {self.tool_name} rate limit exceeded. Please try again later." def to_technical_message(self) -> str: """Generate technical error message for debugging""" return f"Tool '{self.tool_name}' {self.limit_type} limit exceeded ({self.current_usage}/{self.limit_value} {self.limit_type})" def _parse_rate_limit_denial(tool_name: str, denial_reason: str) -> RateLimitViolation: """Parse rate limit denial reason into structured violation information""" import re # Default values limit_type = "unknown" current_usage = 0 limit_value = 0.0 retry_after_seconds = 60.0 # Default retry after 1 minute # Parse different types of rate limit violations if "burst limit exceeded" in denial_reason: limit_type = "burst" retry_after_seconds = 1.0 # Burst limits reset quickly match = re.search(r'\((\d+) requests/burst\)', denial_reason) if match: limit_value = float(match.group(1)) current_usage = int(limit_value) # Approximation elif "per-second limit exceeded" in denial_reason: limit_type = "second" retry_after_seconds = 1.0 # Wait 1 second match = re.search(r'\(([0-9.]+) requests/second\)', denial_reason) if match: limit_value = float(match.group(1)) current_usage = int(limit_value) # Approximation elif "per-minute limit exceeded" in denial_reason: limit_type = "minute" retry_after_seconds = 10.0 # Wait 10 seconds for minute limits match = re.search(r'\(([0-9.]+) requests/minute\)', denial_reason) if match: limit_value = float(match.group(1)) current_usage = int(limit_value) # Approximation elif "per-hour limit exceeded" in denial_reason: limit_type = "hour" retry_after_seconds = 300.0 # Wait 5 minutes for hour limits match = re.search(r'\(([0-9.]+) requests/hour\)', denial_reason) if match: limit_value = float(match.group(1)) current_usage = int(limit_value) # Approximation return RateLimitViolation( tool_name=tool_name, limit_type=limit_type, current_usage=current_usage, limit_value=limit_value, retry_after_seconds=retry_after_seconds ) async def _call_session_tool_async(session: Session, tool_name: str, tool_args: Dict[str, Any], client_ip: str = "unknown") -> Dict[str, Any]: """Execute a tool within a session context with full tracking, workspace management, and global rate limiting""" start_time = time.time() success = False error_details = None result_data = None # Touch session at start of tool execution to prevent expiry during long operations session.touch() try: # CHECK GLOBAL TOOL RATE LIMITS FIRST if global_tool_rate_limiter: allowed, deny_reason = await global_tool_rate_limiter.is_allowed(tool_name) if not allowed: # Parse the denial reason to create structured rate limit violation rate_limit_violation = _parse_rate_limit_denial(tool_name, deny_reason) # Create user-friendly error message user_message = rate_limit_violation.to_user_friendly_message() technical_message = rate_limit_violation.to_technical_message() logger.warning(f"Session {session.id}: {technical_message}") result_data = { "success": False, "error": user_message, "error_code": "RATE_LIMIT_EXCEEDED", "error_type": "rate_limit", "tool_name": tool_name, "limit_type": rate_limit_violation.limit_type, "retry_after_seconds": rate_limit_violation.retry_after_seconds, "data": None, "rate_limited": True, # Keep for backward compatibility "technical_details": technical_message # For debugging } # Still log this for tracking purposes duration_ms = (time.time() - start_time) * 1000 tracker = session.get_tool_tracker() if tracker: try: agent_info = { "client_ip": client_ip, "type": "unknown", "session_request_count": session.request_count } tracker.log_tool_call( tool_name=tool_name, input_args=tool_args, output_result=result_data, success=False, duration_ms=duration_ms, error_details=user_message, agent_info=agent_info ) except Exception as e: logger.error(f"Failed to log rate-limited tool call: {e}") return result_data # Get MCP tools instance for this session (handles workspace isolation) mcp_tools = session.get_mcp_tools(prefer_async=True) # Get tool method directly from the mcp_tools instance if not hasattr(mcp_tools, tool_name): raise ValueError(f"Tool '{tool_name}' not implemented") tool_method = getattr(mcp_tools, tool_name) # Add session context to tool arguments for workspace-aware tools if hasattr(mcp_tools, 'set_session_context'): mcp_tools.set_session_context(session.id, str(session.workspace_path)) # Execute tool with keep-alive for potentially long operations logger.info(f"Session {session.id}: Executing tool '{tool_name}' with args: {list(tool_args.keys())}") # Use keep-alive wrapper for tools that might take a long time long_running_tools = {'batch_web_search', 'url_crawler', 'document_qa', 'document_extract', 'bash'} # Check if the tool method is async import inspect is_async_tool = inspect.iscoroutinefunction(tool_method) # Execute tool based on whether it's async or sync if is_async_tool: # Tool is async - execute directly logger.debug("Executing async tool '{%s}'", tool_name) if config.enable_session_keepalive and tool_name in long_running_tools: # For long-running async tools, use keep-alive with KeepAliveSessionWrapper(session, touch_interval=config.keepalive_touch_interval): result = await tool_method(**tool_args) else: # For regular async tools, execute directly result = await tool_method(**tool_args) else: # Tool is sync - execute in thread pool logger.debug("Executing sync tool '{%s}' in thread pool", tool_name) # Define the synchronous tool execution function def execute_tool_sync(): """Synchronous tool execution to be run in thread pool""" return tool_method(**tool_args) # Execute tool asynchronously in thread pool for true non-blocking execution import asyncio import concurrent.futures # Create a thread pool executor for CPU-bound/blocking operations loop = asyncio.get_event_loop() if config.enable_session_keepalive and tool_name in long_running_tools: # For long-running tools, use keep-alive with async execution with KeepAliveSessionWrapper(session, touch_interval=config.keepalive_touch_interval): # Run in thread pool to avoid blocking the event loop with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: result = await loop.run_in_executor(executor, execute_tool_sync) else: # For regular tools, use async execution without keep-alive with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: result = await loop.run_in_executor(executor, execute_tool_sync) # Touch session after tool execution to update activity session.touch() # Handle different result formats if hasattr(result, 'to_dict'): result_data = result.to_dict() elif isinstance(result, dict): result_data = result else: result_data = {"result": result} success = result_data.get('success', True) if success: logger.info(f"Session {session.id}: Tool '{tool_name}' completed successfully") # RECORD SUCCESSFUL REQUEST FOR RATE LIMITING if global_tool_rate_limiter: await global_tool_rate_limiter.record_request(tool_name) else: error_details = result_data.get('error', 'Unknown error') logger.warning(f"Session {session.id}: Tool '{tool_name}' failed: {error_details}") except Exception as e: success = False error_details = str(e) result_data = { "success": False, "error": error_details, "data": None } logger.error(f"Session {session.id}: Tool '{tool_name}' exception: {e}") # Calculate execution time duration_ms = (time.time() - start_time) * 1000 # Log tool call if tracking is enabled tracker = session.get_tool_tracker() if tracker: try: agent_info = { "client_ip": client_ip, "type": "unknown", # Could be enhanced to detect agent type "session_request_count": session.request_count } tracker.log_tool_call( tool_name=tool_name, input_args=tool_args, output_result=result_data, success=success, duration_ms=duration_ms, error_details=error_details, agent_info=agent_info ) except Exception as e: logger.error(f"Failed to log tool call: {e}") return result_data def create_sse_response(response_data: dict, session_id: str = None) -> StreamingResponse: """Create Server-Sent Events response with proper formatting""" def generate_sse(): try: # Add session info to response if available if session_id: response_data["session_id"] = session_id json_data = json.dumps(response_data, ensure_ascii=False) yield f"event: message\n" yield f"data: {json_data}\n" yield f"\n" except Exception as e: error_data = { "jsonrpc": "2.0", "error": {"code": JsonRpcErr.INTERNAL_ERROR, "message": f"Internal error: {str(e)}"}, "id": response_data.get("id") } json_data = json.dumps(error_data, ensure_ascii=False) yield f"event: error\n" yield f"data: {json_data}\n" yield f"\n" return StreamingResponse( generate_sse(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", } ) def create_error_response(request_id: Any, code: int, message: str, session_id: str = None) -> StreamingResponse: """Create error response in SSE format""" error_data = { "jsonrpc": "2.0", "error": {"code": code, "message": message}, "id": request_id } return create_sse_response(error_data, session_id) def create_rate_limit_response( request_id: Any, tool_name: str, error_message: str, retry_after_seconds: float, limit_type: str, technical_details: str = "", session_id: str = None ) -> JSONResponse: """ Create HTTP 429 Rate Limit Exceeded response with proper headers and error format. Returns proper HTTP status code instead of SSE for rate limiting errors. """ # Calculate retry-after header value retry_after_header = int(max(1.0, retry_after_seconds)) # Create standardized error response error_data = { "error": { "type": "rate_limit_exceeded", "code": "RATE_LIMIT_EXCEEDED", "message": error_message, "details": { "tool_name": tool_name, "limit_type": limit_type, "retry_after_seconds": retry_after_seconds, "technical_details": technical_details } }, "request_id": request_id, "timestamp": datetime.now().isoformat(), "session_id": session_id } # Set appropriate headers headers = { "Retry-After": str(retry_after_header), # HTTP standard header "X-RateLimit-Limit-Type": limit_type, "X-RateLimit-Tool": tool_name, "X-RateLimit-Retry-After": str(retry_after_seconds), "Content-Type": "application/json" } return JSONResponse( status_code=HTTPStatus.TOO_MANY_REQUESTS, # Too Many Requests content=error_data, headers=headers ) async def handle_mcp_request(request: Request) -> StreamingResponse: """Main MCP request handler with session management and tool execution""" try: # Check content length before reading body content_length = request.headers.get("content-length") if content_length: content_size_mb = int(content_length) / (1024 * 1024) if content_size_mb > config.max_request_size_mb: logger.warning(f"Request too large: {content_size_mb:.2f}MB > {config.max_request_size_mb}MB") return create_error_response(None, JsonRpcErr.PARSE_ERROR, f"Request too large: {content_size_mb:.2f}MB") # Parse request with timeout protection try: body = await asyncio.wait_for(request.body(), timeout=config.request_timeout_seconds) except asyncio.TimeoutError: logger.error("Timeout while reading request body") return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Request body read timeout") if not body: return create_error_response(None, JsonRpcErr.PARSE_ERROR, "Empty request body") try: data = json.loads(body.decode('utf-8')) except json.JSONDecodeError as e: return create_error_response(None, JsonRpcErr.PARSE_ERROR, f"Invalid JSON: {str(e)}") # Validate MCP request structure is_valid, error_msg = RequestValidator.validate_mcp_request(data) if not is_valid: return create_error_response(data.get("id"), JsonRpcErr.INVALID_REQUEST, error_msg) request_id = data["id"] method = data["method"] params = data.get("params", {}) # Get or create session session_id = request.headers.get("X-Session-ID") client_ip = request.client.host if request.client else "unknown" session = await session_manager.get_or_create_session(session_id) logger.info(f"Processing {method} request for session {session.id} from {client_ip}") # Handle different MCP methods if method == "initialize": # MCP initialization response_data = { "jsonrpc": "2.0", "result": { "protocolVersion": "2025-06-18", "capabilities": { "tools": {"supportsProgress": True}, "resources": {}, "prompts": {} }, "serverInfo": { "name": "DeepDiver-Demo-MCP", "version": "1.0.0" } }, "id": request_id } elif method == "tools/list": # List available tools using detailed schemas from get_tool_schemas() tools_list = [] detailed_schemas = get_tool_schemas() # Build tools list from schemas for _, detailed_schema in detailed_schemas.items(): tools_list.append({ "name": detailed_schema["name"], "description": detailed_schema["description"], "inputSchema": detailed_schema["inputSchema"] }) logger.info(f"Serving {len(tools_list)} tools with detailed schemas to client") response_data = { "jsonrpc": "2.0", "result": {"tools": tools_list}, "id": request_id } elif method == "tools/call": # Execute tool call is_valid, error_msg = RequestValidator.validate_tool_call(params) if not is_valid: return create_error_response(request_id, JsonRpcErr.INVALID_PARAMS, error_msg, session.id) tool_name = params["name"] tool_arguments = params["arguments"] # Execute tool in session context asynchronously result = await _call_session_tool_async(session, tool_name, tool_arguments, client_ip) # CHECK FOR RATE LIMITING AND RETURN PROPER HTTP STATUS if result.get("rate_limited", False): return create_rate_limit_response( request_id=request_id, tool_name=tool_name, error_message=result.get("error", "Rate limit exceeded"), retry_after_seconds=result.get("retry_after_seconds", 60), limit_type=result.get("limit_type", "unknown"), technical_details=result.get("technical_details", ""), session_id=session.id ) # Format normal response response_data = { "jsonrpc": "2.0", "result": { "content": [ { "type": "text", "text": json.dumps(result, indent=2, ensure_ascii=False) } ] }, "id": request_id } else: return create_error_response(request_id, JsonRpcErr.METHOD_NOT_FOUND, f"Method not found: {method}", session.id) return create_sse_response(response_data, session.id) except asyncio.TimeoutError: logger.warning("Request timeout - client may have disconnected") return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Request timeout") except Exception as e: # Handle client disconnects gracefully if "ClientDisconnect" in str(e) or "ConnectionClosedError" in str(e): logger.warning(f"Client disconnected during request processing: {e}") return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Client disconnected") logger.error(f"Unexpected error in MCP request handler: {e}") import traceback logger.error(traceback.format_exc()) return create_error_response(None, JsonRpcErr.INTERNAL_ERROR, f"Internal server error: {str(e)}") async def handle_health_check(request: Request) -> JSONResponse: """Health check endpoint""" try: stats = await session_manager.get_stats() if session_manager else {} # Get rate limiting summary rate_limit_summary = {} if global_tool_rate_limiter: all_stats = global_tool_rate_limiter.get_all_stats() rate_limit_summary = { "enabled": True, "tools_with_limits": len(all_stats), "total_configured_tools": list(all_stats.keys()) } else: rate_limit_summary = {"enabled": False} health_data = { "status": "healthy", "timestamp": datetime.now().isoformat(), "version": "1.0.0", "session_stats": stats, "features": { "workspace_isolation": True, "tool_call_tracking": config.enable_tool_tracking if config else False, "client_rate_limiting": True, "global_tool_rate_limiting": rate_limit_summary["enabled"], "security_middleware": True, "standardized_rate_limit_responses": True }, "rate_limiting": rate_limit_summary, "error_formats": { "rate_limit_exceeded": { "http_status": HTTPStatus.TOO_MANY_REQUESTS, "headers": ["Retry-After", "X-RateLimit-*"], "error_code": "RATE_LIMIT_EXCEEDED", "response_format": "application/json" } } } return JSONResponse(content=health_data) except Exception as e: return JSONResponse( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, content={"status": "unhealthy", "error": str(e)} ) async def handle_tracking_info(request: Request) -> JSONResponse: """Get tool call tracking information for a session""" try: session_id = request.query_params.get("session_id") if not session_id: return JSONResponse( status_code=HTTPStatus.BAD_REQUEST, content={"error": "session_id parameter required"} ) session = await session_manager.get_session(session_id) if not session: return JSONResponse( status_code=HTTPStatus.NOT_FOUND, content={"error": f"Session {session_id} not found"} ) tracker = session.get_tool_tracker() if not tracker: return JSONResponse( content={ "session_id": session_id, "tracking_enabled": False, "message": "Tool call tracking not enabled or no workspace" } ) # Read session summary summary_data = {} if tracker.summary_file.exists(): try: with open(tracker.summary_file, 'r') as f: summary_data = json.load(f) except Exception as e: logger.error(f"Failed to read session summary: {e}") return JSONResponse(content={ "session_id": session_id, "tracking_enabled": True, "summary": summary_data, "logs_directory": str(tracker.logs_dir), "current_log_file": str(tracker.current_log_file) }) except Exception as e: return JSONResponse( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, content={"error": str(e)} ) async def handle_rate_limit_stats(request: Request) -> JSONResponse: """Get global tool rate limiting statistics""" try: if not global_tool_rate_limiter: return JSONResponse( status_code=HTTPStatus.NOT_FOUND, content={"error": "Global tool rate limiter not initialized"} ) # Check if specific tool requested tool_name = request.query_params.get("tool") if tool_name: # Get stats for specific tool stats = await global_tool_rate_limiter.get_tool_stats(tool_name) return JSONResponse(content=stats) else: # Get stats for all tools all_stats = global_tool_rate_limiter.get_all_stats() return JSONResponse(content={ "timestamp": datetime.now().isoformat(), "global_tool_rate_limiting": True, "tools": all_stats, "summary": { "total_tools_with_limits": len(all_stats), "tools_configured": list(all_stats.keys()) } }) except Exception as e: logger.error(f"Failed to get rate limit stats: {e}") return JSONResponse( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, content={"error": str(e)} ) def create_app() -> Starlette: """Create and configure the Starlette application""" global session_manager, rate_limiter, global_tool_rate_limiter if not config: raise RuntimeError("Server configuration not initialized") # Initialize global components session_manager = ThreadSafeSessionManager( ttl_seconds=config.session_ttl_seconds, max_sessions=config.max_sessions, base_workspace_dir=config.base_workspace_dir ) rate_limiter = RateLimiter(config.rate_limit_requests_per_minute) # Initialize global tool rate limiter if config.tool_rate_limits: global_tool_rate_limiter = GlobalToolRateLimiter(config.tool_rate_limits) logger.info(f"Initialized global tool rate limiter with {len(config.tool_rate_limits)} tool limits") else: logger.info("No tool rate limits configured - tools will run without global rate limiting") # Create app app = Starlette(debug=config.debug_mode) app.add_middleware(SecurityMiddleware) app.add_middleware(RateLimitMiddleware, input_rate_limiter=rate_limiter) # Add routes app.add_route("/mcp", handle_mcp_request, methods=["POST"]) app.add_route("/health", handle_health_check, methods=["GET"]) app.add_route("/tracking", handle_tracking_info, methods=["GET"]) app.add_route("/rate-limits", handle_rate_limit_stats, methods=["GET"]) return app def parse_arguments(): """Parse command line arguments""" parser = argparse.ArgumentParser( description="Demo-Ready MCP Server with Per-Tool Rate Limiting", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python src/tools/mcp_server_standard.py --config src/tools/server_config.yaml python src/tools/mcp_server_standard.py --host 127.0.0.1 --port 8080 python src/tools/mcp_server_standard.py --config custom_config.yaml --debug """ ) parser.add_argument( '--config', '-c', type=str, help='Path to YAML configuration file' ) parser.add_argument( '--host', type=str, help='Server host (overrides config file)' ) parser.add_argument( '--port', '-p', type=int, help='Server port (overrides config file)' ) parser.add_argument( '--debug', action='store_true', help='Enable debug mode (overrides config file)' ) parser.add_argument( '--workspace-dir', type=str, help='Base workspace directory (overrides config file)' ) return parser.parse_args() def print_startup_info(): """Print server startup information""" logger.info("🚀 DeepDiver Demo MCP Server") logger.info("=" * 50) logger.info(f"📊 Features:") logger.info(f" • Session Management: ✅ (TTL: {config.session_ttl_seconds}s)") logger.info(f" • Workspace Isolation: ✅ (Base: {config.base_workspace_dir})") logger.info(f" • Tool Call Tracking: {'✅' if config.enable_tool_tracking else '❌'}") logger.info(f" • Client Rate Limiting: ✅ ({config.rate_limit_requests_per_minute}/min)") logger.info(f" • Global Tool Rate Limiting: {'✅' if config.tool_rate_limits else '❌'}") logger.info(f" • Security Middleware: ✅") # Tool rate limiting information if config.tool_rate_limits: logger.info(f"🚦 Tool Rate Limits: {len(config.tool_rate_limits)} tools configured") for tool_name, limits in list(config.tool_rate_limits.items())[:3]: burst = limits.get('burst_limit', '∞') rpm = limits.get('requests_per_minute', '∞') logger.info(f" • {tool_name}: {rpm}/min, burst: {burst}") if len(config.tool_rate_limits) > 3: logger.info(f" • ... and {len(config.tool_rate_limits) - 3} more tools") # Tool information from schemas tool_schemas = get_tool_schemas() available_tools = list(tool_schemas.keys()) logger.info(f"🔧 Tools Available: {len(available_tools)}") logger.info(f" • All tools defined in schemas: {len(available_tools)} tools") logger.info(f" • Sample tools: {', '.join(sorted(available_tools)[:5])}...") logger.info("=" * 50) def main(): """Main function to run the production MCP server""" global config # Parse command line arguments args = parse_arguments() config = ServerConfig.from_yaml("./src/tools/server_config.yaml") # Apply CLI overrides if args.host: config.host = args.host logger.info(f"🔧 Override: Host = {config.host}") if args.port: config.port = args.port logger.info(f"🔧 Override: Port = {config.port}") if args.debug: config.debug_mode = True logger.info(f"🔧 Override: Debug mode enabled") if args.workspace_dir: config.base_workspace_dir = args.workspace_dir logger.info(f"🔧 Override: Workspace directory = {config.base_workspace_dir}") print_startup_info() try: import os # Calculate optimal worker count for high-concurrency FIRST # Use CPU core count indirectly via uvicorn's defaults; no local variable needed # Override for high-concurrency scenarios if os.getenv('FORCE_HIGH_CONCURRENCY', '').lower() == 'true': pass # Configuration handled elsewhere if needed app = create_app() logger.info(f"🌐 Starting server at http://{config.host}:{config.port}") logger.info(f"📡 MCP endpoint: http://{config.host}:{config.port}/mcp") logger.info(f"🏥 Health check: http://{config.host}:{config.port}/health") logger.info(f"📊 Tracking info: http://{config.host}:{config.port}/tracking?session_id=") logger.info(f"🚦 Rate limit stats: http://{config.host}:{config.port}/rate-limits") uvicorn.run( app, # Use app instance directly for single worker with async optimizations host=config.host, port=config.port, log_level="info", timeout_keep_alive=config.request_timeout_seconds, workers=1, # Single worker with async optimizations backlog=1024, # Larger backlog for high-concurrency access_log=False, # Disable access logs for better performance limit_concurrency=None, # No artificial concurrency limit ) except KeyboardInterrupt: print("\n⏹️ Server stopped by user") except Exception as e: print(f"❌ Server startup failed: {e}") import traceback traceback.print_exc() raise if __name__ == "__main__": main()