drizzlezyk commited on
Commit
edebd9d
·
verified ·
1 Parent(s): 1c24ecc

Upload deepdiver_v2/src/tools/mcp_server_standard.py with huggingface_hub

Browse files
deepdiver_v2/src/tools/mcp_server_standard.py ADDED
@@ -0,0 +1,1751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ #!/usr/bin/env python3
3
+ """
4
+ Demo-Ready MCP Server - New Standard Implementation
5
+ Combines robust session management with comprehensive tool definitions.
6
+ Features: workspace isolation, tool call tracking, rate limiting, security, and full tool suite.
7
+ """
8
+
9
+ import argparse
10
+ import asyncio
11
+ import json
12
+ import logging
13
+ import time
14
+ import uuid
15
+ import yaml
16
+ from collections import defaultdict, deque
17
+ from dataclasses import dataclass, field
18
+ from datetime import datetime, timedelta
19
+ from pathlib import Path
20
+ from threading import Thread, Event
21
+ from typing import Any, Dict, List, Optional
22
+
23
+ # Third-party imports
24
+ from starlette.applications import Starlette
25
+ from starlette.middleware.base import BaseHTTPMiddleware
26
+ from starlette.requests import Request
27
+ from starlette.responses import JSONResponse, StreamingResponse
28
+ import uvicorn
29
+
30
+ # Add project root to Python path for imports
31
+ import sys
32
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
33
+ from src.utils.status_codes import JsonRpcErr
34
+ from http import HTTPStatus
35
+
36
+ # Handle both relative and absolute imports
37
+ try:
38
+ from .mcp_tools import MCPTools, get_tool_schemas
39
+ from .mcp_tools_async import AsyncMCPTools
40
+ except ImportError:
41
+ # Fallback for direct script execution
42
+ from src.tools.mcp_tools import MCPTools, get_tool_schemas
43
+ try:
44
+ from src.tools.mcp_tools_async import AsyncMCPTools
45
+ except ImportError:
46
+ AsyncMCPTools = None
47
+
48
+ # Workspace knowledge manager disabled
49
+ WORKSPACE_KNOWLEDGE_AVAILABLE = False
50
+
51
+ # Configure structured logging
52
+ logging.basicConfig(
53
+ level=logging.INFO,
54
+ format='%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s',
55
+ handlers=[
56
+ logging.StreamHandler(sys.stdout),
57
+ logging.FileHandler('mcp_server.log')
58
+ ]
59
+ )
60
+ logger = logging.getLogger(__name__)
61
+
62
+ # ================ CONFIGURATION ================
63
+
64
+
65
+ @dataclass
66
+ class ServerConfig:
67
+ """Server configuration with only actually implemented options"""
68
+ # Server Core Settings
69
+ host: str = "127.0.0.1"
70
+ port: int = 6274
71
+ debug_mode: bool = False
72
+
73
+ # Session Management
74
+ session_ttl_seconds: int = 3600 # 1 hour default
75
+ max_sessions: int = 1000
76
+ cleanup_interval_seconds: int = 300 # 5 minutes
77
+ enable_session_keepalive: bool = True
78
+ keepalive_touch_interval: int = 300
79
+
80
+ # Request Handling
81
+ request_timeout_seconds: int = 120
82
+ max_request_size_mb: int = 10
83
+
84
+ # Client Rate Limiting (per IP)
85
+ rate_limit_requests_per_minute: int = 300
86
+
87
+ # Workspace Management
88
+ base_workspace_dir: str = "workspaces"
89
+
90
+ # Tool Call Tracking & Logging
91
+ enable_tool_tracking: bool = True
92
+ max_tracked_calls_per_session: int = 1000
93
+ track_detailed_errors: bool = True
94
+
95
+
96
+
97
+ # Per-tool Rate Limiting Configuration
98
+ tool_rate_limits: Dict[str, Dict[str, int]] = field(default_factory=dict)
99
+
100
+ @classmethod
101
+ def from_yaml(cls, config_path: str) -> 'ServerConfig':
102
+ """Load configuration from YAML file"""
103
+ try:
104
+ with open(config_path, 'r') as f:
105
+ config_data = yaml.safe_load(f)
106
+
107
+ # Extract configuration sections with defaults
108
+ server_config = config_data.get('server', {})
109
+ tracking_config = config_data.get('tracking', {})
110
+ tool_rate_limits = config_data.get('tool_rate_limits', {})
111
+
112
+ return cls(
113
+ # Server Core Settings
114
+ host=server_config.get('host', "127.0.0.1"),
115
+ port=server_config.get('port', 6274),
116
+ debug_mode=server_config.get('debug_mode', False),
117
+
118
+ # Session Management
119
+ session_ttl_seconds=server_config.get('session_ttl_seconds', 3600),
120
+ max_sessions=server_config.get('max_sessions', 1000),
121
+ cleanup_interval_seconds=server_config.get('cleanup_interval_seconds', 300),
122
+ enable_session_keepalive=server_config.get('enable_session_keepalive', True),
123
+ keepalive_touch_interval=server_config.get('keepalive_touch_interval', 300),
124
+
125
+ # Request Handling
126
+ request_timeout_seconds=server_config.get('request_timeout_seconds', 120),
127
+ max_request_size_mb=server_config.get('max_request_size_mb', 10),
128
+
129
+ # Client Rate Limiting
130
+ rate_limit_requests_per_minute=server_config.get('rate_limit_requests_per_minute', 300),
131
+
132
+ # Workspace Management
133
+ base_workspace_dir=server_config.get('base_workspace_dir', "workspaces"),
134
+
135
+ # Tool Call Tracking & Logging
136
+ enable_tool_tracking=tracking_config.get('enable_tool_tracking', True),
137
+ max_tracked_calls_per_session=tracking_config.get('max_tracked_calls_per_session', 1000),
138
+ track_detailed_errors=tracking_config.get('track_detailed_errors', True),
139
+
140
+ # Per-tool Rate Limiting
141
+ tool_rate_limits=tool_rate_limits
142
+ )
143
+
144
+ except Exception as e:
145
+ logger.error(f"Failed to load configuration from {config_path}: {e}")
146
+ logger.info("Using default configuration")
147
+ return cls()
148
+
149
+ # Global configuration instance - will be set during startup
150
+ config: Optional[ServerConfig] = None
151
+
152
+ # ================ GLOBAL PER-TOOL RATE LIMITING ================
153
+
154
+
155
+ @dataclass
156
+ class ToolRateLimit:
157
+ """Rate limit configuration for a specific tool"""
158
+ requests_per_minute: float
159
+ requests_per_hour: float
160
+ burst_limit: int
161
+
162
+
163
+ class GlobalToolRateLimiter:
164
+ """
165
+ Global rate limiter that controls QPS to external APIs per tool.
166
+ This is shared across all sessions and clients to manage upstream service load.
167
+ """
168
+
169
+ def __init__(self, tool_rate_limits: Dict[str, Dict[str, int]]):
170
+ self.tool_limits: Dict[str, ToolRateLimit] = {}
171
+ self.tool_requests: Dict[str, deque] = defaultdict(deque)
172
+ self.lock = asyncio.Lock()
173
+
174
+ # Initialize rate limits for each tool
175
+ for tool_name, limits_config in tool_rate_limits.items():
176
+ self.tool_limits[tool_name] = ToolRateLimit(
177
+ requests_per_minute=limits_config.get('requests_per_minute', float('inf')),
178
+ requests_per_hour=limits_config.get('requests_per_hour', float('inf')),
179
+ burst_limit=limits_config.get('burst_limit', 10)
180
+ )
181
+ self.tool_requests[tool_name] = deque()
182
+
183
+ logger.info(f"Initialized global tool rate limiter for {len(self.tool_limits)} tools")
184
+
185
+ async def is_allowed(self, tool_name: str) -> tuple[bool, Optional[str]]:
186
+ """
187
+ Check if a request to the specified tool is allowed based on global rate limits.
188
+
189
+ Returns:
190
+ tuple[bool, Optional[str]]: (allowed, reason_if_denied)
191
+ """
192
+ if tool_name not in self.tool_limits:
193
+ # Tool not configured for rate limiting - allow
194
+ return True, None
195
+
196
+ async with self.lock:
197
+ now = time.time()
198
+ limits = self.tool_limits[tool_name]
199
+ requests = self.tool_requests[tool_name]
200
+
201
+ # Clean old requests outside the time windows
202
+ self._cleanup_old_requests(requests, now)
203
+
204
+ # Check various time window limits
205
+ recent_requests = list(requests)
206
+
207
+ # Check burst limit (rapid requests in last second) - only if specified
208
+ if limits.burst_limit != float('inf'):
209
+ burst_count = sum(1 for req_time in recent_requests if now - req_time < 1.0)
210
+ if burst_count >= limits.burst_limit:
211
+ return False, f"Tool '{tool_name}' burst limit exceeded ({limits.burst_limit} requests/burst)"
212
+
213
+ # Check per-minute limit - only if specified
214
+ if limits.requests_per_minute != float('inf'):
215
+ minute_count = sum(1 for req_time in recent_requests if now - req_time < 60.0)
216
+ if minute_count >= limits.requests_per_minute:
217
+ return False, f"Tool '{tool_name}' per-minute limit exceeded ({limits.requests_per_minute} requests/minute)"
218
+
219
+ # Check per-hour limit - only if specified
220
+ if limits.requests_per_hour != float('inf'):
221
+ hour_count = sum(1 for req_time in recent_requests if now - req_time < 3600.0)
222
+ if hour_count >= limits.requests_per_hour:
223
+ return False, f"Tool '{tool_name}' per-hour limit exceeded ({limits.requests_per_hour} requests/hour)"
224
+
225
+ return True, None
226
+
227
+ async def record_request(self, tool_name: str):
228
+ """Record a successful request for rate limiting tracking"""
229
+ if tool_name not in self.tool_limits:
230
+ return
231
+
232
+ async with self.lock:
233
+ now = time.time()
234
+ self.tool_requests[tool_name].append(now)
235
+
236
+ # Keep deque size manageable (only keep last hour of requests)
237
+ self._cleanup_old_requests(self.tool_requests[tool_name], now)
238
+
239
+ @staticmethod
240
+ def _cleanup_old_requests(requests: deque, now: float):
241
+ """Remove requests older than 1 hour to keep memory usage bounded"""
242
+ while requests and now - requests[0] > 3600.0: # 1 hour
243
+ requests.popleft()
244
+
245
+ async def get_tool_stats(self, tool_name: str) -> Dict[str, Any]:
246
+ """Get current usage statistics for a tool"""
247
+ if tool_name not in self.tool_limits:
248
+ return {"error": f"Tool '{tool_name}' not configured for rate limiting"}
249
+
250
+ async with self.lock:
251
+ now = time.time()
252
+ requests = self.tool_requests[tool_name]
253
+ limits = self.tool_limits[tool_name]
254
+
255
+ # Clean old requests first
256
+ self._cleanup_old_requests(requests, now)
257
+
258
+ recent_requests = list(requests)
259
+
260
+ return {
261
+ "tool_name": tool_name,
262
+ "current_usage": {
263
+ "last_second": sum(1 for req_time in recent_requests if now - req_time < 1.0),
264
+ "last_minute": sum(1 for req_time in recent_requests if now - req_time < 60.0),
265
+ "last_hour": sum(1 for req_time in recent_requests if now - req_time < 3600.0)
266
+ },
267
+ "limits": {
268
+ "requests_per_minute": limits.requests_per_minute if limits.requests_per_minute != float('inf') else None,
269
+ "requests_per_hour": limits.requests_per_hour if limits.requests_per_hour != float('inf') else None,
270
+ "burst_limit": limits.burst_limit if limits.burst_limit != float('inf') else None
271
+ },
272
+ "utilization": {
273
+ "minute_utilization": sum(1 for req_time in recent_requests if now - req_time < 60.0) / limits.requests_per_minute if limits.requests_per_minute != float('inf') else 0,
274
+ "hour_utilization": sum(1 for req_time in recent_requests if now - req_time < 3600.0) / limits.requests_per_hour if limits.requests_per_hour != float('inf') else 0
275
+ }
276
+ }
277
+
278
+ def get_all_stats(self) -> Dict[str, Any]:
279
+ """Get usage statistics for all tools"""
280
+ return {
281
+ tool_name: self.get_tool_stats(tool_name)
282
+ for tool_name in self.tool_limits.keys()
283
+ }
284
+
285
+ # Global tool rate limiter instance - will be initialized during startup
286
+ global_tool_rate_limiter: Optional[GlobalToolRateLimiter] = None
287
+
288
+ # ================ TOOL DEFINITIONS ================
289
+
290
+ # Tool execution function mapping - maps tool names to their implementation functions
291
+
292
+
293
+ def get_tool_function(tool_name: str):
294
+ """Get the actual function for a tool"""
295
+ tool_map = {
296
+ "batch_web_search": lambda tools, **kwargs: tools.batch_web_search(**kwargs),
297
+ "url_crawler": lambda tools, **kwargs: tools.url_crawler(**kwargs),
298
+ "download_files": lambda tools, **kwargs: tools.download_files(**kwargs),
299
+ "list_workspace": lambda tools, **kwargs: tools.list_workspace(**kwargs),
300
+ "str_replace_based_edit_tool": lambda tools, **kwargs: tools.str_replace_based_edit_tool(**kwargs),
301
+ "file_stats": lambda tools, **kwargs: tools.file_stats(**kwargs),
302
+ "file_read": lambda tools, **kwargs: tools.file_read(**kwargs),
303
+ "file_read_lines": lambda tools, **kwargs: tools.file_read_lines(**kwargs),
304
+ "content_preview": lambda tools, **kwargs: tools.content_preview(**kwargs),
305
+ "file_write": lambda tools, **kwargs: tools.file_write(**kwargs),
306
+ "file_grep_search": lambda tools, **kwargs: tools.file_grep_search(**kwargs),
307
+ "file_grep_with_context": lambda tools, **kwargs: tools.file_grep_with_context(**kwargs),
308
+ "file_find_by_name": lambda tools, **kwargs: tools.file_find_by_name(**kwargs),
309
+ "bash": lambda tools, **kwargs: tools.bash(**kwargs),
310
+ "task_done": lambda tools, **kwargs: tools.task_done(**kwargs),
311
+ "think": lambda tools, **kwargs: tools.think(**kwargs),
312
+ "reflect": lambda tools, **kwargs: tools.reflect(**kwargs),
313
+ "document_qa": lambda tools, **kwargs: tools.document_qa(**kwargs),
314
+ "extract_markdown_toc": lambda tools, **kwargs: tools.extract_markdown_toc(**kwargs),
315
+ "extract_markdown_section": lambda tools, **kwargs: tools.extract_markdown_section(**kwargs),
316
+
317
+ "document_extract": lambda tools, **kwargs: tools.document_extract(**kwargs),
318
+ "search_result_classifier": lambda tools, **kwargs: tools.search_result_classifier(**kwargs),
319
+ "info_seeker_subjective_task_done": None,
320
+ "writer_subjective_task_done": None,
321
+ "section_writer": lambda tools, **kwargs: tools.section_writer(**kwargs),
322
+ "concat_section_files": lambda tools, **kwargs: tools.concat_section_files(**kwargs),
323
+
324
+ # Internal tools - available to server but NOT exposed to agents via tool schemas
325
+ "internal_file_read_unlimited": lambda tools, **kwargs: tools.internal_file_read_unlimited(**kwargs),
326
+ }
327
+ return tool_map.get(tool_name)
328
+
329
+
330
+ # ================ TOOL CALL TRACKING ================
331
+
332
+
333
+ @dataclass
334
+ class ToolCallLog:
335
+ """Individual tool call log entry"""
336
+ call_id: str
337
+ timestamp: datetime
338
+ tool_name: str
339
+ input_args: Dict[str, Any]
340
+ output_result: Dict[str, Any]
341
+ success: bool
342
+ duration_ms: float
343
+ error_details: Optional[str] = None
344
+ session_id: str = ""
345
+ agent_info: Optional[Dict[str, Any]] = None
346
+
347
+ def to_dict(self) -> Dict[str, Any]:
348
+ """Convert to dictionary for JSON serialization"""
349
+ return {
350
+ "call_id": self.call_id,
351
+ "timestamp": self.timestamp.isoformat(),
352
+ "tool_name": self.tool_name,
353
+ "input_args": self.input_args,
354
+ "output_result": self.output_result,
355
+ "success": self.success,
356
+ "duration_ms": self.duration_ms,
357
+ "error_details": self.error_details,
358
+ "session_id": self.session_id,
359
+ "agent_info": self.agent_info
360
+ }
361
+
362
+
363
+ class ToolCallTracker:
364
+ """Tracks and saves tool calls to workspace-specific files"""
365
+
366
+ def __init__(self, workspace_path: Path, session_id: str):
367
+ self.workspace_path = workspace_path
368
+ self.session_id = session_id
369
+ self.logs_dir = workspace_path / "tool_call_logs"
370
+ self.logs_dir.mkdir(exist_ok=True)
371
+
372
+ # Create daily log file
373
+ today = datetime.now().strftime("%Y-%m-%d")
374
+ self.current_log_file = self.logs_dir / f"tool_calls_{today}.jsonl"
375
+ self.summary_file = self.logs_dir / "session_summary.json"
376
+
377
+ # Track call counts
378
+ self.call_count = 0
379
+ self.tool_usage_stats = defaultdict(int)
380
+
381
+ # Initialize session summary
382
+ self._initialize_session_summary()
383
+
384
+ def _initialize_session_summary(self):
385
+ """Initialize or update session summary file"""
386
+ summary = {
387
+ "session_id": self.session_id,
388
+ "session_start": datetime.now().isoformat(),
389
+ "last_updated": datetime.now().isoformat(),
390
+ "total_tool_calls": 0,
391
+ "tool_usage_stats": {},
392
+ "agent_activity": {},
393
+ "workspace_path": str(self.workspace_path)
394
+ }
395
+
396
+ # Load existing summary if it exists
397
+ if self.summary_file.exists():
398
+ try:
399
+ with open(self.summary_file, 'r') as f:
400
+ existing_summary = json.load(f)
401
+ summary.update(existing_summary)
402
+ # Don't overwrite session_start if it already exists
403
+ if "session_start" in existing_summary:
404
+ summary["session_start"] = existing_summary["session_start"]
405
+ except Exception as e:
406
+ logger.warning(f"Could not load existing session summary: {e}")
407
+
408
+ self._save_summary(summary)
409
+
410
+ def _save_summary(self, summary: Dict[str, Any]):
411
+ """Save session summary to file"""
412
+ try:
413
+ with open(self.summary_file, 'w') as f:
414
+ json.dump(summary, f, indent=2, ensure_ascii=False)
415
+ except Exception as e:
416
+ logger.error(f"Failed to save session summary: {e}")
417
+
418
+ def log_tool_call(self,
419
+ tool_name: str,
420
+ input_args: Dict[str, Any],
421
+ output_result: Dict[str, Any],
422
+ success: bool,
423
+ duration_ms: float,
424
+ error_details: Optional[str] = None,
425
+ agent_info: Optional[Dict[str, Any]] = None) -> str:
426
+ """Log a tool call and return the call ID"""
427
+
428
+ if not config.enable_tool_tracking:
429
+ return ""
430
+
431
+ # Respect max call limit per session
432
+ if self.call_count >= config.max_tracked_calls_per_session:
433
+ logger.warning(f"Max tracked calls reached for session {self.session_id}")
434
+ return ""
435
+
436
+ call_id = str(uuid.uuid4())
437
+ timestamp = datetime.now()
438
+
439
+ # Create log entry
440
+ log_entry = ToolCallLog(
441
+ call_id=call_id,
442
+ timestamp=timestamp,
443
+ tool_name=tool_name,
444
+ input_args=self._sanitize_args(input_args),
445
+ output_result=self._sanitize_result(output_result),
446
+ success=success,
447
+ duration_ms=duration_ms,
448
+ error_details=error_details if config.track_detailed_errors else None,
449
+ session_id=self.session_id,
450
+ agent_info=agent_info
451
+ )
452
+
453
+ # Save to JSONL file (one JSON object per line)
454
+ try:
455
+ with open(self.current_log_file, 'a', encoding="utf-8") as f:
456
+ f.write(json.dumps(log_entry.to_dict(), ensure_ascii=False) + '\n')
457
+ except Exception as e:
458
+ logger.error(f"Failed to save tool call log: {e}")
459
+
460
+ # Update session summary
461
+ self._update_session_summary(log_entry)
462
+
463
+ self.call_count += 1
464
+ self.tool_usage_stats[tool_name] += 1
465
+
466
+ return call_id
467
+
468
+ @staticmethod
469
+ def _sanitize_args(args: Dict[str, Any]) -> Dict[str, Any]:
470
+ """Sanitize arguments for logging (remove sensitive data)"""
471
+ sanitized = {}
472
+ for key, value in args.items():
473
+ if isinstance(value, str) and len(value) > 1000:
474
+ sanitized[key] = value[:1000] + "... [truncated]"
475
+ elif key.lower() in ['password', 'token', 'secret', 'key']:
476
+ sanitized[key] = "[REDACTED]"
477
+ else:
478
+ sanitized[key] = value
479
+ return sanitized
480
+
481
+ def _sanitize_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
482
+ """Sanitize result for logging (remove large content)"""
483
+ if not isinstance(result, dict):
484
+ return result
485
+
486
+ sanitized = {}
487
+ for key, value in result.items():
488
+ if isinstance(value, str) and len(value) > 2000:
489
+ sanitized[key] = value[:2000] + "... [truncated]"
490
+ elif isinstance(value, dict):
491
+ sanitized[key] = self._sanitize_result(value)
492
+ else:
493
+ sanitized[key] = value
494
+ return sanitized
495
+
496
+ def _update_session_summary(self, log_entry: ToolCallLog):
497
+ """Update session summary with new tool call"""
498
+ try:
499
+ summary = {
500
+ "session_id": self.session_id,
501
+ "last_updated": datetime.now().isoformat(),
502
+ "total_tool_calls": self.call_count + 1,
503
+ "tool_usage_stats": dict(self.tool_usage_stats),
504
+ "workspace_path": str(self.workspace_path)
505
+ }
506
+
507
+ # Load existing summary
508
+ if self.summary_file.exists():
509
+ with open(self.summary_file, 'r') as f:
510
+ existing_summary = json.load(f)
511
+ summary.update(existing_summary)
512
+
513
+ # Update with new data
514
+ summary["last_updated"] = datetime.now().isoformat()
515
+ summary["total_tool_calls"] = self.call_count + 1
516
+ summary["tool_usage_stats"] = dict(self.tool_usage_stats)
517
+ summary["tool_usage_stats"][log_entry.tool_name] = self.tool_usage_stats[log_entry.tool_name] + 1
518
+
519
+ # Track agent activity
520
+ if log_entry.agent_info:
521
+ agent_type = log_entry.agent_info.get('type', 'unknown')
522
+ if 'agent_activity' not in summary:
523
+ summary['agent_activity'] = {}
524
+ if agent_type not in summary['agent_activity']:
525
+ summary['agent_activity'][agent_type] = {
526
+ 'tool_calls': 0,
527
+ 'last_active': log_entry.timestamp.isoformat()
528
+ }
529
+ summary['agent_activity'][agent_type]['tool_calls'] += 1
530
+ summary['agent_activity'][agent_type]['last_active'] = log_entry.timestamp.isoformat()
531
+
532
+ self._save_summary(summary)
533
+
534
+ except Exception as e:
535
+ logger.error(f"Failed to update session summary: {e}")
536
+
537
+ # ================ SESSION KEEP-ALIVE FOR LONG OPERATIONS ================
538
+
539
+
540
+ class KeepAliveSessionWrapper:
541
+ """Wrapper that keeps a session alive during long-running operations"""
542
+
543
+ def __init__(self, session: 'Session', touch_interval: int = 300): # Touch every 5 minutes
544
+ self.session = session
545
+ self.touch_interval = touch_interval
546
+ self.keep_alive_thread = None
547
+ self.stop_event = Event()
548
+ self.active = False
549
+
550
+ def start_keep_alive(self):
551
+ """Start the keep-alive mechanism"""
552
+ if self.active:
553
+ return
554
+
555
+ self.active = True
556
+ self.stop_event.clear()
557
+
558
+ def keep_alive_worker():
559
+ while not self.stop_event.wait(self.touch_interval):
560
+ try:
561
+ self.session.touch()
562
+ logger.debug("Keep-alive: Touched session {%s}", self.session.id)
563
+ except Exception as e:
564
+ logger.error(f"Keep-alive error for session {self.session.id}: {e}")
565
+ break
566
+
567
+ self.keep_alive_thread = Thread(target=keep_alive_worker, daemon=True)
568
+ self.keep_alive_thread.start()
569
+ logger.info(f"Started keep-alive for session {self.session.id}")
570
+
571
+ def stop_keep_alive(self):
572
+ """Stop the keep-alive mechanism"""
573
+ if not self.active:
574
+ return
575
+
576
+ self.active = False
577
+ self.stop_event.set()
578
+
579
+ if self.keep_alive_thread and self.keep_alive_thread.is_alive():
580
+ self.keep_alive_thread.join(timeout=1.0)
581
+
582
+ # Final touch
583
+ try:
584
+ self.session.touch()
585
+ except Exception as e:
586
+ logger.error(f"Final keep-alive touch error for session {self.session.id}: {e}")
587
+
588
+ logger.info(f"Stopped keep-alive for session {self.session.id}")
589
+
590
+ def __enter__(self):
591
+ self.start_keep_alive()
592
+ return self
593
+
594
+ def __exit__(self, exc_type, exc_val, exc_tb):
595
+ self.stop_keep_alive()
596
+
597
+ # ================ SESSION MANAGEMENT ================
598
+
599
+
600
+ @dataclass
601
+ class Session:
602
+ """Thread-safe session data structure with workspace management"""
603
+ id: str
604
+ created_at: datetime
605
+ last_accessed: datetime
606
+ initialized: bool = False
607
+ request_count: int = 0
608
+ metadata: Dict[str, Any] = field(default_factory=dict)
609
+ workspace_path: Optional[Path] = None
610
+ mcp_tools: Optional[MCPTools] = None
611
+ tool_tracker: Optional[ToolCallTracker] = None
612
+
613
+
614
+ def is_expired(self, ttl_seconds: int) -> bool:
615
+ """Check if session has expired"""
616
+ return datetime.now() - self.last_accessed > timedelta(seconds=ttl_seconds)
617
+
618
+ def touch(self):
619
+ """Update last accessed time"""
620
+ self.last_accessed = datetime.now()
621
+ self.request_count += 1
622
+
623
+ def get_mcp_tools(self, prefer_async: bool = True) -> MCPTools:
624
+ """Get or create MCP tools instance for this session"""
625
+ if self.mcp_tools is None:
626
+ # Use async tools if available and preferred
627
+ if prefer_async and AsyncMCPTools is not None:
628
+ self.mcp_tools = AsyncMCPTools(workspace_path=str(self.workspace_path) if self.workspace_path else None)
629
+ else:
630
+ self.mcp_tools = MCPTools(workspace_path=str(self.workspace_path) if self.workspace_path else None)
631
+ return self.mcp_tools
632
+
633
+ def get_tool_tracker(self) -> Optional[ToolCallTracker]:
634
+ """Get or create tool call tracker for this session"""
635
+ if config.enable_tool_tracking and self.workspace_path:
636
+ if self.tool_tracker is None:
637
+ self.tool_tracker = ToolCallTracker(self.workspace_path, self.id)
638
+ return self.tool_tracker
639
+ return None
640
+
641
+
642
+
643
+ class AsyncRLock:
644
+ """异步可重入锁,模拟 threading.RLock 的异步版本"""
645
+ def __init__(self):
646
+ self._lock = asyncio.Lock()
647
+ self._owner: Optional[asyncio.Task] = None # 记录持有锁的协程任务
648
+ self._count = 0 # 重入次数
649
+
650
+ async def acquire(self):
651
+ current_task = asyncio.current_task()
652
+ # 如果当前协程已持有锁,直接增加重入次数
653
+ if self._owner == current_task:
654
+ self._count += 1
655
+ return
656
+ # 否则等待获取锁
657
+ await self._lock.acquire()
658
+ self._owner = current_task
659
+ self._count = 1
660
+
661
+ async def release(self):
662
+ if self._owner != asyncio.current_task():
663
+ raise RuntimeError("不能释放非当前协程持有的锁")
664
+ self._count -= 1
665
+ if self._count == 0: # 重入次数归零时,真正释放锁
666
+ self._owner = None
667
+ self._lock.release()
668
+
669
+ # 支持 async with 语法
670
+ async def __aenter__(self):
671
+ await self.acquire()
672
+ return self
673
+
674
+ async def __aexit__(self, exc_type, exc, tb):
675
+ await self.release()
676
+
677
+
678
+ class ThreadSafeSessionManager:
679
+ """Thread-safe session manager with workspace management"""
680
+
681
+ def __init__(self, ttl_seconds: int = 3600, max_sessions: int = 1000, base_workspace_dir: str = "workspaces"):
682
+ self.ttl_seconds = ttl_seconds
683
+ self.max_sessions = max_sessions
684
+ self.base_workspace_dir = Path(base_workspace_dir)
685
+ self.base_workspace_dir.mkdir(exist_ok=True)
686
+
687
+ # Thread-safe session storage
688
+ self.sessions: Dict[str, Session] = {}
689
+ self.lock = AsyncRLock()
690
+
691
+ # Start cleanup thread
692
+ self._start_cleanup_thread()
693
+
694
+ async def create_session(self) -> str:
695
+ """Create a new session and return session ID"""
696
+ session_id = str(uuid.uuid4())
697
+
698
+ async with self.lock:
699
+ # Check session limits
700
+ if len(self.sessions) >= self.max_sessions:
701
+ await self._cleanup_oldest_sessions()
702
+
703
+ # Create workspace directory
704
+ workspace_path = self.base_workspace_dir / session_id
705
+ workspace_path.mkdir(exist_ok=True, parents=True)
706
+
707
+ # Create session
708
+ session = Session(
709
+ id=session_id,
710
+ created_at=datetime.now(),
711
+ last_accessed=datetime.now(),
712
+ workspace_path=workspace_path
713
+ )
714
+
715
+ self.sessions[session_id] = session
716
+
717
+ logger.info(f"Created session {session_id} with workspace {workspace_path}")
718
+ return session_id
719
+
720
+ async def get_session(self, session_id: str) -> Optional[Session]:
721
+ """Get session by ID if it exists and is not expired"""
722
+ async with self.lock:
723
+ session = self.sessions.get(session_id)
724
+ if session and not session.is_expired(self.ttl_seconds):
725
+ session.touch()
726
+ return session
727
+ elif session:
728
+ # Remove expired session
729
+ del self.sessions[session_id]
730
+ logger.info(f"Removed expired session {session_id}")
731
+ return None
732
+
733
+ async def get_or_create_session(self, session_id: Optional[str] = None) -> Session:
734
+ """Get existing session or create new one"""
735
+ if session_id:
736
+ session = await self.get_session(session_id)
737
+ if session:
738
+ return session
739
+
740
+ # Create new session
741
+ new_session_id = await self.create_session()
742
+ return self.sessions[new_session_id]
743
+
744
+ async def _cleanup_expired_sessions(self):
745
+ """Remove expired sessions"""
746
+ async with self.lock:
747
+ expired_sessions = []
748
+ for session_id, session in self.sessions.items():
749
+ if session.is_expired(self.ttl_seconds):
750
+ expired_sessions.append(session_id)
751
+
752
+ for session_id in expired_sessions:
753
+ del self.sessions[session_id]
754
+ logger.info(f"Cleaned up expired session {session_id}")
755
+
756
+ async def _cleanup_oldest_sessions(self):
757
+ """Remove oldest sessions when limit is reached"""
758
+ async with self.lock:
759
+ if len(self.sessions) < self.max_sessions:
760
+ return
761
+
762
+ # Sort by last accessed time and remove oldest
763
+ sorted_sessions = sorted(
764
+ self.sessions.items(),
765
+ key=lambda x: x[1].last_accessed
766
+ )
767
+
768
+ sessions_to_remove = len(self.sessions) - self.max_sessions + 10 # Remove extra
769
+ for i in range(sessions_to_remove):
770
+ if i < len(sorted_sessions):
771
+ session_id = sorted_sessions[i][0]
772
+ del self.sessions[session_id]
773
+ logger.info(f"Removed old session {session_id} due to session limit")
774
+
775
+ def _start_cleanup_thread(self):
776
+ """Start background cleanup thread"""
777
+ def cleanup_worker():
778
+ while True:
779
+ try:
780
+ time.sleep(config.cleanup_interval_seconds)
781
+ # Run async method in sync context
782
+ loop = asyncio.new_event_loop()
783
+ loop.run_until_complete(self._cleanup_expired_sessions())
784
+ loop.close()
785
+ except Exception as e:
786
+ logger.error(f"Error in cleanup thread: {e}")
787
+
788
+ import threading
789
+ cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
790
+ cleanup_thread.start()
791
+ logger.info("Started session cleanup thread")
792
+
793
+
794
+
795
+ async def get_stats(self) -> Dict[str, Any]:
796
+ """Get session manager statistics"""
797
+ async with self.lock:
798
+ return {
799
+ "total_sessions": len(self.sessions),
800
+ "max_sessions": self.max_sessions,
801
+ "ttl_seconds": self.ttl_seconds,
802
+ "session_ids": list(self.sessions.keys())
803
+ }
804
+
805
+ # ================ MIDDLEWARE AND SECURITY ================
806
+
807
+
808
+ class RateLimiter:
809
+ """Simple rate limiter with time-window tracking"""
810
+
811
+ def __init__(self, requests_per_minute: int = 60):
812
+ self.requests_per_minute = requests_per_minute
813
+ self.requests: Dict[str, List[float]] = defaultdict(list)
814
+ self.lock = asyncio.Lock()
815
+
816
+ async def is_allowed(self, client_id: str) -> bool:
817
+ """Check if request is allowed for client"""
818
+ async with self.lock:
819
+ now = time.time()
820
+ minute_ago = now - 60
821
+
822
+ # Clean old requests
823
+ self.requests[client_id] = [
824
+ req_time for req_time in self.requests[client_id]
825
+ if req_time > minute_ago
826
+ ]
827
+
828
+ # Check rate limit
829
+ if len(self.requests[client_id]) >= self.requests_per_minute:
830
+ return False
831
+
832
+ # Add current request
833
+ self.requests[client_id].append(now)
834
+ return True
835
+
836
+
837
+ class RequestValidator:
838
+ """Validates incoming MCP requests"""
839
+
840
+ @staticmethod
841
+ def validate_mcp_request(data: Dict[str, Any]) -> tuple[bool, Optional[str]]:
842
+ """Validate basic MCP request structure"""
843
+ if not isinstance(data, dict):
844
+ return False, "Request must be a JSON object"
845
+
846
+ if "method" not in data:
847
+ return False, "Missing 'method' field"
848
+
849
+ if "id" not in data:
850
+ return False, "Missing 'id' field"
851
+
852
+ return True, None
853
+
854
+ @staticmethod
855
+ def validate_tool_call(params: Dict[str, Any]) -> tuple[bool, Optional[str]]:
856
+ """Validate tool call parameters"""
857
+ if not isinstance(params, dict):
858
+ return False, "Tool parameters must be a JSON object"
859
+
860
+ if "name" not in params:
861
+ return False, "Missing tool 'name'"
862
+
863
+ if "arguments" not in params:
864
+ return False, "Missing tool 'arguments'"
865
+
866
+ tool_name = params["name"]
867
+
868
+ # Get detailed schemas
869
+ detailed_schemas = get_tool_schemas()
870
+
871
+ if tool_name not in detailed_schemas:
872
+ return False, f"Unknown tool: {tool_name}. Available tools: {sorted(list(detailed_schemas.keys()))}"
873
+
874
+ return True, None
875
+
876
+
877
+ class SecurityMiddleware(BaseHTTPMiddleware):
878
+ """Security middleware for basic protection"""
879
+
880
+ async def dispatch(self, request: Request, call_next):
881
+ # Check content length
882
+ content_length = request.headers.get("content-length")
883
+ if content_length and int(content_length) > config.max_request_size_mb * 1024 * 1024:
884
+ return JSONResponse(
885
+ status_code=HTTPStatus.REQUEST_ENTITY_TOO_LARGE,
886
+ content={"error": "Request too large"}
887
+ )
888
+
889
+ # Add security headers
890
+ response = await call_next(request)
891
+ response.headers["X-Content-Type-Options"] = "nosniff"
892
+ response.headers["X-Frame-Options"] = "DENY"
893
+ response.headers["X-XSS-Protection"] = "1; mode=block"
894
+
895
+ return response
896
+
897
+
898
+ class RateLimitMiddleware(BaseHTTPMiddleware):
899
+ """Rate limiting middleware"""
900
+
901
+ def __init__(self, app, input_rate_limiter: RateLimiter):
902
+ super().__init__(app)
903
+ self.rate_limiter = input_rate_limiter
904
+
905
+ async def dispatch(self, request: Request, call_next):
906
+ # Get client identifier (IP address)
907
+ client_ip = request.client.host if request.client else "unknown"
908
+
909
+ if not await self.rate_limiter.is_allowed(client_ip):
910
+ return JSONResponse(
911
+ status_code=HTTPStatus.TOO_MANY_REQUESTS,
912
+ content={"error": "Rate limit exceeded"}
913
+ )
914
+
915
+ return await call_next(request)
916
+
917
+ # Global session manager
918
+ session_manager = None
919
+ rate_limiter = None
920
+
921
+
922
+ @dataclass
923
+ class RateLimitViolation:
924
+ """Represents a rate limit violation with standardized error information"""
925
+ tool_name: str
926
+ limit_type: str # "burst", "second", "minute", "hour"
927
+ current_usage: int
928
+ limit_value: float
929
+ retry_after_seconds: float
930
+
931
+ def to_user_friendly_message(self) -> str:
932
+ """Generate user-friendly error message"""
933
+ if self.limit_type == "burst":
934
+ return f"Service temporarily unavailable: Too many rapid requests to {self.tool_name}. Please wait {self.retry_after_seconds:.0f} seconds before trying again."
935
+ elif self.limit_type == "second":
936
+ return f"Service temporarily unavailable: {self.tool_name} request rate exceeded ({self.limit_value}/second). Please wait {self.retry_after_seconds:.0f} seconds before trying again."
937
+ elif self.limit_type == "minute":
938
+ return f"Service temporarily unavailable: {self.tool_name} quota exceeded ({self.limit_value}/minute). Please try again in {self.retry_after_seconds:.0f} seconds."
939
+ elif self.limit_type == "hour":
940
+ return f"Service temporarily unavailable: {self.tool_name} hourly quota exceeded ({self.limit_value}/hour). Please try again in {self.retry_after_seconds:.0f} minutes."
941
+ else:
942
+ return f"Service temporarily unavailable: {self.tool_name} rate limit exceeded. Please try again later."
943
+
944
+ def to_technical_message(self) -> str:
945
+ """Generate technical error message for debugging"""
946
+ return f"Tool '{self.tool_name}' {self.limit_type} limit exceeded ({self.current_usage}/{self.limit_value} {self.limit_type})"
947
+
948
+
949
+ def _parse_rate_limit_denial(tool_name: str, denial_reason: str) -> RateLimitViolation:
950
+ """Parse rate limit denial reason into structured violation information"""
951
+ import re
952
+
953
+ # Default values
954
+ limit_type = "unknown"
955
+ current_usage = 0
956
+ limit_value = 0.0
957
+ retry_after_seconds = 60.0 # Default retry after 1 minute
958
+
959
+ # Parse different types of rate limit violations
960
+ if "burst limit exceeded" in denial_reason:
961
+ limit_type = "burst"
962
+ retry_after_seconds = 1.0 # Burst limits reset quickly
963
+ match = re.search(r'\((\d+) requests/burst\)', denial_reason)
964
+ if match:
965
+ limit_value = float(match.group(1))
966
+ current_usage = int(limit_value) # Approximation
967
+
968
+ elif "per-second limit exceeded" in denial_reason:
969
+ limit_type = "second"
970
+ retry_after_seconds = 1.0 # Wait 1 second
971
+ match = re.search(r'\(([0-9.]+) requests/second\)', denial_reason)
972
+ if match:
973
+ limit_value = float(match.group(1))
974
+ current_usage = int(limit_value) # Approximation
975
+
976
+ elif "per-minute limit exceeded" in denial_reason:
977
+ limit_type = "minute"
978
+ retry_after_seconds = 10.0 # Wait 10 seconds for minute limits
979
+ match = re.search(r'\(([0-9.]+) requests/minute\)', denial_reason)
980
+ if match:
981
+ limit_value = float(match.group(1))
982
+ current_usage = int(limit_value) # Approximation
983
+
984
+ elif "per-hour limit exceeded" in denial_reason:
985
+ limit_type = "hour"
986
+ retry_after_seconds = 300.0 # Wait 5 minutes for hour limits
987
+ match = re.search(r'\(([0-9.]+) requests/hour\)', denial_reason)
988
+ if match:
989
+ limit_value = float(match.group(1))
990
+ current_usage = int(limit_value) # Approximation
991
+
992
+ return RateLimitViolation(
993
+ tool_name=tool_name,
994
+ limit_type=limit_type,
995
+ current_usage=current_usage,
996
+ limit_value=limit_value,
997
+ retry_after_seconds=retry_after_seconds
998
+ )
999
+
1000
+
1001
+ async def _call_session_tool_async(session: Session, tool_name: str, tool_args: Dict[str, Any],
1002
+ client_ip: str = "unknown") -> Dict[str, Any]:
1003
+ """Execute a tool within a session context with full tracking, workspace management, and global rate limiting"""
1004
+
1005
+ start_time = time.time()
1006
+ success = False
1007
+ error_details = None
1008
+ result_data = None
1009
+
1010
+ # Touch session at start of tool execution to prevent expiry during long operations
1011
+ session.touch()
1012
+
1013
+ try:
1014
+ # CHECK GLOBAL TOOL RATE LIMITS FIRST
1015
+ if global_tool_rate_limiter:
1016
+ allowed, deny_reason = await global_tool_rate_limiter.is_allowed(tool_name)
1017
+ if not allowed:
1018
+ # Parse the denial reason to create structured rate limit violation
1019
+ rate_limit_violation = _parse_rate_limit_denial(tool_name, deny_reason)
1020
+
1021
+ # Create user-friendly error message
1022
+ user_message = rate_limit_violation.to_user_friendly_message()
1023
+ technical_message = rate_limit_violation.to_technical_message()
1024
+
1025
+ logger.warning(f"Session {session.id}: {technical_message}")
1026
+
1027
+ result_data = {
1028
+ "success": False,
1029
+ "error": user_message,
1030
+ "error_code": "RATE_LIMIT_EXCEEDED",
1031
+ "error_type": "rate_limit",
1032
+ "tool_name": tool_name,
1033
+ "limit_type": rate_limit_violation.limit_type,
1034
+ "retry_after_seconds": rate_limit_violation.retry_after_seconds,
1035
+ "data": None,
1036
+ "rate_limited": True, # Keep for backward compatibility
1037
+ "technical_details": technical_message # For debugging
1038
+ }
1039
+
1040
+ # Still log this for tracking purposes
1041
+ duration_ms = (time.time() - start_time) * 1000
1042
+ tracker = session.get_tool_tracker()
1043
+ if tracker:
1044
+ try:
1045
+ agent_info = {
1046
+ "client_ip": client_ip,
1047
+ "type": "unknown",
1048
+ "session_request_count": session.request_count
1049
+ }
1050
+
1051
+ tracker.log_tool_call(
1052
+ tool_name=tool_name,
1053
+ input_args=tool_args,
1054
+ output_result=result_data,
1055
+ success=False,
1056
+ duration_ms=duration_ms,
1057
+ error_details=user_message,
1058
+ agent_info=agent_info
1059
+ )
1060
+ except Exception as e:
1061
+ logger.error(f"Failed to log rate-limited tool call: {e}")
1062
+
1063
+ return result_data
1064
+
1065
+ # Get MCP tools instance for this session (handles workspace isolation)
1066
+ mcp_tools = session.get_mcp_tools(prefer_async=True)
1067
+
1068
+ # Get tool method directly from the mcp_tools instance
1069
+ if not hasattr(mcp_tools, tool_name):
1070
+ raise ValueError(f"Tool '{tool_name}' not implemented")
1071
+
1072
+ tool_method = getattr(mcp_tools, tool_name)
1073
+
1074
+ # Add session context to tool arguments for workspace-aware tools
1075
+ if hasattr(mcp_tools, 'set_session_context'):
1076
+ mcp_tools.set_session_context(session.id, str(session.workspace_path))
1077
+
1078
+ # Execute tool with keep-alive for potentially long operations
1079
+ logger.info(f"Session {session.id}: Executing tool '{tool_name}' with args: {list(tool_args.keys())}")
1080
+
1081
+ # Use keep-alive wrapper for tools that might take a long time
1082
+ long_running_tools = {'batch_web_search', 'url_crawler', 'document_qa', 'document_extract', 'bash'}
1083
+
1084
+ # Check if the tool method is async
1085
+ import inspect
1086
+ is_async_tool = inspect.iscoroutinefunction(tool_method)
1087
+
1088
+ # Execute tool based on whether it's async or sync
1089
+ if is_async_tool:
1090
+ # Tool is async - execute directly
1091
+ logger.debug("Executing async tool '{%s}'", tool_name)
1092
+
1093
+ if config.enable_session_keepalive and tool_name in long_running_tools:
1094
+ # For long-running async tools, use keep-alive
1095
+ with KeepAliveSessionWrapper(session, touch_interval=config.keepalive_touch_interval):
1096
+ result = await tool_method(**tool_args)
1097
+ else:
1098
+ # For regular async tools, execute directly
1099
+ result = await tool_method(**tool_args)
1100
+ else:
1101
+ # Tool is sync - execute in thread pool
1102
+ logger.debug("Executing sync tool '{%s}' in thread pool", tool_name)
1103
+
1104
+ # Define the synchronous tool execution function
1105
+ def execute_tool_sync():
1106
+ """Synchronous tool execution to be run in thread pool"""
1107
+ return tool_method(**tool_args)
1108
+
1109
+ # Execute tool asynchronously in thread pool for true non-blocking execution
1110
+ import asyncio
1111
+ import concurrent.futures
1112
+
1113
+ # Create a thread pool executor for CPU-bound/blocking operations
1114
+ loop = asyncio.get_event_loop()
1115
+
1116
+ if config.enable_session_keepalive and tool_name in long_running_tools:
1117
+ # For long-running tools, use keep-alive with async execution
1118
+ with KeepAliveSessionWrapper(session, touch_interval=config.keepalive_touch_interval):
1119
+ # Run in thread pool to avoid blocking the event loop
1120
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
1121
+ result = await loop.run_in_executor(executor, execute_tool_sync)
1122
+ else:
1123
+ # For regular tools, use async execution without keep-alive
1124
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
1125
+ result = await loop.run_in_executor(executor, execute_tool_sync)
1126
+
1127
+ # Touch session after tool execution to update activity
1128
+ session.touch()
1129
+
1130
+ # Handle different result formats
1131
+ if hasattr(result, 'to_dict'):
1132
+ result_data = result.to_dict()
1133
+ elif isinstance(result, dict):
1134
+ result_data = result
1135
+ else:
1136
+ result_data = {"result": result}
1137
+
1138
+ success = result_data.get('success', True)
1139
+
1140
+ if success:
1141
+ logger.info(f"Session {session.id}: Tool '{tool_name}' completed successfully")
1142
+
1143
+ # RECORD SUCCESSFUL REQUEST FOR RATE LIMITING
1144
+ if global_tool_rate_limiter:
1145
+ await global_tool_rate_limiter.record_request(tool_name)
1146
+
1147
+
1148
+
1149
+ else:
1150
+ error_details = result_data.get('error', 'Unknown error')
1151
+ logger.warning(f"Session {session.id}: Tool '{tool_name}' failed: {error_details}")
1152
+
1153
+ except Exception as e:
1154
+ success = False
1155
+ error_details = str(e)
1156
+ result_data = {
1157
+ "success": False,
1158
+ "error": error_details,
1159
+ "data": None
1160
+ }
1161
+ logger.error(f"Session {session.id}: Tool '{tool_name}' exception: {e}")
1162
+
1163
+ # Calculate execution time
1164
+ duration_ms = (time.time() - start_time) * 1000
1165
+
1166
+ # Log tool call if tracking is enabled
1167
+ tracker = session.get_tool_tracker()
1168
+ if tracker:
1169
+ try:
1170
+ agent_info = {
1171
+ "client_ip": client_ip,
1172
+ "type": "unknown", # Could be enhanced to detect agent type
1173
+ "session_request_count": session.request_count
1174
+ }
1175
+
1176
+ tracker.log_tool_call(
1177
+ tool_name=tool_name,
1178
+ input_args=tool_args,
1179
+ output_result=result_data,
1180
+ success=success,
1181
+ duration_ms=duration_ms,
1182
+ error_details=error_details,
1183
+ agent_info=agent_info
1184
+ )
1185
+ except Exception as e:
1186
+ logger.error(f"Failed to log tool call: {e}")
1187
+
1188
+ return result_data
1189
+
1190
+
1191
+
1192
+ def create_sse_response(response_data: dict, session_id: str = None) -> StreamingResponse:
1193
+ """Create Server-Sent Events response with proper formatting"""
1194
+ def generate_sse():
1195
+ try:
1196
+ # Add session info to response if available
1197
+ if session_id:
1198
+ response_data["session_id"] = session_id
1199
+
1200
+ json_data = json.dumps(response_data, ensure_ascii=False)
1201
+ yield f"event: message\n"
1202
+ yield f"data: {json_data}\n"
1203
+ yield f"\n"
1204
+ except Exception as e:
1205
+ error_data = {
1206
+ "jsonrpc": "2.0",
1207
+ "error": {"code": JsonRpcErr.INTERNAL_ERROR, "message": f"Internal error: {str(e)}"},
1208
+ "id": response_data.get("id")
1209
+ }
1210
+ json_data = json.dumps(error_data, ensure_ascii=False)
1211
+ yield f"event: error\n"
1212
+ yield f"data: {json_data}\n"
1213
+ yield f"\n"
1214
+
1215
+ return StreamingResponse(
1216
+ generate_sse(),
1217
+ media_type="text/event-stream",
1218
+ headers={
1219
+ "Cache-Control": "no-cache",
1220
+ "Connection": "keep-alive",
1221
+ "Access-Control-Allow-Origin": "*",
1222
+ }
1223
+ )
1224
+
1225
+
1226
+ def create_error_response(request_id: Any, code: int, message: str, session_id: str = None) -> StreamingResponse:
1227
+ """Create error response in SSE format"""
1228
+ error_data = {
1229
+ "jsonrpc": "2.0",
1230
+ "error": {"code": code, "message": message},
1231
+ "id": request_id
1232
+ }
1233
+ return create_sse_response(error_data, session_id)
1234
+
1235
+
1236
+ def create_rate_limit_response(
1237
+ request_id: Any,
1238
+ tool_name: str,
1239
+ error_message: str,
1240
+ retry_after_seconds: float,
1241
+ limit_type: str,
1242
+ technical_details: str = "",
1243
+ session_id: str = None
1244
+ ) -> JSONResponse:
1245
+ """
1246
+ Create HTTP 429 Rate Limit Exceeded response with proper headers and error format.
1247
+
1248
+ Returns proper HTTP status code instead of SSE for rate limiting errors.
1249
+ """
1250
+
1251
+ # Calculate retry-after header value
1252
+ retry_after_header = int(max(1.0, retry_after_seconds))
1253
+
1254
+ # Create standardized error response
1255
+ error_data = {
1256
+ "error": {
1257
+ "type": "rate_limit_exceeded",
1258
+ "code": "RATE_LIMIT_EXCEEDED",
1259
+ "message": error_message,
1260
+ "details": {
1261
+ "tool_name": tool_name,
1262
+ "limit_type": limit_type,
1263
+ "retry_after_seconds": retry_after_seconds,
1264
+ "technical_details": technical_details
1265
+ }
1266
+ },
1267
+ "request_id": request_id,
1268
+ "timestamp": datetime.now().isoformat(),
1269
+ "session_id": session_id
1270
+ }
1271
+
1272
+ # Set appropriate headers
1273
+ headers = {
1274
+ "Retry-After": str(retry_after_header), # HTTP standard header
1275
+ "X-RateLimit-Limit-Type": limit_type,
1276
+ "X-RateLimit-Tool": tool_name,
1277
+ "X-RateLimit-Retry-After": str(retry_after_seconds),
1278
+ "Content-Type": "application/json"
1279
+ }
1280
+
1281
+ return JSONResponse(
1282
+ status_code=HTTPStatus.TOO_MANY_REQUESTS, # Too Many Requests
1283
+ content=error_data,
1284
+ headers=headers
1285
+ )
1286
+
1287
+
1288
+ async def handle_mcp_request(request: Request) -> StreamingResponse:
1289
+ """Main MCP request handler with session management and tool execution"""
1290
+
1291
+ try:
1292
+ # Check content length before reading body
1293
+ content_length = request.headers.get("content-length")
1294
+ if content_length:
1295
+ content_size_mb = int(content_length) / (1024 * 1024)
1296
+ if content_size_mb > config.max_request_size_mb:
1297
+ logger.warning(f"Request too large: {content_size_mb:.2f}MB > {config.max_request_size_mb}MB")
1298
+ return create_error_response(None, JsonRpcErr.PARSE_ERROR, f"Request too large: {content_size_mb:.2f}MB")
1299
+
1300
+ # Parse request with timeout protection
1301
+ try:
1302
+ body = await asyncio.wait_for(request.body(), timeout=config.request_timeout_seconds)
1303
+ except asyncio.TimeoutError:
1304
+ logger.error("Timeout while reading request body")
1305
+ return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Request body read timeout")
1306
+
1307
+ if not body:
1308
+ return create_error_response(None, JsonRpcErr.PARSE_ERROR, "Empty request body")
1309
+
1310
+ try:
1311
+ data = json.loads(body.decode('utf-8'))
1312
+ except json.JSONDecodeError as e:
1313
+ return create_error_response(None, JsonRpcErr.PARSE_ERROR, f"Invalid JSON: {str(e)}")
1314
+
1315
+ # Validate MCP request structure
1316
+ is_valid, error_msg = RequestValidator.validate_mcp_request(data)
1317
+ if not is_valid:
1318
+ return create_error_response(data.get("id"), JsonRpcErr.INVALID_REQUEST, error_msg)
1319
+
1320
+ request_id = data["id"]
1321
+ method = data["method"]
1322
+ params = data.get("params", {})
1323
+
1324
+ # Get or create session
1325
+ session_id = request.headers.get("X-Session-ID")
1326
+ client_ip = request.client.host if request.client else "unknown"
1327
+
1328
+ session = await session_manager.get_or_create_session(session_id)
1329
+ logger.info(f"Processing {method} request for session {session.id} from {client_ip}")
1330
+
1331
+ # Handle different MCP methods
1332
+ if method == "initialize":
1333
+ # MCP initialization
1334
+ response_data = {
1335
+ "jsonrpc": "2.0",
1336
+ "result": {
1337
+ "protocolVersion": "2025-06-18",
1338
+ "capabilities": {
1339
+ "tools": {"supportsProgress": True},
1340
+ "resources": {},
1341
+ "prompts": {}
1342
+ },
1343
+ "serverInfo": {
1344
+ "name": "DeepDiver-Demo-MCP",
1345
+ "version": "1.0.0"
1346
+ }
1347
+ },
1348
+ "id": request_id
1349
+ }
1350
+
1351
+ elif method == "tools/list":
1352
+ # List available tools using detailed schemas from get_tool_schemas()
1353
+ tools_list = []
1354
+ detailed_schemas = get_tool_schemas()
1355
+
1356
+ # Build tools list from schemas
1357
+ for _, detailed_schema in detailed_schemas.items():
1358
+ tools_list.append({
1359
+ "name": detailed_schema["name"],
1360
+ "description": detailed_schema["description"],
1361
+ "inputSchema": detailed_schema["inputSchema"]
1362
+ })
1363
+
1364
+ logger.info(f"Serving {len(tools_list)} tools with detailed schemas to client")
1365
+
1366
+ response_data = {
1367
+ "jsonrpc": "2.0",
1368
+ "result": {"tools": tools_list},
1369
+ "id": request_id
1370
+ }
1371
+
1372
+ elif method == "tools/call":
1373
+ # Execute tool call
1374
+ is_valid, error_msg = RequestValidator.validate_tool_call(params)
1375
+ if not is_valid:
1376
+ return create_error_response(request_id, JsonRpcErr.INVALID_PARAMS, error_msg, session.id)
1377
+
1378
+ tool_name = params["name"]
1379
+ tool_arguments = params["arguments"]
1380
+
1381
+ # Execute tool in session context asynchronously
1382
+ result = await _call_session_tool_async(session, tool_name, tool_arguments, client_ip)
1383
+
1384
+ # CHECK FOR RATE LIMITING AND RETURN PROPER HTTP STATUS
1385
+ if result.get("rate_limited", False):
1386
+ return create_rate_limit_response(
1387
+ request_id=request_id,
1388
+ tool_name=tool_name,
1389
+ error_message=result.get("error", "Rate limit exceeded"),
1390
+ retry_after_seconds=result.get("retry_after_seconds", 60),
1391
+ limit_type=result.get("limit_type", "unknown"),
1392
+ technical_details=result.get("technical_details", ""),
1393
+ session_id=session.id
1394
+ )
1395
+
1396
+ # Format normal response
1397
+ response_data = {
1398
+ "jsonrpc": "2.0",
1399
+ "result": {
1400
+ "content": [
1401
+ {
1402
+ "type": "text",
1403
+ "text": json.dumps(result, indent=2, ensure_ascii=False)
1404
+ }
1405
+ ]
1406
+ },
1407
+ "id": request_id
1408
+ }
1409
+
1410
+ else:
1411
+ return create_error_response(request_id, JsonRpcErr.METHOD_NOT_FOUND, f"Method not found: {method}", session.id)
1412
+
1413
+ return create_sse_response(response_data, session.id)
1414
+
1415
+ except asyncio.TimeoutError:
1416
+ logger.warning("Request timeout - client may have disconnected")
1417
+ return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Request timeout")
1418
+ except Exception as e:
1419
+ # Handle client disconnects gracefully
1420
+ if "ClientDisconnect" in str(e) or "ConnectionClosedError" in str(e):
1421
+ logger.warning(f"Client disconnected during request processing: {e}")
1422
+ return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Client disconnected")
1423
+
1424
+ logger.error(f"Unexpected error in MCP request handler: {e}")
1425
+ import traceback
1426
+ logger.error(traceback.format_exc())
1427
+ return create_error_response(None, JsonRpcErr.INTERNAL_ERROR, f"Internal server error: {str(e)}")
1428
+
1429
+
1430
+ async def handle_health_check(request: Request) -> JSONResponse:
1431
+ """Health check endpoint"""
1432
+ try:
1433
+ stats = await session_manager.get_stats() if session_manager else {}
1434
+
1435
+ # Get rate limiting summary
1436
+ rate_limit_summary = {}
1437
+ if global_tool_rate_limiter:
1438
+ all_stats = global_tool_rate_limiter.get_all_stats()
1439
+ rate_limit_summary = {
1440
+ "enabled": True,
1441
+ "tools_with_limits": len(all_stats),
1442
+ "total_configured_tools": list(all_stats.keys())
1443
+ }
1444
+ else:
1445
+ rate_limit_summary = {"enabled": False}
1446
+
1447
+ health_data = {
1448
+ "status": "healthy",
1449
+ "timestamp": datetime.now().isoformat(),
1450
+ "version": "1.0.0",
1451
+ "session_stats": stats,
1452
+ "features": {
1453
+ "workspace_isolation": True,
1454
+ "tool_call_tracking": config.enable_tool_tracking if config else False,
1455
+ "client_rate_limiting": True,
1456
+ "global_tool_rate_limiting": rate_limit_summary["enabled"],
1457
+ "security_middleware": True,
1458
+ "standardized_rate_limit_responses": True
1459
+ },
1460
+ "rate_limiting": rate_limit_summary,
1461
+ "error_formats": {
1462
+ "rate_limit_exceeded": {
1463
+ "http_status": HTTPStatus.TOO_MANY_REQUESTS,
1464
+ "headers": ["Retry-After", "X-RateLimit-*"],
1465
+ "error_code": "RATE_LIMIT_EXCEEDED",
1466
+ "response_format": "application/json"
1467
+ }
1468
+ }
1469
+ }
1470
+
1471
+ return JSONResponse(content=health_data)
1472
+
1473
+ except Exception as e:
1474
+ return JSONResponse(
1475
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
1476
+ content={"status": "unhealthy", "error": str(e)}
1477
+ )
1478
+
1479
+
1480
+ async def handle_tracking_info(request: Request) -> JSONResponse:
1481
+ """Get tool call tracking information for a session"""
1482
+ try:
1483
+ session_id = request.query_params.get("session_id")
1484
+ if not session_id:
1485
+ return JSONResponse(
1486
+ status_code=HTTPStatus.BAD_REQUEST,
1487
+ content={"error": "session_id parameter required"}
1488
+ )
1489
+
1490
+ session = await session_manager.get_session(session_id)
1491
+ if not session:
1492
+ return JSONResponse(
1493
+ status_code=HTTPStatus.NOT_FOUND,
1494
+ content={"error": f"Session {session_id} not found"}
1495
+ )
1496
+
1497
+ tracker = session.get_tool_tracker()
1498
+ if not tracker:
1499
+ return JSONResponse(
1500
+ content={
1501
+ "session_id": session_id,
1502
+ "tracking_enabled": False,
1503
+ "message": "Tool call tracking not enabled or no workspace"
1504
+ }
1505
+ )
1506
+
1507
+ # Read session summary
1508
+ summary_data = {}
1509
+ if tracker.summary_file.exists():
1510
+ try:
1511
+ with open(tracker.summary_file, 'r') as f:
1512
+ summary_data = json.load(f)
1513
+ except Exception as e:
1514
+ logger.error(f"Failed to read session summary: {e}")
1515
+
1516
+ return JSONResponse(content={
1517
+ "session_id": session_id,
1518
+ "tracking_enabled": True,
1519
+ "summary": summary_data,
1520
+ "logs_directory": str(tracker.logs_dir),
1521
+ "current_log_file": str(tracker.current_log_file)
1522
+ })
1523
+
1524
+ except Exception as e:
1525
+ return JSONResponse(
1526
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
1527
+ content={"error": str(e)}
1528
+ )
1529
+
1530
+
1531
+
1532
+ async def handle_rate_limit_stats(request: Request) -> JSONResponse:
1533
+ """Get global tool rate limiting statistics"""
1534
+ try:
1535
+ if not global_tool_rate_limiter:
1536
+ return JSONResponse(
1537
+ status_code=HTTPStatus.NOT_FOUND,
1538
+ content={"error": "Global tool rate limiter not initialized"}
1539
+ )
1540
+
1541
+ # Check if specific tool requested
1542
+ tool_name = request.query_params.get("tool")
1543
+
1544
+ if tool_name:
1545
+ # Get stats for specific tool
1546
+ stats = await global_tool_rate_limiter.get_tool_stats(tool_name)
1547
+ return JSONResponse(content=stats)
1548
+ else:
1549
+ # Get stats for all tools
1550
+ all_stats = global_tool_rate_limiter.get_all_stats()
1551
+ return JSONResponse(content={
1552
+ "timestamp": datetime.now().isoformat(),
1553
+ "global_tool_rate_limiting": True,
1554
+ "tools": all_stats,
1555
+ "summary": {
1556
+ "total_tools_with_limits": len(all_stats),
1557
+ "tools_configured": list(all_stats.keys())
1558
+ }
1559
+ })
1560
+
1561
+ except Exception as e:
1562
+ logger.error(f"Failed to get rate limit stats: {e}")
1563
+ return JSONResponse(
1564
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
1565
+ content={"error": str(e)}
1566
+ )
1567
+
1568
+
1569
+ def create_app() -> Starlette:
1570
+ """Create and configure the Starlette application"""
1571
+ global session_manager, rate_limiter, global_tool_rate_limiter
1572
+
1573
+ if not config:
1574
+ raise RuntimeError("Server configuration not initialized")
1575
+
1576
+ # Initialize global components
1577
+ session_manager = ThreadSafeSessionManager(
1578
+ ttl_seconds=config.session_ttl_seconds,
1579
+ max_sessions=config.max_sessions,
1580
+ base_workspace_dir=config.base_workspace_dir
1581
+ )
1582
+ rate_limiter = RateLimiter(config.rate_limit_requests_per_minute)
1583
+
1584
+ # Initialize global tool rate limiter
1585
+ if config.tool_rate_limits:
1586
+ global_tool_rate_limiter = GlobalToolRateLimiter(config.tool_rate_limits)
1587
+ logger.info(f"Initialized global tool rate limiter with {len(config.tool_rate_limits)} tool limits")
1588
+ else:
1589
+ logger.info("No tool rate limits configured - tools will run without global rate limiting")
1590
+
1591
+ # Create app
1592
+ app = Starlette(debug=config.debug_mode)
1593
+
1594
+ app.add_middleware(SecurityMiddleware)
1595
+ app.add_middleware(RateLimitMiddleware, input_rate_limiter=rate_limiter)
1596
+
1597
+ # Add routes
1598
+ app.add_route("/mcp", handle_mcp_request, methods=["POST"])
1599
+ app.add_route("/health", handle_health_check, methods=["GET"])
1600
+ app.add_route("/tracking", handle_tracking_info, methods=["GET"])
1601
+ app.add_route("/rate-limits", handle_rate_limit_stats, methods=["GET"])
1602
+
1603
+ return app
1604
+
1605
+
1606
+ def parse_arguments():
1607
+ """Parse command line arguments"""
1608
+ parser = argparse.ArgumentParser(
1609
+ description="Demo-Ready MCP Server with Per-Tool Rate Limiting",
1610
+ formatter_class=argparse.RawDescriptionHelpFormatter,
1611
+ epilog="""
1612
+ Examples:
1613
+ python src/tools/mcp_server_standard.py --config src/tools/server_config.yaml
1614
+ python src/tools/mcp_server_standard.py --host 127.0.0.1 --port 8080
1615
+ python src/tools/mcp_server_standard.py --config custom_config.yaml --debug
1616
+ """
1617
+ )
1618
+
1619
+ parser.add_argument(
1620
+ '--config', '-c',
1621
+ type=str,
1622
+ help='Path to YAML configuration file'
1623
+ )
1624
+
1625
+ parser.add_argument(
1626
+ '--host',
1627
+ type=str,
1628
+ help='Server host (overrides config file)'
1629
+ )
1630
+
1631
+ parser.add_argument(
1632
+ '--port', '-p',
1633
+ type=int,
1634
+ help='Server port (overrides config file)'
1635
+ )
1636
+
1637
+ parser.add_argument(
1638
+ '--debug',
1639
+ action='store_true',
1640
+ help='Enable debug mode (overrides config file)'
1641
+ )
1642
+
1643
+ parser.add_argument(
1644
+ '--workspace-dir',
1645
+ type=str,
1646
+ help='Base workspace directory (overrides config file)'
1647
+ )
1648
+
1649
+ return parser.parse_args()
1650
+
1651
+
1652
+ def print_startup_info():
1653
+ """Print server startup information"""
1654
+ logger.info("🚀 DeepDiver Demo MCP Server")
1655
+ logger.info("=" * 50)
1656
+ logger.info(f"📊 Features:")
1657
+ logger.info(f" • Session Management: ✅ (TTL: {config.session_ttl_seconds}s)")
1658
+ logger.info(f" • Workspace Isolation: ✅ (Base: {config.base_workspace_dir})")
1659
+ logger.info(f" • Tool Call Tracking: {'✅' if config.enable_tool_tracking else '❌'}")
1660
+ logger.info(f" • Client Rate Limiting: ✅ ({config.rate_limit_requests_per_minute}/min)")
1661
+ logger.info(f" • Global Tool Rate Limiting: {'✅' if config.tool_rate_limits else '❌'}")
1662
+ logger.info(f" • Security Middleware: ✅")
1663
+
1664
+ # Tool rate limiting information
1665
+ if config.tool_rate_limits:
1666
+ logger.info(f"🚦 Tool Rate Limits: {len(config.tool_rate_limits)} tools configured")
1667
+ for tool_name, limits in list(config.tool_rate_limits.items())[:3]:
1668
+ burst = limits.get('burst_limit', '∞')
1669
+ rpm = limits.get('requests_per_minute', '∞')
1670
+ logger.info(f" • {tool_name}: {rpm}/min, burst: {burst}")
1671
+ if len(config.tool_rate_limits) > 3:
1672
+ logger.info(f" • ... and {len(config.tool_rate_limits) - 3} more tools")
1673
+
1674
+ # Tool information from schemas
1675
+ tool_schemas = get_tool_schemas()
1676
+ available_tools = list(tool_schemas.keys())
1677
+
1678
+ logger.info(f"🔧 Tools Available: {len(available_tools)}")
1679
+ logger.info(f" • All tools defined in schemas: {len(available_tools)} tools")
1680
+ logger.info(f" • Sample tools: {', '.join(sorted(available_tools)[:5])}...")
1681
+ logger.info("=" * 50)
1682
+
1683
+
1684
+ def main():
1685
+ """Main function to run the production MCP server"""
1686
+ global config
1687
+
1688
+ # Parse command line arguments
1689
+ args = parse_arguments()
1690
+
1691
+ config = ServerConfig.from_yaml("./src/tools/server_config.yaml")
1692
+
1693
+ # Apply CLI overrides
1694
+ if args.host:
1695
+ config.host = args.host
1696
+ logger.info(f"🔧 Override: Host = {config.host}")
1697
+
1698
+ if args.port:
1699
+ config.port = args.port
1700
+ logger.info(f"🔧 Override: Port = {config.port}")
1701
+
1702
+ if args.debug:
1703
+ config.debug_mode = True
1704
+ logger.info(f"🔧 Override: Debug mode enabled")
1705
+
1706
+ if args.workspace_dir:
1707
+ config.base_workspace_dir = args.workspace_dir
1708
+ logger.info(f"🔧 Override: Workspace directory = {config.base_workspace_dir}")
1709
+
1710
+ print_startup_info()
1711
+
1712
+ try:
1713
+ import os
1714
+
1715
+ # Calculate optimal worker count for high-concurrency FIRST
1716
+ # Use CPU core count indirectly via uvicorn's defaults; no local variable needed
1717
+
1718
+ # Override for high-concurrency scenarios
1719
+ if os.getenv('FORCE_HIGH_CONCURRENCY', '').lower() == 'true':
1720
+ pass # Configuration handled elsewhere if needed
1721
+
1722
+ app = create_app()
1723
+
1724
+ logger.info(f"🌐 Starting server at http://{config.host}:{config.port}")
1725
+ logger.info(f"📡 MCP endpoint: http://{config.host}:{config.port}/mcp")
1726
+ logger.info(f"🏥 Health check: http://{config.host}:{config.port}/health")
1727
+ logger.info(f"📊 Tracking info: http://{config.host}:{config.port}/tracking?session_id=<id>")
1728
+ logger.info(f"🚦 Rate limit stats: http://{config.host}:{config.port}/rate-limits")
1729
+
1730
+ uvicorn.run(
1731
+ app, # Use app instance directly for single worker with async optimizations
1732
+ host=config.host,
1733
+ port=config.port,
1734
+ log_level="info",
1735
+ timeout_keep_alive=config.request_timeout_seconds,
1736
+ workers=1, # Single worker with async optimizations
1737
+ backlog=1024, # Larger backlog for high-concurrency
1738
+ access_log=False, # Disable access logs for better performance
1739
+ limit_concurrency=None, # No artificial concurrency limit
1740
+ )
1741
+
1742
+ except KeyboardInterrupt:
1743
+ print("\n⏹️ Server stopped by user")
1744
+ except Exception as e:
1745
+ print(f"❌ Server startup failed: {e}")
1746
+ import traceback
1747
+ traceback.print_exc()
1748
+ raise
1749
+
1750
+ if __name__ == "__main__":
1751
+ main()