leideng/QCFuse / srt /entrypoints /openai /tool_server.py
leideng's picture
download
raw
5.66 kB
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from abc import ABC, abstractmethod
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any
try:
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.types import ListToolsResult
except ImportError as e:
ClientSession = sse_client = ListToolsResult = e
from openai_harmony import ToolDescription, ToolNamespaceConfig
logger = logging.getLogger(__name__)
async def list_server_and_tools(server_url: str):
async with sse_client(url=server_url) as streams, ClientSession(
*streams
) as session:
initialize_response = await session.initialize()
list_tools_response = await session.list_tools()
return initialize_response, list_tools_response
def trim_schema(schema: dict) -> dict:
# Turn JSON Schema from MCP generated into Harmony's variant.
if "title" in schema:
del schema["title"]
if "default" in schema and schema["default"] is None:
del schema["default"]
if "anyOf" in schema:
# Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
# into "type": ["type-1", "type-2"]
# if there's more than 1 types, also remove "null" type as Harmony will
# just ignore it
types = [
type_dict["type"]
for type_dict in schema["anyOf"]
if type_dict["type"] != "null"
]
schema["type"] = types
del schema["anyOf"]
if "properties" in schema:
schema["properties"] = {
k: trim_schema(v) for k, v in schema["properties"].items()
}
return schema
def post_process_tools_description(
list_tools_result: "ListToolsResult",
) -> "ListToolsResult":
# Adapt the MCP tool result for Harmony
for tool in list_tools_result.tools:
tool.inputSchema = trim_schema(tool.inputSchema)
# Some tools schema don't need to be part of the prompt (e.g. simple text
# in text out for Python)
list_tools_result.tools = [
tool
for tool in list_tools_result.tools
if getattr(tool.annotations, "include_in_prompt", True)
]
return list_tools_result
class ToolServer(ABC):
@abstractmethod
def has_tool(self, tool_name: str):
pass
@abstractmethod
def get_tool_description(self, tool_name: str):
pass
@abstractmethod
def get_tool_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: ...
class MCPToolServer(ToolServer):
def __init__(self):
self.harmony_tool_descriptions = {}
async def add_tool_server(self, server_url: str):
tool_urls = server_url.split(",")
self.harmony_tool_descriptions = {}
self.urls: dict[str, str] = {}
for url in tool_urls:
url = f"http://{url}/sse"
initialize_response, list_tools_response = await list_server_and_tools(url)
list_tools_response = post_process_tools_description(list_tools_response)
tool_from_mcp = ToolNamespaceConfig(
name=initialize_response.serverInfo.name,
description=initialize_response.instructions,
tools=[
ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.inputSchema,
)
for tool in list_tools_response.tools
],
)
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
if tool_from_mcp.name not in self.urls:
self.urls[tool_from_mcp.name] = url
else:
logger.warning(
"Tool %s already exists. Ignoring duplicate tool server %s",
tool_from_mcp.name,
url,
)
def has_tool(self, tool_name: str):
return tool_name in self.harmony_tool_descriptions
def get_tool_description(self, tool_name: str):
return self.harmony_tool_descriptions.get(tool_name)
@asynccontextmanager
async def get_tool_session(self, tool_name: str):
url = self.urls.get(tool_name)
if url:
async with sse_client(url=url) as streams, ClientSession(
*streams
) as session:
await session.initialize()
yield session
else:
logger.warning("Tool %s not found", tool_name)
class DemoToolServer(ToolServer):
def __init__(self):
from sglang.srt.entrypoints.tool import (
HarmonyBrowserTool,
HarmonyPythonTool,
Tool,
)
self.tools: dict[str, Tool] = {}
browser_tool = HarmonyBrowserTool()
if browser_tool.enabled:
self.tools["browser"] = browser_tool
python_tool = HarmonyPythonTool()
if python_tool.enabled:
self.tools["python"] = python_tool
def has_tool(self, tool_name: str):
return tool_name in self.tools
def get_tool_description(self, tool_name: str):
if tool_name not in self.tools:
return None
if tool_name == "browser":
return ToolNamespaceConfig.browser()
elif tool_name == "python":
return ToolNamespaceConfig.python()
else:
raise ValueError(f"Unknown tool {tool_name}")
@asynccontextmanager
async def get_tool_session(self, tool_name: str):
yield self.tools[tool_name]

Xet Storage Details

Size:
5.66 kB
·
Xet hash:
9e0e73bd00f3b062f25499cc05debeabffcb7ce340da1d841645b341367e4fc8

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.