|
|
|
|
|
|
|
|
""" |
|
|
Demo-Ready MCP Server - New Standard Implementation |
|
|
Combines robust session management with comprehensive tool definitions. |
|
|
Features: workspace isolation, tool call tracking, rate limiting, security, and full tool suite. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import asyncio |
|
|
import json |
|
|
import logging |
|
|
import time |
|
|
import uuid |
|
|
import yaml |
|
|
from collections import defaultdict, deque |
|
|
from dataclasses import dataclass, field |
|
|
from datetime import datetime, timedelta |
|
|
from pathlib import Path |
|
|
from threading import Thread, Event |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
|
|
|
from starlette.applications import Starlette |
|
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
|
from starlette.requests import Request |
|
|
from starlette.responses import JSONResponse, StreamingResponse |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
import sys |
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) |
|
|
from src.utils.status_codes import JsonRpcErr |
|
|
from http import HTTPStatus |
|
|
|
|
|
|
|
|
try: |
|
|
from .mcp_tools import MCPTools, get_tool_schemas |
|
|
from .mcp_tools_async import AsyncMCPTools |
|
|
except ImportError: |
|
|
|
|
|
from src.tools.mcp_tools import MCPTools, get_tool_schemas |
|
|
try: |
|
|
from src.tools.mcp_tools_async import AsyncMCPTools |
|
|
except ImportError: |
|
|
AsyncMCPTools = None |
|
|
|
|
|
|
|
|
WORKSPACE_KNOWLEDGE_AVAILABLE = False |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s', |
|
|
handlers=[ |
|
|
logging.StreamHandler(sys.stdout), |
|
|
logging.FileHandler('mcp_server.log') |
|
|
] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ServerConfig: |
|
|
"""Server configuration with only actually implemented options""" |
|
|
|
|
|
host: str = "127.0.0.1" |
|
|
port: int = 6274 |
|
|
debug_mode: bool = False |
|
|
|
|
|
|
|
|
session_ttl_seconds: int = 3600 |
|
|
max_sessions: int = 1000 |
|
|
cleanup_interval_seconds: int = 300 |
|
|
enable_session_keepalive: bool = True |
|
|
keepalive_touch_interval: int = 300 |
|
|
|
|
|
|
|
|
request_timeout_seconds: int = 120 |
|
|
max_request_size_mb: int = 10 |
|
|
|
|
|
|
|
|
rate_limit_requests_per_minute: int = 300 |
|
|
|
|
|
|
|
|
base_workspace_dir: str = "workspaces" |
|
|
|
|
|
|
|
|
enable_tool_tracking: bool = True |
|
|
max_tracked_calls_per_session: int = 1000 |
|
|
track_detailed_errors: bool = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tool_rate_limits: Dict[str, Dict[str, int]] = field(default_factory=dict) |
|
|
|
|
|
@classmethod |
|
|
def from_yaml(cls, config_path: str) -> 'ServerConfig': |
|
|
"""Load configuration from YAML file""" |
|
|
try: |
|
|
with open(config_path, 'r') as f: |
|
|
config_data = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
server_config = config_data.get('server', {}) |
|
|
tracking_config = config_data.get('tracking', {}) |
|
|
tool_rate_limits = config_data.get('tool_rate_limits', {}) |
|
|
|
|
|
return cls( |
|
|
|
|
|
host=server_config.get('host', "127.0.0.1"), |
|
|
port=server_config.get('port', 6274), |
|
|
debug_mode=server_config.get('debug_mode', False), |
|
|
|
|
|
|
|
|
session_ttl_seconds=server_config.get('session_ttl_seconds', 3600), |
|
|
max_sessions=server_config.get('max_sessions', 1000), |
|
|
cleanup_interval_seconds=server_config.get('cleanup_interval_seconds', 300), |
|
|
enable_session_keepalive=server_config.get('enable_session_keepalive', True), |
|
|
keepalive_touch_interval=server_config.get('keepalive_touch_interval', 300), |
|
|
|
|
|
|
|
|
request_timeout_seconds=server_config.get('request_timeout_seconds', 120), |
|
|
max_request_size_mb=server_config.get('max_request_size_mb', 10), |
|
|
|
|
|
|
|
|
rate_limit_requests_per_minute=server_config.get('rate_limit_requests_per_minute', 300), |
|
|
|
|
|
|
|
|
base_workspace_dir=server_config.get('base_workspace_dir', "workspaces"), |
|
|
|
|
|
|
|
|
enable_tool_tracking=tracking_config.get('enable_tool_tracking', True), |
|
|
max_tracked_calls_per_session=tracking_config.get('max_tracked_calls_per_session', 1000), |
|
|
track_detailed_errors=tracking_config.get('track_detailed_errors', True), |
|
|
|
|
|
|
|
|
tool_rate_limits=tool_rate_limits |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load configuration from {config_path}: {e}") |
|
|
logger.info("Using default configuration") |
|
|
return cls() |
|
|
|
|
|
|
|
|
config: Optional[ServerConfig] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolRateLimit: |
|
|
"""Rate limit configuration for a specific tool""" |
|
|
requests_per_minute: float |
|
|
requests_per_hour: float |
|
|
burst_limit: int |
|
|
|
|
|
|
|
|
class GlobalToolRateLimiter: |
|
|
""" |
|
|
Global rate limiter that controls QPS to external APIs per tool. |
|
|
This is shared across all sessions and clients to manage upstream service load. |
|
|
""" |
|
|
|
|
|
def __init__(self, tool_rate_limits: Dict[str, Dict[str, int]]): |
|
|
self.tool_limits: Dict[str, ToolRateLimit] = {} |
|
|
self.tool_requests: Dict[str, deque] = defaultdict(deque) |
|
|
self.lock = asyncio.Lock() |
|
|
|
|
|
|
|
|
for tool_name, limits_config in tool_rate_limits.items(): |
|
|
self.tool_limits[tool_name] = ToolRateLimit( |
|
|
requests_per_minute=limits_config.get('requests_per_minute', float('inf')), |
|
|
requests_per_hour=limits_config.get('requests_per_hour', float('inf')), |
|
|
burst_limit=limits_config.get('burst_limit', 10) |
|
|
) |
|
|
self.tool_requests[tool_name] = deque() |
|
|
|
|
|
logger.info(f"Initialized global tool rate limiter for {len(self.tool_limits)} tools") |
|
|
|
|
|
async def is_allowed(self, tool_name: str) -> tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Check if a request to the specified tool is allowed based on global rate limits. |
|
|
|
|
|
Returns: |
|
|
tuple[bool, Optional[str]]: (allowed, reason_if_denied) |
|
|
""" |
|
|
if tool_name not in self.tool_limits: |
|
|
|
|
|
return True, None |
|
|
|
|
|
async with self.lock: |
|
|
now = time.time() |
|
|
limits = self.tool_limits[tool_name] |
|
|
requests = self.tool_requests[tool_name] |
|
|
|
|
|
|
|
|
self._cleanup_old_requests(requests, now) |
|
|
|
|
|
|
|
|
recent_requests = list(requests) |
|
|
|
|
|
|
|
|
if limits.burst_limit != float('inf'): |
|
|
burst_count = sum(1 for req_time in recent_requests if now - req_time < 1.0) |
|
|
if burst_count >= limits.burst_limit: |
|
|
return False, f"Tool '{tool_name}' burst limit exceeded ({limits.burst_limit} requests/burst)" |
|
|
|
|
|
|
|
|
if limits.requests_per_minute != float('inf'): |
|
|
minute_count = sum(1 for req_time in recent_requests if now - req_time < 60.0) |
|
|
if minute_count >= limits.requests_per_minute: |
|
|
return False, f"Tool '{tool_name}' per-minute limit exceeded ({limits.requests_per_minute} requests/minute)" |
|
|
|
|
|
|
|
|
if limits.requests_per_hour != float('inf'): |
|
|
hour_count = sum(1 for req_time in recent_requests if now - req_time < 3600.0) |
|
|
if hour_count >= limits.requests_per_hour: |
|
|
return False, f"Tool '{tool_name}' per-hour limit exceeded ({limits.requests_per_hour} requests/hour)" |
|
|
|
|
|
return True, None |
|
|
|
|
|
async def record_request(self, tool_name: str): |
|
|
"""Record a successful request for rate limiting tracking""" |
|
|
if tool_name not in self.tool_limits: |
|
|
return |
|
|
|
|
|
async with self.lock: |
|
|
now = time.time() |
|
|
self.tool_requests[tool_name].append(now) |
|
|
|
|
|
|
|
|
self._cleanup_old_requests(self.tool_requests[tool_name], now) |
|
|
|
|
|
@staticmethod |
|
|
def _cleanup_old_requests(requests: deque, now: float): |
|
|
"""Remove requests older than 1 hour to keep memory usage bounded""" |
|
|
while requests and now - requests[0] > 3600.0: |
|
|
requests.popleft() |
|
|
|
|
|
async def get_tool_stats(self, tool_name: str) -> Dict[str, Any]: |
|
|
"""Get current usage statistics for a tool""" |
|
|
if tool_name not in self.tool_limits: |
|
|
return {"error": f"Tool '{tool_name}' not configured for rate limiting"} |
|
|
|
|
|
async with self.lock: |
|
|
now = time.time() |
|
|
requests = self.tool_requests[tool_name] |
|
|
limits = self.tool_limits[tool_name] |
|
|
|
|
|
|
|
|
self._cleanup_old_requests(requests, now) |
|
|
|
|
|
recent_requests = list(requests) |
|
|
|
|
|
return { |
|
|
"tool_name": tool_name, |
|
|
"current_usage": { |
|
|
"last_second": sum(1 for req_time in recent_requests if now - req_time < 1.0), |
|
|
"last_minute": sum(1 for req_time in recent_requests if now - req_time < 60.0), |
|
|
"last_hour": sum(1 for req_time in recent_requests if now - req_time < 3600.0) |
|
|
}, |
|
|
"limits": { |
|
|
"requests_per_minute": limits.requests_per_minute if limits.requests_per_minute != float('inf') else None, |
|
|
"requests_per_hour": limits.requests_per_hour if limits.requests_per_hour != float('inf') else None, |
|
|
"burst_limit": limits.burst_limit if limits.burst_limit != float('inf') else None |
|
|
}, |
|
|
"utilization": { |
|
|
"minute_utilization": sum(1 for req_time in recent_requests if now - req_time < 60.0) / limits.requests_per_minute if limits.requests_per_minute != float('inf') else 0, |
|
|
"hour_utilization": sum(1 for req_time in recent_requests if now - req_time < 3600.0) / limits.requests_per_hour if limits.requests_per_hour != float('inf') else 0 |
|
|
} |
|
|
} |
|
|
|
|
|
def get_all_stats(self) -> Dict[str, Any]: |
|
|
"""Get usage statistics for all tools""" |
|
|
return { |
|
|
tool_name: self.get_tool_stats(tool_name) |
|
|
for tool_name in self.tool_limits.keys() |
|
|
} |
|
|
|
|
|
|
|
|
global_tool_rate_limiter: Optional[GlobalToolRateLimiter] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_tool_function(tool_name: str): |
|
|
"""Get the actual function for a tool""" |
|
|
tool_map = { |
|
|
"batch_web_search": lambda tools, **kwargs: tools.batch_web_search(**kwargs), |
|
|
"url_crawler": lambda tools, **kwargs: tools.url_crawler(**kwargs), |
|
|
"download_files": lambda tools, **kwargs: tools.download_files(**kwargs), |
|
|
"list_workspace": lambda tools, **kwargs: tools.list_workspace(**kwargs), |
|
|
"str_replace_based_edit_tool": lambda tools, **kwargs: tools.str_replace_based_edit_tool(**kwargs), |
|
|
"file_stats": lambda tools, **kwargs: tools.file_stats(**kwargs), |
|
|
"file_read": lambda tools, **kwargs: tools.file_read(**kwargs), |
|
|
"file_read_lines": lambda tools, **kwargs: tools.file_read_lines(**kwargs), |
|
|
"content_preview": lambda tools, **kwargs: tools.content_preview(**kwargs), |
|
|
"file_write": lambda tools, **kwargs: tools.file_write(**kwargs), |
|
|
"file_grep_search": lambda tools, **kwargs: tools.file_grep_search(**kwargs), |
|
|
"file_grep_with_context": lambda tools, **kwargs: tools.file_grep_with_context(**kwargs), |
|
|
"file_find_by_name": lambda tools, **kwargs: tools.file_find_by_name(**kwargs), |
|
|
"bash": lambda tools, **kwargs: tools.bash(**kwargs), |
|
|
"task_done": lambda tools, **kwargs: tools.task_done(**kwargs), |
|
|
"think": lambda tools, **kwargs: tools.think(**kwargs), |
|
|
"reflect": lambda tools, **kwargs: tools.reflect(**kwargs), |
|
|
"document_qa": lambda tools, **kwargs: tools.document_qa(**kwargs), |
|
|
"extract_markdown_toc": lambda tools, **kwargs: tools.extract_markdown_toc(**kwargs), |
|
|
"extract_markdown_section": lambda tools, **kwargs: tools.extract_markdown_section(**kwargs), |
|
|
|
|
|
"document_extract": lambda tools, **kwargs: tools.document_extract(**kwargs), |
|
|
"search_result_classifier": lambda tools, **kwargs: tools.search_result_classifier(**kwargs), |
|
|
"info_seeker_subjective_task_done": None, |
|
|
"writer_subjective_task_done": None, |
|
|
"section_writer": lambda tools, **kwargs: tools.section_writer(**kwargs), |
|
|
"concat_section_files": lambda tools, **kwargs: tools.concat_section_files(**kwargs), |
|
|
|
|
|
|
|
|
"internal_file_read_unlimited": lambda tools, **kwargs: tools.internal_file_read_unlimited(**kwargs), |
|
|
} |
|
|
return tool_map.get(tool_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolCallLog: |
|
|
"""Individual tool call log entry""" |
|
|
call_id: str |
|
|
timestamp: datetime |
|
|
tool_name: str |
|
|
input_args: Dict[str, Any] |
|
|
output_result: Dict[str, Any] |
|
|
success: bool |
|
|
duration_ms: float |
|
|
error_details: Optional[str] = None |
|
|
session_id: str = "" |
|
|
agent_info: Optional[Dict[str, Any]] = None |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert to dictionary for JSON serialization""" |
|
|
return { |
|
|
"call_id": self.call_id, |
|
|
"timestamp": self.timestamp.isoformat(), |
|
|
"tool_name": self.tool_name, |
|
|
"input_args": self.input_args, |
|
|
"output_result": self.output_result, |
|
|
"success": self.success, |
|
|
"duration_ms": self.duration_ms, |
|
|
"error_details": self.error_details, |
|
|
"session_id": self.session_id, |
|
|
"agent_info": self.agent_info |
|
|
} |
|
|
|
|
|
|
|
|
class ToolCallTracker: |
|
|
"""Tracks and saves tool calls to workspace-specific files""" |
|
|
|
|
|
def __init__(self, workspace_path: Path, session_id: str): |
|
|
self.workspace_path = workspace_path |
|
|
self.session_id = session_id |
|
|
self.logs_dir = workspace_path / "tool_call_logs" |
|
|
self.logs_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
today = datetime.now().strftime("%Y-%m-%d") |
|
|
self.current_log_file = self.logs_dir / f"tool_calls_{today}.jsonl" |
|
|
self.summary_file = self.logs_dir / "session_summary.json" |
|
|
|
|
|
|
|
|
self.call_count = 0 |
|
|
self.tool_usage_stats = defaultdict(int) |
|
|
|
|
|
|
|
|
self._initialize_session_summary() |
|
|
|
|
|
def _initialize_session_summary(self): |
|
|
"""Initialize or update session summary file""" |
|
|
summary = { |
|
|
"session_id": self.session_id, |
|
|
"session_start": datetime.now().isoformat(), |
|
|
"last_updated": datetime.now().isoformat(), |
|
|
"total_tool_calls": 0, |
|
|
"tool_usage_stats": {}, |
|
|
"agent_activity": {}, |
|
|
"workspace_path": str(self.workspace_path) |
|
|
} |
|
|
|
|
|
|
|
|
if self.summary_file.exists(): |
|
|
try: |
|
|
with open(self.summary_file, 'r') as f: |
|
|
existing_summary = json.load(f) |
|
|
summary.update(existing_summary) |
|
|
|
|
|
if "session_start" in existing_summary: |
|
|
summary["session_start"] = existing_summary["session_start"] |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not load existing session summary: {e}") |
|
|
|
|
|
self._save_summary(summary) |
|
|
|
|
|
def _save_summary(self, summary: Dict[str, Any]): |
|
|
"""Save session summary to file""" |
|
|
try: |
|
|
with open(self.summary_file, 'w') as f: |
|
|
json.dump(summary, f, indent=2, ensure_ascii=False) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save session summary: {e}") |
|
|
|
|
|
def log_tool_call(self, |
|
|
tool_name: str, |
|
|
input_args: Dict[str, Any], |
|
|
output_result: Dict[str, Any], |
|
|
success: bool, |
|
|
duration_ms: float, |
|
|
error_details: Optional[str] = None, |
|
|
agent_info: Optional[Dict[str, Any]] = None) -> str: |
|
|
"""Log a tool call and return the call ID""" |
|
|
|
|
|
if not config.enable_tool_tracking: |
|
|
return "" |
|
|
|
|
|
|
|
|
if self.call_count >= config.max_tracked_calls_per_session: |
|
|
logger.warning(f"Max tracked calls reached for session {self.session_id}") |
|
|
return "" |
|
|
|
|
|
call_id = str(uuid.uuid4()) |
|
|
timestamp = datetime.now() |
|
|
|
|
|
|
|
|
log_entry = ToolCallLog( |
|
|
call_id=call_id, |
|
|
timestamp=timestamp, |
|
|
tool_name=tool_name, |
|
|
input_args=self._sanitize_args(input_args), |
|
|
output_result=self._sanitize_result(output_result), |
|
|
success=success, |
|
|
duration_ms=duration_ms, |
|
|
error_details=error_details if config.track_detailed_errors else None, |
|
|
session_id=self.session_id, |
|
|
agent_info=agent_info |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
with open(self.current_log_file, 'a', encoding="utf-8") as f: |
|
|
f.write(json.dumps(log_entry.to_dict(), ensure_ascii=False) + '\n') |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save tool call log: {e}") |
|
|
|
|
|
|
|
|
self._update_session_summary(log_entry) |
|
|
|
|
|
self.call_count += 1 |
|
|
self.tool_usage_stats[tool_name] += 1 |
|
|
|
|
|
return call_id |
|
|
|
|
|
@staticmethod |
|
|
def _sanitize_args(args: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Sanitize arguments for logging (remove sensitive data)""" |
|
|
sanitized = {} |
|
|
for key, value in args.items(): |
|
|
if isinstance(value, str) and len(value) > 1000: |
|
|
sanitized[key] = value[:1000] + "... [truncated]" |
|
|
elif key.lower() in ['password', 'token', 'secret', 'key']: |
|
|
sanitized[key] = "[REDACTED]" |
|
|
else: |
|
|
sanitized[key] = value |
|
|
return sanitized |
|
|
|
|
|
def _sanitize_result(self, result: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Sanitize result for logging (remove large content)""" |
|
|
if not isinstance(result, dict): |
|
|
return result |
|
|
|
|
|
sanitized = {} |
|
|
for key, value in result.items(): |
|
|
if isinstance(value, str) and len(value) > 2000: |
|
|
sanitized[key] = value[:2000] + "... [truncated]" |
|
|
elif isinstance(value, dict): |
|
|
sanitized[key] = self._sanitize_result(value) |
|
|
else: |
|
|
sanitized[key] = value |
|
|
return sanitized |
|
|
|
|
|
def _update_session_summary(self, log_entry: ToolCallLog): |
|
|
"""Update session summary with new tool call""" |
|
|
try: |
|
|
summary = { |
|
|
"session_id": self.session_id, |
|
|
"last_updated": datetime.now().isoformat(), |
|
|
"total_tool_calls": self.call_count + 1, |
|
|
"tool_usage_stats": dict(self.tool_usage_stats), |
|
|
"workspace_path": str(self.workspace_path) |
|
|
} |
|
|
|
|
|
|
|
|
if self.summary_file.exists(): |
|
|
with open(self.summary_file, 'r') as f: |
|
|
existing_summary = json.load(f) |
|
|
summary.update(existing_summary) |
|
|
|
|
|
|
|
|
summary["last_updated"] = datetime.now().isoformat() |
|
|
summary["total_tool_calls"] = self.call_count + 1 |
|
|
summary["tool_usage_stats"] = dict(self.tool_usage_stats) |
|
|
summary["tool_usage_stats"][log_entry.tool_name] = self.tool_usage_stats[log_entry.tool_name] + 1 |
|
|
|
|
|
|
|
|
if log_entry.agent_info: |
|
|
agent_type = log_entry.agent_info.get('type', 'unknown') |
|
|
if 'agent_activity' not in summary: |
|
|
summary['agent_activity'] = {} |
|
|
if agent_type not in summary['agent_activity']: |
|
|
summary['agent_activity'][agent_type] = { |
|
|
'tool_calls': 0, |
|
|
'last_active': log_entry.timestamp.isoformat() |
|
|
} |
|
|
summary['agent_activity'][agent_type]['tool_calls'] += 1 |
|
|
summary['agent_activity'][agent_type]['last_active'] = log_entry.timestamp.isoformat() |
|
|
|
|
|
self._save_summary(summary) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to update session summary: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KeepAliveSessionWrapper: |
|
|
"""Wrapper that keeps a session alive during long-running operations""" |
|
|
|
|
|
def __init__(self, session: 'Session', touch_interval: int = 300): |
|
|
self.session = session |
|
|
self.touch_interval = touch_interval |
|
|
self.keep_alive_thread = None |
|
|
self.stop_event = Event() |
|
|
self.active = False |
|
|
|
|
|
def start_keep_alive(self): |
|
|
"""Start the keep-alive mechanism""" |
|
|
if self.active: |
|
|
return |
|
|
|
|
|
self.active = True |
|
|
self.stop_event.clear() |
|
|
|
|
|
def keep_alive_worker(): |
|
|
while not self.stop_event.wait(self.touch_interval): |
|
|
try: |
|
|
self.session.touch() |
|
|
logger.debug("Keep-alive: Touched session {%s}", self.session.id) |
|
|
except Exception as e: |
|
|
logger.error(f"Keep-alive error for session {self.session.id}: {e}") |
|
|
break |
|
|
|
|
|
self.keep_alive_thread = Thread(target=keep_alive_worker, daemon=True) |
|
|
self.keep_alive_thread.start() |
|
|
logger.info(f"Started keep-alive for session {self.session.id}") |
|
|
|
|
|
def stop_keep_alive(self): |
|
|
"""Stop the keep-alive mechanism""" |
|
|
if not self.active: |
|
|
return |
|
|
|
|
|
self.active = False |
|
|
self.stop_event.set() |
|
|
|
|
|
if self.keep_alive_thread and self.keep_alive_thread.is_alive(): |
|
|
self.keep_alive_thread.join(timeout=1.0) |
|
|
|
|
|
|
|
|
try: |
|
|
self.session.touch() |
|
|
except Exception as e: |
|
|
logger.error(f"Final keep-alive touch error for session {self.session.id}: {e}") |
|
|
|
|
|
logger.info(f"Stopped keep-alive for session {self.session.id}") |
|
|
|
|
|
def __enter__(self): |
|
|
self.start_keep_alive() |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
self.stop_keep_alive() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Session: |
|
|
"""Thread-safe session data structure with workspace management""" |
|
|
id: str |
|
|
created_at: datetime |
|
|
last_accessed: datetime |
|
|
initialized: bool = False |
|
|
request_count: int = 0 |
|
|
metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
workspace_path: Optional[Path] = None |
|
|
mcp_tools: Optional[MCPTools] = None |
|
|
tool_tracker: Optional[ToolCallTracker] = None |
|
|
|
|
|
|
|
|
def is_expired(self, ttl_seconds: int) -> bool: |
|
|
"""Check if session has expired""" |
|
|
return datetime.now() - self.last_accessed > timedelta(seconds=ttl_seconds) |
|
|
|
|
|
def touch(self): |
|
|
"""Update last accessed time""" |
|
|
self.last_accessed = datetime.now() |
|
|
self.request_count += 1 |
|
|
|
|
|
def get_mcp_tools(self, prefer_async: bool = True) -> MCPTools: |
|
|
"""Get or create MCP tools instance for this session""" |
|
|
if self.mcp_tools is None: |
|
|
|
|
|
if prefer_async and AsyncMCPTools is not None: |
|
|
self.mcp_tools = AsyncMCPTools(workspace_path=str(self.workspace_path) if self.workspace_path else None) |
|
|
else: |
|
|
self.mcp_tools = MCPTools(workspace_path=str(self.workspace_path) if self.workspace_path else None) |
|
|
return self.mcp_tools |
|
|
|
|
|
def get_tool_tracker(self) -> Optional[ToolCallTracker]: |
|
|
"""Get or create tool call tracker for this session""" |
|
|
if config.enable_tool_tracking and self.workspace_path: |
|
|
if self.tool_tracker is None: |
|
|
self.tool_tracker = ToolCallTracker(self.workspace_path, self.id) |
|
|
return self.tool_tracker |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
class AsyncRLock: |
|
|
"""异步可重入锁,模拟 threading.RLock 的异步版本""" |
|
|
def __init__(self): |
|
|
self._lock = asyncio.Lock() |
|
|
self._owner: Optional[asyncio.Task] = None |
|
|
self._count = 0 |
|
|
|
|
|
async def acquire(self): |
|
|
current_task = asyncio.current_task() |
|
|
|
|
|
if self._owner == current_task: |
|
|
self._count += 1 |
|
|
return |
|
|
|
|
|
await self._lock.acquire() |
|
|
self._owner = current_task |
|
|
self._count = 1 |
|
|
|
|
|
async def release(self): |
|
|
if self._owner != asyncio.current_task(): |
|
|
raise RuntimeError("不能释放非当前协程持有的锁") |
|
|
self._count -= 1 |
|
|
if self._count == 0: |
|
|
self._owner = None |
|
|
self._lock.release() |
|
|
|
|
|
|
|
|
async def __aenter__(self): |
|
|
await self.acquire() |
|
|
return self |
|
|
|
|
|
async def __aexit__(self, exc_type, exc, tb): |
|
|
await self.release() |
|
|
|
|
|
|
|
|
class ThreadSafeSessionManager: |
|
|
"""Thread-safe session manager with workspace management""" |
|
|
|
|
|
def __init__(self, ttl_seconds: int = 3600, max_sessions: int = 1000, base_workspace_dir: str = "workspaces"): |
|
|
self.ttl_seconds = ttl_seconds |
|
|
self.max_sessions = max_sessions |
|
|
self.base_workspace_dir = Path(base_workspace_dir) |
|
|
self.base_workspace_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
self.sessions: Dict[str, Session] = {} |
|
|
self.lock = AsyncRLock() |
|
|
|
|
|
|
|
|
self._start_cleanup_thread() |
|
|
|
|
|
async def create_session(self) -> str: |
|
|
"""Create a new session and return session ID""" |
|
|
session_id = str(uuid.uuid4()) |
|
|
|
|
|
async with self.lock: |
|
|
|
|
|
if len(self.sessions) >= self.max_sessions: |
|
|
await self._cleanup_oldest_sessions() |
|
|
|
|
|
|
|
|
workspace_path = self.base_workspace_dir / session_id |
|
|
workspace_path.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
|
|
|
session = Session( |
|
|
id=session_id, |
|
|
created_at=datetime.now(), |
|
|
last_accessed=datetime.now(), |
|
|
workspace_path=workspace_path |
|
|
) |
|
|
|
|
|
self.sessions[session_id] = session |
|
|
|
|
|
logger.info(f"Created session {session_id} with workspace {workspace_path}") |
|
|
return session_id |
|
|
|
|
|
async def get_session(self, session_id: str) -> Optional[Session]: |
|
|
"""Get session by ID if it exists and is not expired""" |
|
|
async with self.lock: |
|
|
session = self.sessions.get(session_id) |
|
|
if session and not session.is_expired(self.ttl_seconds): |
|
|
session.touch() |
|
|
return session |
|
|
elif session: |
|
|
|
|
|
del self.sessions[session_id] |
|
|
logger.info(f"Removed expired session {session_id}") |
|
|
return None |
|
|
|
|
|
async def get_or_create_session(self, session_id: Optional[str] = None) -> Session: |
|
|
"""Get existing session or create new one""" |
|
|
if session_id: |
|
|
session = await self.get_session(session_id) |
|
|
if session: |
|
|
return session |
|
|
|
|
|
|
|
|
new_session_id = await self.create_session() |
|
|
return self.sessions[new_session_id] |
|
|
|
|
|
async def _cleanup_expired_sessions(self): |
|
|
"""Remove expired sessions""" |
|
|
async with self.lock: |
|
|
expired_sessions = [] |
|
|
for session_id, session in self.sessions.items(): |
|
|
if session.is_expired(self.ttl_seconds): |
|
|
expired_sessions.append(session_id) |
|
|
|
|
|
for session_id in expired_sessions: |
|
|
del self.sessions[session_id] |
|
|
logger.info(f"Cleaned up expired session {session_id}") |
|
|
|
|
|
async def _cleanup_oldest_sessions(self): |
|
|
"""Remove oldest sessions when limit is reached""" |
|
|
async with self.lock: |
|
|
if len(self.sessions) < self.max_sessions: |
|
|
return |
|
|
|
|
|
|
|
|
sorted_sessions = sorted( |
|
|
self.sessions.items(), |
|
|
key=lambda x: x[1].last_accessed |
|
|
) |
|
|
|
|
|
sessions_to_remove = len(self.sessions) - self.max_sessions + 10 |
|
|
for i in range(sessions_to_remove): |
|
|
if i < len(sorted_sessions): |
|
|
session_id = sorted_sessions[i][0] |
|
|
del self.sessions[session_id] |
|
|
logger.info(f"Removed old session {session_id} due to session limit") |
|
|
|
|
|
def _start_cleanup_thread(self): |
|
|
"""Start background cleanup thread""" |
|
|
def cleanup_worker(): |
|
|
while True: |
|
|
try: |
|
|
time.sleep(config.cleanup_interval_seconds) |
|
|
|
|
|
loop = asyncio.new_event_loop() |
|
|
loop.run_until_complete(self._cleanup_expired_sessions()) |
|
|
loop.close() |
|
|
except Exception as e: |
|
|
logger.error(f"Error in cleanup thread: {e}") |
|
|
|
|
|
import threading |
|
|
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) |
|
|
cleanup_thread.start() |
|
|
logger.info("Started session cleanup thread") |
|
|
|
|
|
|
|
|
|
|
|
async def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get session manager statistics""" |
|
|
async with self.lock: |
|
|
return { |
|
|
"total_sessions": len(self.sessions), |
|
|
"max_sessions": self.max_sessions, |
|
|
"ttl_seconds": self.ttl_seconds, |
|
|
"session_ids": list(self.sessions.keys()) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RateLimiter: |
|
|
"""Simple rate limiter with time-window tracking""" |
|
|
|
|
|
def __init__(self, requests_per_minute: int = 60): |
|
|
self.requests_per_minute = requests_per_minute |
|
|
self.requests: Dict[str, List[float]] = defaultdict(list) |
|
|
self.lock = asyncio.Lock() |
|
|
|
|
|
async def is_allowed(self, client_id: str) -> bool: |
|
|
"""Check if request is allowed for client""" |
|
|
async with self.lock: |
|
|
now = time.time() |
|
|
minute_ago = now - 60 |
|
|
|
|
|
|
|
|
self.requests[client_id] = [ |
|
|
req_time for req_time in self.requests[client_id] |
|
|
if req_time > minute_ago |
|
|
] |
|
|
|
|
|
|
|
|
if len(self.requests[client_id]) >= self.requests_per_minute: |
|
|
return False |
|
|
|
|
|
|
|
|
self.requests[client_id].append(now) |
|
|
return True |
|
|
|
|
|
|
|
|
class RequestValidator: |
|
|
"""Validates incoming MCP requests""" |
|
|
|
|
|
@staticmethod |
|
|
def validate_mcp_request(data: Dict[str, Any]) -> tuple[bool, Optional[str]]: |
|
|
"""Validate basic MCP request structure""" |
|
|
if not isinstance(data, dict): |
|
|
return False, "Request must be a JSON object" |
|
|
|
|
|
if "method" not in data: |
|
|
return False, "Missing 'method' field" |
|
|
|
|
|
if "id" not in data: |
|
|
return False, "Missing 'id' field" |
|
|
|
|
|
return True, None |
|
|
|
|
|
@staticmethod |
|
|
def validate_tool_call(params: Dict[str, Any]) -> tuple[bool, Optional[str]]: |
|
|
"""Validate tool call parameters""" |
|
|
if not isinstance(params, dict): |
|
|
return False, "Tool parameters must be a JSON object" |
|
|
|
|
|
if "name" not in params: |
|
|
return False, "Missing tool 'name'" |
|
|
|
|
|
if "arguments" not in params: |
|
|
return False, "Missing tool 'arguments'" |
|
|
|
|
|
tool_name = params["name"] |
|
|
|
|
|
|
|
|
detailed_schemas = get_tool_schemas() |
|
|
|
|
|
if tool_name not in detailed_schemas: |
|
|
return False, f"Unknown tool: {tool_name}. Available tools: {sorted(list(detailed_schemas.keys()))}" |
|
|
|
|
|
return True, None |
|
|
|
|
|
|
|
|
class SecurityMiddleware(BaseHTTPMiddleware): |
|
|
"""Security middleware for basic protection""" |
|
|
|
|
|
async def dispatch(self, request: Request, call_next): |
|
|
|
|
|
content_length = request.headers.get("content-length") |
|
|
if content_length and int(content_length) > config.max_request_size_mb * 1024 * 1024: |
|
|
return JSONResponse( |
|
|
status_code=HTTPStatus.REQUEST_ENTITY_TOO_LARGE, |
|
|
content={"error": "Request too large"} |
|
|
) |
|
|
|
|
|
|
|
|
response = await call_next(request) |
|
|
response.headers["X-Content-Type-Options"] = "nosniff" |
|
|
response.headers["X-Frame-Options"] = "DENY" |
|
|
response.headers["X-XSS-Protection"] = "1; mode=block" |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware): |
|
|
"""Rate limiting middleware""" |
|
|
|
|
|
def __init__(self, app, input_rate_limiter: RateLimiter): |
|
|
super().__init__(app) |
|
|
self.rate_limiter = input_rate_limiter |
|
|
|
|
|
async def dispatch(self, request: Request, call_next): |
|
|
|
|
|
client_ip = request.client.host if request.client else "unknown" |
|
|
|
|
|
if not await self.rate_limiter.is_allowed(client_ip): |
|
|
return JSONResponse( |
|
|
status_code=HTTPStatus.TOO_MANY_REQUESTS, |
|
|
content={"error": "Rate limit exceeded"} |
|
|
) |
|
|
|
|
|
return await call_next(request) |
|
|
|
|
|
|
|
|
session_manager = None |
|
|
rate_limiter = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RateLimitViolation: |
|
|
"""Represents a rate limit violation with standardized error information""" |
|
|
tool_name: str |
|
|
limit_type: str |
|
|
current_usage: int |
|
|
limit_value: float |
|
|
retry_after_seconds: float |
|
|
|
|
|
def to_user_friendly_message(self) -> str: |
|
|
"""Generate user-friendly error message""" |
|
|
if self.limit_type == "burst": |
|
|
return f"Service temporarily unavailable: Too many rapid requests to {self.tool_name}. Please wait {self.retry_after_seconds:.0f} seconds before trying again." |
|
|
elif self.limit_type == "second": |
|
|
return f"Service temporarily unavailable: {self.tool_name} request rate exceeded ({self.limit_value}/second). Please wait {self.retry_after_seconds:.0f} seconds before trying again." |
|
|
elif self.limit_type == "minute": |
|
|
return f"Service temporarily unavailable: {self.tool_name} quota exceeded ({self.limit_value}/minute). Please try again in {self.retry_after_seconds:.0f} seconds." |
|
|
elif self.limit_type == "hour": |
|
|
return f"Service temporarily unavailable: {self.tool_name} hourly quota exceeded ({self.limit_value}/hour). Please try again in {self.retry_after_seconds:.0f} minutes." |
|
|
else: |
|
|
return f"Service temporarily unavailable: {self.tool_name} rate limit exceeded. Please try again later." |
|
|
|
|
|
def to_technical_message(self) -> str: |
|
|
"""Generate technical error message for debugging""" |
|
|
return f"Tool '{self.tool_name}' {self.limit_type} limit exceeded ({self.current_usage}/{self.limit_value} {self.limit_type})" |
|
|
|
|
|
|
|
|
def _parse_rate_limit_denial(tool_name: str, denial_reason: str) -> RateLimitViolation: |
|
|
"""Parse rate limit denial reason into structured violation information""" |
|
|
import re |
|
|
|
|
|
|
|
|
limit_type = "unknown" |
|
|
current_usage = 0 |
|
|
limit_value = 0.0 |
|
|
retry_after_seconds = 60.0 |
|
|
|
|
|
|
|
|
if "burst limit exceeded" in denial_reason: |
|
|
limit_type = "burst" |
|
|
retry_after_seconds = 1.0 |
|
|
match = re.search(r'\((\d+) requests/burst\)', denial_reason) |
|
|
if match: |
|
|
limit_value = float(match.group(1)) |
|
|
current_usage = int(limit_value) |
|
|
|
|
|
elif "per-second limit exceeded" in denial_reason: |
|
|
limit_type = "second" |
|
|
retry_after_seconds = 1.0 |
|
|
match = re.search(r'\(([0-9.]+) requests/second\)', denial_reason) |
|
|
if match: |
|
|
limit_value = float(match.group(1)) |
|
|
current_usage = int(limit_value) |
|
|
|
|
|
elif "per-minute limit exceeded" in denial_reason: |
|
|
limit_type = "minute" |
|
|
retry_after_seconds = 10.0 |
|
|
match = re.search(r'\(([0-9.]+) requests/minute\)', denial_reason) |
|
|
if match: |
|
|
limit_value = float(match.group(1)) |
|
|
current_usage = int(limit_value) |
|
|
|
|
|
elif "per-hour limit exceeded" in denial_reason: |
|
|
limit_type = "hour" |
|
|
retry_after_seconds = 300.0 |
|
|
match = re.search(r'\(([0-9.]+) requests/hour\)', denial_reason) |
|
|
if match: |
|
|
limit_value = float(match.group(1)) |
|
|
current_usage = int(limit_value) |
|
|
|
|
|
return RateLimitViolation( |
|
|
tool_name=tool_name, |
|
|
limit_type=limit_type, |
|
|
current_usage=current_usage, |
|
|
limit_value=limit_value, |
|
|
retry_after_seconds=retry_after_seconds |
|
|
) |
|
|
|
|
|
|
|
|
async def _call_session_tool_async(session: Session, tool_name: str, tool_args: Dict[str, Any], |
|
|
client_ip: str = "unknown") -> Dict[str, Any]: |
|
|
"""Execute a tool within a session context with full tracking, workspace management, and global rate limiting""" |
|
|
|
|
|
start_time = time.time() |
|
|
success = False |
|
|
error_details = None |
|
|
result_data = None |
|
|
|
|
|
|
|
|
session.touch() |
|
|
|
|
|
try: |
|
|
|
|
|
if global_tool_rate_limiter: |
|
|
allowed, deny_reason = await global_tool_rate_limiter.is_allowed(tool_name) |
|
|
if not allowed: |
|
|
|
|
|
rate_limit_violation = _parse_rate_limit_denial(tool_name, deny_reason) |
|
|
|
|
|
|
|
|
user_message = rate_limit_violation.to_user_friendly_message() |
|
|
technical_message = rate_limit_violation.to_technical_message() |
|
|
|
|
|
logger.warning(f"Session {session.id}: {technical_message}") |
|
|
|
|
|
result_data = { |
|
|
"success": False, |
|
|
"error": user_message, |
|
|
"error_code": "RATE_LIMIT_EXCEEDED", |
|
|
"error_type": "rate_limit", |
|
|
"tool_name": tool_name, |
|
|
"limit_type": rate_limit_violation.limit_type, |
|
|
"retry_after_seconds": rate_limit_violation.retry_after_seconds, |
|
|
"data": None, |
|
|
"rate_limited": True, |
|
|
"technical_details": technical_message |
|
|
} |
|
|
|
|
|
|
|
|
duration_ms = (time.time() - start_time) * 1000 |
|
|
tracker = session.get_tool_tracker() |
|
|
if tracker: |
|
|
try: |
|
|
agent_info = { |
|
|
"client_ip": client_ip, |
|
|
"type": "unknown", |
|
|
"session_request_count": session.request_count |
|
|
} |
|
|
|
|
|
tracker.log_tool_call( |
|
|
tool_name=tool_name, |
|
|
input_args=tool_args, |
|
|
output_result=result_data, |
|
|
success=False, |
|
|
duration_ms=duration_ms, |
|
|
error_details=user_message, |
|
|
agent_info=agent_info |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to log rate-limited tool call: {e}") |
|
|
|
|
|
return result_data |
|
|
|
|
|
|
|
|
mcp_tools = session.get_mcp_tools(prefer_async=True) |
|
|
|
|
|
|
|
|
if not hasattr(mcp_tools, tool_name): |
|
|
raise ValueError(f"Tool '{tool_name}' not implemented") |
|
|
|
|
|
tool_method = getattr(mcp_tools, tool_name) |
|
|
|
|
|
|
|
|
if hasattr(mcp_tools, 'set_session_context'): |
|
|
mcp_tools.set_session_context(session.id, str(session.workspace_path)) |
|
|
|
|
|
|
|
|
logger.info(f"Session {session.id}: Executing tool '{tool_name}' with args: {list(tool_args.keys())}") |
|
|
|
|
|
|
|
|
long_running_tools = {'batch_web_search', 'url_crawler', 'document_qa', 'document_extract', 'bash'} |
|
|
|
|
|
|
|
|
import inspect |
|
|
is_async_tool = inspect.iscoroutinefunction(tool_method) |
|
|
|
|
|
|
|
|
if is_async_tool: |
|
|
|
|
|
logger.debug("Executing async tool '{%s}'", tool_name) |
|
|
|
|
|
if config.enable_session_keepalive and tool_name in long_running_tools: |
|
|
|
|
|
with KeepAliveSessionWrapper(session, touch_interval=config.keepalive_touch_interval): |
|
|
result = await tool_method(**tool_args) |
|
|
else: |
|
|
|
|
|
result = await tool_method(**tool_args) |
|
|
else: |
|
|
|
|
|
logger.debug("Executing sync tool '{%s}' in thread pool", tool_name) |
|
|
|
|
|
|
|
|
def execute_tool_sync(): |
|
|
"""Synchronous tool execution to be run in thread pool""" |
|
|
return tool_method(**tool_args) |
|
|
|
|
|
|
|
|
import asyncio |
|
|
import concurrent.futures |
|
|
|
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
if config.enable_session_keepalive and tool_name in long_running_tools: |
|
|
|
|
|
with KeepAliveSessionWrapper(session, touch_interval=config.keepalive_touch_interval): |
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: |
|
|
result = await loop.run_in_executor(executor, execute_tool_sync) |
|
|
else: |
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: |
|
|
result = await loop.run_in_executor(executor, execute_tool_sync) |
|
|
|
|
|
|
|
|
session.touch() |
|
|
|
|
|
|
|
|
if hasattr(result, 'to_dict'): |
|
|
result_data = result.to_dict() |
|
|
elif isinstance(result, dict): |
|
|
result_data = result |
|
|
else: |
|
|
result_data = {"result": result} |
|
|
|
|
|
success = result_data.get('success', True) |
|
|
|
|
|
if success: |
|
|
logger.info(f"Session {session.id}: Tool '{tool_name}' completed successfully") |
|
|
|
|
|
|
|
|
if global_tool_rate_limiter: |
|
|
await global_tool_rate_limiter.record_request(tool_name) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
error_details = result_data.get('error', 'Unknown error') |
|
|
logger.warning(f"Session {session.id}: Tool '{tool_name}' failed: {error_details}") |
|
|
|
|
|
except Exception as e: |
|
|
success = False |
|
|
error_details = str(e) |
|
|
result_data = { |
|
|
"success": False, |
|
|
"error": error_details, |
|
|
"data": None |
|
|
} |
|
|
logger.error(f"Session {session.id}: Tool '{tool_name}' exception: {e}") |
|
|
|
|
|
|
|
|
duration_ms = (time.time() - start_time) * 1000 |
|
|
|
|
|
|
|
|
tracker = session.get_tool_tracker() |
|
|
if tracker: |
|
|
try: |
|
|
agent_info = { |
|
|
"client_ip": client_ip, |
|
|
"type": "unknown", |
|
|
"session_request_count": session.request_count |
|
|
} |
|
|
|
|
|
tracker.log_tool_call( |
|
|
tool_name=tool_name, |
|
|
input_args=tool_args, |
|
|
output_result=result_data, |
|
|
success=success, |
|
|
duration_ms=duration_ms, |
|
|
error_details=error_details, |
|
|
agent_info=agent_info |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to log tool call: {e}") |
|
|
|
|
|
return result_data |
|
|
|
|
|
|
|
|
|
|
|
def create_sse_response(response_data: dict, session_id: str = None) -> StreamingResponse: |
|
|
"""Create Server-Sent Events response with proper formatting""" |
|
|
def generate_sse(): |
|
|
try: |
|
|
|
|
|
if session_id: |
|
|
response_data["session_id"] = session_id |
|
|
|
|
|
json_data = json.dumps(response_data, ensure_ascii=False) |
|
|
yield f"event: message\n" |
|
|
yield f"data: {json_data}\n" |
|
|
yield f"\n" |
|
|
except Exception as e: |
|
|
error_data = { |
|
|
"jsonrpc": "2.0", |
|
|
"error": {"code": JsonRpcErr.INTERNAL_ERROR, "message": f"Internal error: {str(e)}"}, |
|
|
"id": response_data.get("id") |
|
|
} |
|
|
json_data = json.dumps(error_data, ensure_ascii=False) |
|
|
yield f"event: error\n" |
|
|
yield f"data: {json_data}\n" |
|
|
yield f"\n" |
|
|
|
|
|
return StreamingResponse( |
|
|
generate_sse(), |
|
|
media_type="text/event-stream", |
|
|
headers={ |
|
|
"Cache-Control": "no-cache", |
|
|
"Connection": "keep-alive", |
|
|
"Access-Control-Allow-Origin": "*", |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
def create_error_response(request_id: Any, code: int, message: str, session_id: str = None) -> StreamingResponse: |
|
|
"""Create error response in SSE format""" |
|
|
error_data = { |
|
|
"jsonrpc": "2.0", |
|
|
"error": {"code": code, "message": message}, |
|
|
"id": request_id |
|
|
} |
|
|
return create_sse_response(error_data, session_id) |
|
|
|
|
|
|
|
|
def create_rate_limit_response( |
|
|
request_id: Any, |
|
|
tool_name: str, |
|
|
error_message: str, |
|
|
retry_after_seconds: float, |
|
|
limit_type: str, |
|
|
technical_details: str = "", |
|
|
session_id: str = None |
|
|
) -> JSONResponse: |
|
|
""" |
|
|
Create HTTP 429 Rate Limit Exceeded response with proper headers and error format. |
|
|
|
|
|
Returns proper HTTP status code instead of SSE for rate limiting errors. |
|
|
""" |
|
|
|
|
|
|
|
|
retry_after_header = int(max(1.0, retry_after_seconds)) |
|
|
|
|
|
|
|
|
error_data = { |
|
|
"error": { |
|
|
"type": "rate_limit_exceeded", |
|
|
"code": "RATE_LIMIT_EXCEEDED", |
|
|
"message": error_message, |
|
|
"details": { |
|
|
"tool_name": tool_name, |
|
|
"limit_type": limit_type, |
|
|
"retry_after_seconds": retry_after_seconds, |
|
|
"technical_details": technical_details |
|
|
} |
|
|
}, |
|
|
"request_id": request_id, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"session_id": session_id |
|
|
} |
|
|
|
|
|
|
|
|
headers = { |
|
|
"Retry-After": str(retry_after_header), |
|
|
"X-RateLimit-Limit-Type": limit_type, |
|
|
"X-RateLimit-Tool": tool_name, |
|
|
"X-RateLimit-Retry-After": str(retry_after_seconds), |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
return JSONResponse( |
|
|
status_code=HTTPStatus.TOO_MANY_REQUESTS, |
|
|
content=error_data, |
|
|
headers=headers |
|
|
) |
|
|
|
|
|
|
|
|
async def handle_mcp_request(request: Request) -> StreamingResponse: |
|
|
"""Main MCP request handler with session management and tool execution""" |
|
|
|
|
|
try: |
|
|
|
|
|
content_length = request.headers.get("content-length") |
|
|
if content_length: |
|
|
content_size_mb = int(content_length) / (1024 * 1024) |
|
|
if content_size_mb > config.max_request_size_mb: |
|
|
logger.warning(f"Request too large: {content_size_mb:.2f}MB > {config.max_request_size_mb}MB") |
|
|
return create_error_response(None, JsonRpcErr.PARSE_ERROR, f"Request too large: {content_size_mb:.2f}MB") |
|
|
|
|
|
|
|
|
try: |
|
|
body = await asyncio.wait_for(request.body(), timeout=config.request_timeout_seconds) |
|
|
except asyncio.TimeoutError: |
|
|
logger.error("Timeout while reading request body") |
|
|
return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Request body read timeout") |
|
|
|
|
|
if not body: |
|
|
return create_error_response(None, JsonRpcErr.PARSE_ERROR, "Empty request body") |
|
|
|
|
|
try: |
|
|
data = json.loads(body.decode('utf-8')) |
|
|
except json.JSONDecodeError as e: |
|
|
return create_error_response(None, JsonRpcErr.PARSE_ERROR, f"Invalid JSON: {str(e)}") |
|
|
|
|
|
|
|
|
is_valid, error_msg = RequestValidator.validate_mcp_request(data) |
|
|
if not is_valid: |
|
|
return create_error_response(data.get("id"), JsonRpcErr.INVALID_REQUEST, error_msg) |
|
|
|
|
|
request_id = data["id"] |
|
|
method = data["method"] |
|
|
params = data.get("params", {}) |
|
|
|
|
|
|
|
|
session_id = request.headers.get("X-Session-ID") |
|
|
client_ip = request.client.host if request.client else "unknown" |
|
|
|
|
|
session = await session_manager.get_or_create_session(session_id) |
|
|
logger.info(f"Processing {method} request for session {session.id} from {client_ip}") |
|
|
|
|
|
|
|
|
if method == "initialize": |
|
|
|
|
|
response_data = { |
|
|
"jsonrpc": "2.0", |
|
|
"result": { |
|
|
"protocolVersion": "2025-06-18", |
|
|
"capabilities": { |
|
|
"tools": {"supportsProgress": True}, |
|
|
"resources": {}, |
|
|
"prompts": {} |
|
|
}, |
|
|
"serverInfo": { |
|
|
"name": "DeepDiver-Demo-MCP", |
|
|
"version": "1.0.0" |
|
|
} |
|
|
}, |
|
|
"id": request_id |
|
|
} |
|
|
|
|
|
elif method == "tools/list": |
|
|
|
|
|
tools_list = [] |
|
|
detailed_schemas = get_tool_schemas() |
|
|
|
|
|
|
|
|
for _, detailed_schema in detailed_schemas.items(): |
|
|
tools_list.append({ |
|
|
"name": detailed_schema["name"], |
|
|
"description": detailed_schema["description"], |
|
|
"inputSchema": detailed_schema["inputSchema"] |
|
|
}) |
|
|
|
|
|
logger.info(f"Serving {len(tools_list)} tools with detailed schemas to client") |
|
|
|
|
|
response_data = { |
|
|
"jsonrpc": "2.0", |
|
|
"result": {"tools": tools_list}, |
|
|
"id": request_id |
|
|
} |
|
|
|
|
|
elif method == "tools/call": |
|
|
|
|
|
is_valid, error_msg = RequestValidator.validate_tool_call(params) |
|
|
if not is_valid: |
|
|
return create_error_response(request_id, JsonRpcErr.INVALID_PARAMS, error_msg, session.id) |
|
|
|
|
|
tool_name = params["name"] |
|
|
tool_arguments = params["arguments"] |
|
|
|
|
|
|
|
|
result = await _call_session_tool_async(session, tool_name, tool_arguments, client_ip) |
|
|
|
|
|
|
|
|
if result.get("rate_limited", False): |
|
|
return create_rate_limit_response( |
|
|
request_id=request_id, |
|
|
tool_name=tool_name, |
|
|
error_message=result.get("error", "Rate limit exceeded"), |
|
|
retry_after_seconds=result.get("retry_after_seconds", 60), |
|
|
limit_type=result.get("limit_type", "unknown"), |
|
|
technical_details=result.get("technical_details", ""), |
|
|
session_id=session.id |
|
|
) |
|
|
|
|
|
|
|
|
response_data = { |
|
|
"jsonrpc": "2.0", |
|
|
"result": { |
|
|
"content": [ |
|
|
{ |
|
|
"type": "text", |
|
|
"text": json.dumps(result, indent=2, ensure_ascii=False) |
|
|
} |
|
|
] |
|
|
}, |
|
|
"id": request_id |
|
|
} |
|
|
|
|
|
else: |
|
|
return create_error_response(request_id, JsonRpcErr.METHOD_NOT_FOUND, f"Method not found: {method}", session.id) |
|
|
|
|
|
return create_sse_response(response_data, session.id) |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
logger.warning("Request timeout - client may have disconnected") |
|
|
return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Request timeout") |
|
|
except Exception as e: |
|
|
|
|
|
if "ClientDisconnect" in str(e) or "ConnectionClosedError" in str(e): |
|
|
logger.warning(f"Client disconnected during request processing: {e}") |
|
|
return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Client disconnected") |
|
|
|
|
|
logger.error(f"Unexpected error in MCP request handler: {e}") |
|
|
import traceback |
|
|
logger.error(traceback.format_exc()) |
|
|
return create_error_response(None, JsonRpcErr.INTERNAL_ERROR, f"Internal server error: {str(e)}") |
|
|
|
|
|
|
|
|
async def handle_health_check(request: Request) -> JSONResponse: |
|
|
"""Health check endpoint""" |
|
|
try: |
|
|
stats = await session_manager.get_stats() if session_manager else {} |
|
|
|
|
|
|
|
|
rate_limit_summary = {} |
|
|
if global_tool_rate_limiter: |
|
|
all_stats = global_tool_rate_limiter.get_all_stats() |
|
|
rate_limit_summary = { |
|
|
"enabled": True, |
|
|
"tools_with_limits": len(all_stats), |
|
|
"total_configured_tools": list(all_stats.keys()) |
|
|
} |
|
|
else: |
|
|
rate_limit_summary = {"enabled": False} |
|
|
|
|
|
health_data = { |
|
|
"status": "healthy", |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"version": "1.0.0", |
|
|
"session_stats": stats, |
|
|
"features": { |
|
|
"workspace_isolation": True, |
|
|
"tool_call_tracking": config.enable_tool_tracking if config else False, |
|
|
"client_rate_limiting": True, |
|
|
"global_tool_rate_limiting": rate_limit_summary["enabled"], |
|
|
"security_middleware": True, |
|
|
"standardized_rate_limit_responses": True |
|
|
}, |
|
|
"rate_limiting": rate_limit_summary, |
|
|
"error_formats": { |
|
|
"rate_limit_exceeded": { |
|
|
"http_status": HTTPStatus.TOO_MANY_REQUESTS, |
|
|
"headers": ["Retry-After", "X-RateLimit-*"], |
|
|
"error_code": "RATE_LIMIT_EXCEEDED", |
|
|
"response_format": "application/json" |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
return JSONResponse(content=health_data) |
|
|
|
|
|
except Exception as e: |
|
|
return JSONResponse( |
|
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, |
|
|
content={"status": "unhealthy", "error": str(e)} |
|
|
) |
|
|
|
|
|
|
|
|
async def handle_tracking_info(request: Request) -> JSONResponse: |
|
|
"""Get tool call tracking information for a session""" |
|
|
try: |
|
|
session_id = request.query_params.get("session_id") |
|
|
if not session_id: |
|
|
return JSONResponse( |
|
|
status_code=HTTPStatus.BAD_REQUEST, |
|
|
content={"error": "session_id parameter required"} |
|
|
) |
|
|
|
|
|
session = await session_manager.get_session(session_id) |
|
|
if not session: |
|
|
return JSONResponse( |
|
|
status_code=HTTPStatus.NOT_FOUND, |
|
|
content={"error": f"Session {session_id} not found"} |
|
|
) |
|
|
|
|
|
tracker = session.get_tool_tracker() |
|
|
if not tracker: |
|
|
return JSONResponse( |
|
|
content={ |
|
|
"session_id": session_id, |
|
|
"tracking_enabled": False, |
|
|
"message": "Tool call tracking not enabled or no workspace" |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
summary_data = {} |
|
|
if tracker.summary_file.exists(): |
|
|
try: |
|
|
with open(tracker.summary_file, 'r') as f: |
|
|
summary_data = json.load(f) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to read session summary: {e}") |
|
|
|
|
|
return JSONResponse(content={ |
|
|
"session_id": session_id, |
|
|
"tracking_enabled": True, |
|
|
"summary": summary_data, |
|
|
"logs_directory": str(tracker.logs_dir), |
|
|
"current_log_file": str(tracker.current_log_file) |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
return JSONResponse( |
|
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, |
|
|
content={"error": str(e)} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
async def handle_rate_limit_stats(request: Request) -> JSONResponse: |
|
|
"""Get global tool rate limiting statistics""" |
|
|
try: |
|
|
if not global_tool_rate_limiter: |
|
|
return JSONResponse( |
|
|
status_code=HTTPStatus.NOT_FOUND, |
|
|
content={"error": "Global tool rate limiter not initialized"} |
|
|
) |
|
|
|
|
|
|
|
|
tool_name = request.query_params.get("tool") |
|
|
|
|
|
if tool_name: |
|
|
|
|
|
stats = await global_tool_rate_limiter.get_tool_stats(tool_name) |
|
|
return JSONResponse(content=stats) |
|
|
else: |
|
|
|
|
|
all_stats = global_tool_rate_limiter.get_all_stats() |
|
|
return JSONResponse(content={ |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"global_tool_rate_limiting": True, |
|
|
"tools": all_stats, |
|
|
"summary": { |
|
|
"total_tools_with_limits": len(all_stats), |
|
|
"tools_configured": list(all_stats.keys()) |
|
|
} |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to get rate limit stats: {e}") |
|
|
return JSONResponse( |
|
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, |
|
|
content={"error": str(e)} |
|
|
) |
|
|
|
|
|
|
|
|
def create_app() -> Starlette: |
|
|
"""Create and configure the Starlette application""" |
|
|
global session_manager, rate_limiter, global_tool_rate_limiter |
|
|
|
|
|
if not config: |
|
|
raise RuntimeError("Server configuration not initialized") |
|
|
|
|
|
|
|
|
session_manager = ThreadSafeSessionManager( |
|
|
ttl_seconds=config.session_ttl_seconds, |
|
|
max_sessions=config.max_sessions, |
|
|
base_workspace_dir=config.base_workspace_dir |
|
|
) |
|
|
rate_limiter = RateLimiter(config.rate_limit_requests_per_minute) |
|
|
|
|
|
|
|
|
if config.tool_rate_limits: |
|
|
global_tool_rate_limiter = GlobalToolRateLimiter(config.tool_rate_limits) |
|
|
logger.info(f"Initialized global tool rate limiter with {len(config.tool_rate_limits)} tool limits") |
|
|
else: |
|
|
logger.info("No tool rate limits configured - tools will run without global rate limiting") |
|
|
|
|
|
|
|
|
app = Starlette(debug=config.debug_mode) |
|
|
|
|
|
app.add_middleware(SecurityMiddleware) |
|
|
app.add_middleware(RateLimitMiddleware, input_rate_limiter=rate_limiter) |
|
|
|
|
|
|
|
|
app.add_route("/mcp", handle_mcp_request, methods=["POST"]) |
|
|
app.add_route("/health", handle_health_check, methods=["GET"]) |
|
|
app.add_route("/tracking", handle_tracking_info, methods=["GET"]) |
|
|
app.add_route("/rate-limits", handle_rate_limit_stats, methods=["GET"]) |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
def parse_arguments(): |
|
|
"""Parse command line arguments""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Demo-Ready MCP Server with Per-Tool Rate Limiting", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
python src/tools/mcp_server_standard.py --config src/tools/server_config.yaml |
|
|
python src/tools/mcp_server_standard.py --host 127.0.0.1 --port 8080 |
|
|
python src/tools/mcp_server_standard.py --config custom_config.yaml --debug |
|
|
""" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--config', '-c', |
|
|
type=str, |
|
|
help='Path to YAML configuration file' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--host', |
|
|
type=str, |
|
|
help='Server host (overrides config file)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--port', '-p', |
|
|
type=int, |
|
|
help='Server port (overrides config file)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--debug', |
|
|
action='store_true', |
|
|
help='Enable debug mode (overrides config file)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--workspace-dir', |
|
|
type=str, |
|
|
help='Base workspace directory (overrides config file)' |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def print_startup_info(): |
|
|
"""Print server startup information""" |
|
|
logger.info("🚀 DeepDiver Demo MCP Server") |
|
|
logger.info("=" * 50) |
|
|
logger.info(f"📊 Features:") |
|
|
logger.info(f" • Session Management: ✅ (TTL: {config.session_ttl_seconds}s)") |
|
|
logger.info(f" • Workspace Isolation: ✅ (Base: {config.base_workspace_dir})") |
|
|
logger.info(f" • Tool Call Tracking: {'✅' if config.enable_tool_tracking else '❌'}") |
|
|
logger.info(f" • Client Rate Limiting: ✅ ({config.rate_limit_requests_per_minute}/min)") |
|
|
logger.info(f" • Global Tool Rate Limiting: {'✅' if config.tool_rate_limits else '❌'}") |
|
|
logger.info(f" • Security Middleware: ✅") |
|
|
|
|
|
|
|
|
if config.tool_rate_limits: |
|
|
logger.info(f"🚦 Tool Rate Limits: {len(config.tool_rate_limits)} tools configured") |
|
|
for tool_name, limits in list(config.tool_rate_limits.items())[:3]: |
|
|
burst = limits.get('burst_limit', '∞') |
|
|
rpm = limits.get('requests_per_minute', '∞') |
|
|
logger.info(f" • {tool_name}: {rpm}/min, burst: {burst}") |
|
|
if len(config.tool_rate_limits) > 3: |
|
|
logger.info(f" • ... and {len(config.tool_rate_limits) - 3} more tools") |
|
|
|
|
|
|
|
|
tool_schemas = get_tool_schemas() |
|
|
available_tools = list(tool_schemas.keys()) |
|
|
|
|
|
logger.info(f"🔧 Tools Available: {len(available_tools)}") |
|
|
logger.info(f" • All tools defined in schemas: {len(available_tools)} tools") |
|
|
logger.info(f" • Sample tools: {', '.join(sorted(available_tools)[:5])}...") |
|
|
logger.info("=" * 50) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to run the production MCP server""" |
|
|
global config |
|
|
|
|
|
|
|
|
args = parse_arguments() |
|
|
|
|
|
config = ServerConfig.from_yaml("./src/tools/server_config.yaml") |
|
|
|
|
|
|
|
|
if args.host: |
|
|
config.host = args.host |
|
|
logger.info(f"🔧 Override: Host = {config.host}") |
|
|
|
|
|
if args.port: |
|
|
config.port = args.port |
|
|
logger.info(f"🔧 Override: Port = {config.port}") |
|
|
|
|
|
if args.debug: |
|
|
config.debug_mode = True |
|
|
logger.info(f"🔧 Override: Debug mode enabled") |
|
|
|
|
|
if args.workspace_dir: |
|
|
config.base_workspace_dir = args.workspace_dir |
|
|
logger.info(f"🔧 Override: Workspace directory = {config.base_workspace_dir}") |
|
|
|
|
|
print_startup_info() |
|
|
|
|
|
try: |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.getenv('FORCE_HIGH_CONCURRENCY', '').lower() == 'true': |
|
|
pass |
|
|
|
|
|
app = create_app() |
|
|
|
|
|
logger.info(f"🌐 Starting server at http://{config.host}:{config.port}") |
|
|
logger.info(f"📡 MCP endpoint: http://{config.host}:{config.port}/mcp") |
|
|
logger.info(f"🏥 Health check: http://{config.host}:{config.port}/health") |
|
|
logger.info(f"📊 Tracking info: http://{config.host}:{config.port}/tracking?session_id=<id>") |
|
|
logger.info(f"🚦 Rate limit stats: http://{config.host}:{config.port}/rate-limits") |
|
|
|
|
|
uvicorn.run( |
|
|
app, |
|
|
host=config.host, |
|
|
port=config.port, |
|
|
log_level="info", |
|
|
timeout_keep_alive=config.request_timeout_seconds, |
|
|
workers=1, |
|
|
backlog=1024, |
|
|
access_log=False, |
|
|
limit_concurrency=None, |
|
|
) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n⏹️ Server stopped by user") |
|
|
except Exception as e: |
|
|
print(f"❌ Server startup failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |