Bloom_Ware / features /mcp /mcp_client.py
XiaoBai1221's picture
Bloom Ware 最終修復版本 (Spaces 兼容)
3f0377e
"""
MCP 協議客戶端實作
支援與外部 MCP 服務器進行通信,動態發現和調用工具
"""
import asyncio
import json
import logging
import subprocess
import sys
import os
from typing import Dict, Any, List, Optional, Callable, Union
from pathlib import Path
import uuid
import time
from .types import Tool
logger = logging.getLogger("mcp.client")
class MCPClientError(Exception):
"""MCP 客戶端錯誤"""
pass
class MCPClient:
"""MCP 協議客戶端"""
def __init__(self, server_name: str, server_config: Dict[str, Any]):
self.server_name = server_name
self.server_config = server_config
self.process: Optional[subprocess.Popen] = None
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.request_id = 0
self.pending_requests: Dict[str, asyncio.Future] = {}
self.tools: Dict[str, Tool] = {}
self.connected = False
self.initialized = False
def _get_next_request_id(self) -> str:
"""獲取下一個請求ID"""
self.request_id += 1
return str(self.request_id)
async def start(self) -> bool:
"""啟動 MCP 服務器進程"""
try:
command = self.server_config.get("command", "")
args = self.server_config.get("args", [])
env = self.server_config.get("env", {})
if not command:
raise MCPClientError(f"服務器 {self.server_name} 沒有指定命令")
# 合併環境變數
process_env = dict(os.environ)
process_env.update(env)
# 啟動進程
cmd = [command] + args
logger.info(f"啟動 MCP 服務器: {' '.join(cmd)}")
self.process = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=process_env,
text=False # 使用二進制模式
)
# 創建異步流
loop = asyncio.get_event_loop()
self.reader = asyncio.StreamReader()
reader_protocol = asyncio.StreamReaderProtocol(self.reader)
await loop.connect_read_pipe(lambda: reader_protocol, self.process.stdout)
# 啟動消息處理循環
asyncio.create_task(self._message_loop())
# 等待連接建立
await asyncio.sleep(0.1)
# 初始化服務器
success = await self._initialize()
if success:
# 發現工具
await self._discover_tools()
self.connected = True
logger.info(f"MCP 服務器 {self.server_name} 連接成功,發現 {len(self.tools)} 個工具")
return True
else:
await self.stop()
return False
except Exception as e:
logger.error(f"啟動 MCP 服務器失敗 {self.server_name}: {e}")
await self.stop()
return False
async def stop(self):
"""停止 MCP 服務器"""
self.connected = False
self.initialized = False
if self.process:
try:
self.process.terminate()
await asyncio.wait_for(asyncio.sleep(1), timeout=2.0)
if self.process.poll() is None:
self.process.kill()
except Exception as e:
logger.error(f"停止進程時發生錯誤: {e}")
self.pending_requests.clear()
self.tools.clear()
async def _initialize(self) -> bool:
"""初始化 MCP 服務器"""
try:
response = await self._send_request("initialize", {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {"listChanged": True}
},
"clientInfo": {
"name": "mcp-client",
"version": "1.0.0"
}
})
if response and response.get("result"):
self.initialized = True
logger.info(f"MCP 服務器 {self.server_name} 初始化成功")
return True
else:
logger.error(f"MCP 服務器 {self.server_name} 初始化失敗")
return False
except Exception as e:
logger.error(f"初始化 MCP 服務器失敗 {self.server_name}: {e}")
return False
async def _discover_tools(self):
"""發現服務器提供的工具"""
try:
response = await self._send_request("tools/list", {})
if response and response.get("result"):
tools_data = response["result"].get("tools", [])
for tool_data in tools_data:
tool = self._create_tool_from_data(tool_data)
if tool:
self.tools[tool.name] = tool
logger.info(f"發現外部工具: {tool.name}")
except Exception as e:
logger.error(f"發現工具失敗 {self.server_name}: {e}")
def _create_tool_from_data(self, tool_data: Dict[str, Any]) -> Optional[Tool]:
"""從工具數據創建 Tool 實例"""
try:
name = tool_data.get("name")
description = tool_data.get("description", "")
input_schema = tool_data.get("inputSchema", {"type": "object", "properties": {}})
# 創建代理處理器
async def tool_handler(arguments: Dict[str, Any]) -> Dict[str, Any]:
return await self._call_tool(name, arguments)
tool = Tool(
name=name,
description=description,
inputSchema=input_schema,
handler=tool_handler
)
return tool
except Exception as e:
logger.error(f"創建工具失敗: {e}")
return None
async def _call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""調用外部工具"""
try:
response = await self._send_request("tools/call", {
"name": tool_name,
"arguments": arguments
})
if response and response.get("result"):
content = response["result"].get("content", [])
return {
"success": True,
"content": "\n".join([item.get("text", "") for item in content if item.get("type") == "text"])
}
else:
error = response.get("error", {}).get("message", "未知錯誤")
return {
"success": False,
"error": error
}
except Exception as e:
logger.error(f"調用工具失敗 {tool_name}: {e}")
return {
"success": False,
"error": str(e)
}
async def _send_request(self, method: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""發送 JSON-RPC 請求"""
if not self.process or not self.initialized:
return None
request_id = self._get_next_request_id()
request = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
"params": params
}
# 創建等待響應的 Future
future = asyncio.Future()
self.pending_requests[request_id] = future
try:
# 發送請求
request_json = json.dumps(request, ensure_ascii=False) + "\n"
self.process.stdin.write(request_json.encode('utf-8'))
self.process.stdin.flush()
# 等待響應
response = await asyncio.wait_for(future, timeout=30.0)
return response
except asyncio.TimeoutError:
logger.error(f"請求超時: {method}")
return None
except Exception as e:
logger.error(f"發送請求失敗: {e}")
return None
finally:
# 清理
self.pending_requests.pop(request_id, None)
async def _message_loop(self):
"""消息處理循環"""
try:
while self.process and self.process.poll() is None:
try:
# 讀取一行
line = await self.reader.readline()
if not line:
break
# 解析 JSON
message = json.loads(line.decode('utf-8'))
# 處理消息
await self._handle_message(message)
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"處理消息時發生錯誤: {e}")
break
except Exception as e:
logger.error(f"消息循環錯誤: {e}")
finally:
self.connected = False
async def _handle_message(self, message: Dict[str, Any]):
"""處理接收到的消息"""
try:
# 檢查是否是響應
if "id" in message and "result" in message or "error" in message:
request_id = str(message["id"])
if request_id in self.pending_requests:
future = self.pending_requests[request_id]
if not future.done():
future.set_result(message)
# 處理服務器主動消息 (如工具列表變化通知)
elif "method" in message:
method = message.get("method")
params = message.get("params", {})
if method == "tools/listChanged":
# 工具列表發生變化,重新發現
logger.info(f"MCP 服務器 {self.server_name} 工具列表發生變化")
await self._discover_tools()
except Exception as e:
logger.error(f"處理消息失敗: {e}")
class MCPClientManager:
"""MCP 客戶端管理器"""
def __init__(self):
self.clients: Dict[str, MCPClient] = {}
self.logger = logging.getLogger("mcp.client_manager")
async def start_client(self, server_name: str, server_config: Dict[str, Any]) -> bool:
"""啟動 MCP 客戶端"""
try:
client = MCPClient(server_name, server_config)
success = await client.start()
if success:
self.clients[server_name] = client
self.logger.info(f"MCP 客戶端 {server_name} 啟動成功")
return True
else:
self.logger.error(f"MCP 客戶端 {server_name} 啟動失敗")
return False
except Exception as e:
self.logger.error(f"啟動 MCP 客戶端失敗 {server_name}: {e}")
return False
async def stop_client(self, server_name: str):
"""停止 MCP 客戶端"""
if server_name in self.clients:
client = self.clients[server_name]
await client.stop()
del self.clients[server_name]
self.logger.info(f"MCP 客戶端 {server_name} 已停止")
async def stop_all(self):
"""停止所有客戶端"""
for server_name in list(self.clients.keys()):
await self.stop_client(server_name)
def get_client_tools(self, server_name: str) -> Dict[str, Tool]:
"""獲取客戶端的工具"""
if server_name in self.clients:
return self.clients[server_name].tools
return {}
def get_all_tools(self) -> Dict[str, Tool]:
"""獲取所有客戶端的工具"""
all_tools = {}
for server_name, client in self.clients.items():
for tool_name, tool in client.tools.items():
# 添加服務器前綴避免衝突
prefixed_name = f"{server_name}_{tool_name}"
all_tools[prefixed_name] = tool
return all_tools
def is_client_connected(self, server_name: str) -> bool:
"""檢查客戶端是否連接"""
return server_name in self.clients and self.clients[server_name].connected