| | import inspect |
| | import logging |
| | import re |
| | import inspect |
| | import aiohttp |
| | import asyncio |
| | import yaml |
| | import json |
| |
|
| | from pydantic import BaseModel |
| | from pydantic.fields import FieldInfo |
| | from typing import ( |
| | Any, |
| | Awaitable, |
| | Callable, |
| | get_type_hints, |
| | get_args, |
| | get_origin, |
| | Dict, |
| | List, |
| | Tuple, |
| | Union, |
| | Optional, |
| | Type, |
| | ) |
| | from functools import update_wrapper, partial |
| |
|
| |
|
| | from fastapi import Request |
| | from pydantic import BaseModel, Field, create_model |
| |
|
| | from langchain_core.utils.function_calling import ( |
| | convert_to_openai_function as convert_pydantic_model_to_openai_function_spec, |
| | ) |
| |
|
| |
|
| | from open_webui.utils.misc import is_string_allowed |
| | from open_webui.models.tools import Tools |
| | from open_webui.models.users import UserModel |
| | from open_webui.models.groups import Groups |
| | from open_webui.models.access_grants import AccessGrants |
| | from open_webui.utils.plugin import load_tool_module_by_id |
| | from open_webui.utils.access_control import has_access |
| | from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL |
| | from open_webui.env import ( |
| | AIOHTTP_CLIENT_TIMEOUT, |
| | AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA, |
| | AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, |
| | ENABLE_FORWARD_USER_INFO_HEADERS, |
| | FORWARD_SESSION_INFO_HEADER_CHAT_ID, |
| | FORWARD_SESSION_INFO_HEADER_MESSAGE_ID, |
| | ) |
| | from open_webui.utils.headers import include_user_info_headers |
| | from open_webui.tools.builtin import ( |
| | search_web, |
| | fetch_url, |
| | generate_image, |
| | edit_image, |
| | execute_code, |
| | search_memories, |
| | add_memory, |
| | replace_memory_content, |
| | get_current_timestamp, |
| | calculate_timestamp, |
| | search_notes, |
| | search_chats, |
| | search_channels, |
| | search_channel_messages, |
| | view_note, |
| | view_chat, |
| | view_channel_message, |
| | view_channel_thread, |
| | replace_note_content, |
| | write_note, |
| | list_knowledge_bases, |
| | search_knowledge_bases, |
| | query_knowledge_bases, |
| | search_knowledge_files, |
| | query_knowledge_files, |
| | view_file, |
| | view_knowledge_file, |
| | view_skill, |
| | ) |
| |
|
| | import copy |
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| |
|
| | def get_async_tool_function_and_apply_extra_params( |
| | function: Callable, extra_params: dict |
| | ) -> Callable[..., Awaitable]: |
| | sig = inspect.signature(function) |
| | extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters} |
| | partial_func = partial(function, **extra_params) |
| |
|
| | |
| | |
| | parameters = [] |
| | for name, parameter in sig.parameters.items(): |
| | |
| | if name in extra_params: |
| | continue |
| | |
| | parameters.append(parameter) |
| |
|
| | new_sig = inspect.Signature( |
| | parameters=parameters, return_annotation=sig.return_annotation |
| | ) |
| |
|
| | if inspect.iscoroutinefunction(function): |
| | |
| | |
| | async def new_function(*args, **kwargs): |
| | return await partial_func(*args, **kwargs) |
| |
|
| | else: |
| | |
| | async def new_function(*args, **kwargs): |
| | return partial_func(*args, **kwargs) |
| |
|
| | update_wrapper(new_function, function) |
| | new_function.__signature__ = new_sig |
| |
|
| | new_function.__function__ = function |
| | new_function.__extra_params__ = extra_params |
| |
|
| | return new_function |
| |
|
| |
|
| | def get_updated_tool_function(function: Callable, extra_params: dict): |
| | |
| | __function__ = getattr(function, "__function__", None) |
| | __extra_params__ = getattr(function, "__extra_params__", None) |
| |
|
| | if __function__ is not None and __extra_params__ is not None: |
| | return get_async_tool_function_and_apply_extra_params( |
| | __function__, |
| | {**__extra_params__, **extra_params}, |
| | ) |
| |
|
| | return function |
| |
|
| |
|
| | def has_tool_server_access( |
| | user: UserModel, server_connection: dict, user_group_ids: set = None |
| | ) -> bool: |
| | """Check if user has access to a tool server (MCP or OpenAPI).""" |
| | if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: |
| | return True |
| |
|
| | if user_group_ids is None: |
| | user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} |
| |
|
| | server_config = server_connection.get("config", {}) |
| | access_grants = server_config.get("access_grants", []) |
| | return has_access(user.id, "read", access_grants, user_group_ids) |
| |
|
| |
|
| | async def get_tools( |
| | request: Request, tool_ids: list[str], user: UserModel, extra_params: dict |
| | ) -> dict[str, dict]: |
| | """Load tools for the given tool_ids, checking access control.""" |
| | tools_dict = {} |
| |
|
| | |
| | user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} |
| |
|
| | for tool_id in tool_ids: |
| | tool = Tools.get_tool_by_id(tool_id) |
| | if tool: |
| | |
| | if ( |
| | not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) |
| | and tool.user_id != user.id |
| | and not AccessGrants.has_access( |
| | user_id=user.id, |
| | resource_type="tool", |
| | resource_id=tool.id, |
| | permission="read", |
| | user_group_ids=user_group_ids, |
| | ) |
| | ): |
| | log.warning(f"Access denied to tool {tool_id} for user {user.id}") |
| | continue |
| |
|
| | module = request.app.state.TOOLS.get(tool_id, None) |
| | if module is None: |
| | module, _ = load_tool_module_by_id(tool_id) |
| | request.app.state.TOOLS[tool_id] = module |
| |
|
| | __user__ = { |
| | **extra_params["__user__"], |
| | } |
| |
|
| | |
| | if hasattr(module, "valves") and hasattr(module, "Valves"): |
| | valves = Tools.get_tool_valves_by_id(tool_id) or {} |
| | module.valves = module.Valves(**valves) |
| | if hasattr(module, "UserValves"): |
| | __user__["valves"] = module.UserValves( |
| | **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) |
| | ) |
| |
|
| | for spec in tool.specs: |
| | |
| | |
| | for val in spec.get("parameters", {}).get("properties", {}).values(): |
| | if val.get("type") == "str": |
| | val["type"] = "string" |
| |
|
| | |
| | spec["parameters"]["properties"] = { |
| | key: val |
| | for key, val in spec["parameters"]["properties"].items() |
| | if not key.startswith("__") |
| | } |
| |
|
| | |
| | function_name = spec["name"] |
| | tool_function = getattr(module, function_name) |
| | callable = get_async_tool_function_and_apply_extra_params( |
| | tool_function, |
| | { |
| | **extra_params, |
| | "__id__": tool_id, |
| | "__user__": __user__, |
| | }, |
| | ) |
| |
|
| | |
| | if callable.__doc__ and callable.__doc__.strip() != "": |
| | s = re.split(":(param|return)", callable.__doc__, 1) |
| | spec["description"] = s[0] |
| | else: |
| | spec["description"] = function_name |
| |
|
| | tool_dict = { |
| | "tool_id": tool_id, |
| | "callable": callable, |
| | "spec": spec, |
| | |
| | "metadata": { |
| | "file_handler": hasattr(module, "file_handler") |
| | and module.file_handler, |
| | "citation": hasattr(module, "citation") and module.citation, |
| | }, |
| | } |
| |
|
| | |
| | while function_name in tools_dict: |
| | log.warning( |
| | f"Tool {function_name} already exists in another tools!" |
| | ) |
| | |
| | function_name = f"{tool_id}_{function_name}" |
| |
|
| | tools_dict[function_name] = tool_dict |
| | else: |
| | if tool_id.startswith("server:"): |
| | splits = tool_id.split(":") |
| |
|
| | if len(splits) == 2: |
| | type = "openapi" |
| | server_id = splits[1] |
| | elif len(splits) == 3: |
| | type = splits[1] |
| | server_id = splits[2] |
| |
|
| | server_id_splits = server_id.split("|") |
| | if len(server_id_splits) == 2: |
| | server_id = server_id_splits[0] |
| | function_names = server_id_splits[1].split(",") |
| |
|
| | if type == "openapi": |
| |
|
| | tool_server_data = None |
| | for server in await get_tool_servers(request): |
| | if server["id"] == server_id: |
| | tool_server_data = server |
| | break |
| |
|
| | if tool_server_data is None: |
| | log.warning(f"Tool server data not found for {server_id}") |
| | continue |
| |
|
| | tool_server_idx = tool_server_data.get("idx", 0) |
| | tool_server_connection = ( |
| | request.app.state.config.TOOL_SERVER_CONNECTIONS[ |
| | tool_server_idx |
| | ] |
| | ) |
| |
|
| | |
| | if not has_tool_server_access( |
| | user, tool_server_connection, user_group_ids |
| | ): |
| | log.warning( |
| | f"Access denied to tool server {server_id} for user {user.id}" |
| | ) |
| | continue |
| |
|
| | specs = tool_server_data.get("specs", []) |
| | function_name_filter_list = tool_server_connection.get( |
| | "config", {} |
| | ).get("function_name_filter_list", "") |
| |
|
| | if isinstance(function_name_filter_list, str): |
| | function_name_filter_list = function_name_filter_list.split(",") |
| |
|
| | for spec in specs: |
| | function_name = spec["name"] |
| | if function_name_filter_list: |
| | if not is_string_allowed( |
| | function_name, function_name_filter_list |
| | ): |
| | |
| | continue |
| |
|
| | auth_type = tool_server_connection.get("auth_type", "bearer") |
| |
|
| | cookies = {} |
| | headers = { |
| | "Content-Type": "application/json", |
| | } |
| |
|
| | if auth_type == "bearer": |
| | headers["Authorization"] = ( |
| | f"Bearer {tool_server_connection.get('key', '')}" |
| | ) |
| | elif auth_type == "none": |
| | |
| | pass |
| | elif auth_type == "session": |
| | cookies = request.cookies |
| | headers["Authorization"] = ( |
| | f"Bearer {request.state.token.credentials}" |
| | ) |
| | elif auth_type == "system_oauth": |
| | cookies = request.cookies |
| | oauth_token = extra_params.get("__oauth_token__", None) |
| | if oauth_token: |
| | headers["Authorization"] = ( |
| | f"Bearer {oauth_token.get('access_token', '')}" |
| | ) |
| |
|
| | connection_headers = tool_server_connection.get("headers", None) |
| | if connection_headers and isinstance(connection_headers, dict): |
| | for key, value in connection_headers.items(): |
| | headers[key] = value |
| |
|
| | |
| | if ENABLE_FORWARD_USER_INFO_HEADERS and user: |
| | headers = include_user_info_headers(headers, user) |
| | metadata = extra_params.get("__metadata__", {}) |
| | if metadata and metadata.get("chat_id"): |
| | headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = ( |
| | metadata.get("chat_id") |
| | ) |
| | if metadata and metadata.get("message_id"): |
| | headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = ( |
| | metadata.get("message_id") |
| | ) |
| |
|
| | def make_tool_function( |
| | function_name, tool_server_data, headers |
| | ): |
| | async def tool_function(**kwargs): |
| | return await execute_tool_server( |
| | url=tool_server_data["url"], |
| | headers=headers, |
| | cookies=cookies, |
| | name=function_name, |
| | params=kwargs, |
| | server_data=tool_server_data, |
| | ) |
| |
|
| | return tool_function |
| |
|
| | tool_function = make_tool_function( |
| | function_name, tool_server_data, headers |
| | ) |
| |
|
| | callable = get_async_tool_function_and_apply_extra_params( |
| | tool_function, |
| | {}, |
| | ) |
| |
|
| | tool_dict = { |
| | "tool_id": tool_id, |
| | "callable": callable, |
| | "spec": spec, |
| | |
| | "type": "external", |
| | } |
| |
|
| | |
| | while function_name in tools_dict: |
| | log.warning( |
| | f"Tool {function_name} already exists in another tools!" |
| | ) |
| | |
| | function_name = f"{server_id}_{function_name}" |
| |
|
| | tools_dict[function_name] = tool_dict |
| |
|
| | else: |
| | continue |
| |
|
| | return tools_dict |
| |
|
| |
|
| | def get_builtin_tools( |
| | request: Request, extra_params: dict, features: dict = None, model: dict = None |
| | ) -> dict[str, dict]: |
| | """ |
| | Get built-in tools for native function calling. |
| | Only returns tools when BOTH the global config is enabled AND the model capability allows it. |
| | """ |
| | tools_dict = {} |
| | builtin_functions = [] |
| | features = features or {} |
| | model = model or {} |
| |
|
| | |
| | def get_model_capability(name: str, default: bool = True) -> bool: |
| | return (model.get("info", {}).get("meta", {}).get("capabilities") or {}).get( |
| | name, default |
| | ) |
| |
|
| | |
| | |
| | def is_builtin_tool_enabled(category: str) -> bool: |
| | builtin_tools = model.get("info", {}).get("meta", {}).get("builtinTools", {}) |
| | return builtin_tools.get(category, True) |
| |
|
| | |
| | if is_builtin_tool_enabled("time"): |
| | builtin_functions.extend([get_current_timestamp, calculate_timestamp]) |
| |
|
| | |
| | |
| | |
| | model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", []) |
| | if is_builtin_tool_enabled("knowledge"): |
| | if model_knowledge: |
| | |
| | builtin_functions.append(query_knowledge_files) |
| |
|
| | knowledge_types = {item.get("type") for item in model_knowledge} |
| | if "file" in knowledge_types or "collection" in knowledge_types: |
| | builtin_functions.append(view_file) |
| | if "note" in knowledge_types: |
| | builtin_functions.append(view_note) |
| | else: |
| | |
| | builtin_functions.extend( |
| | [ |
| | list_knowledge_bases, |
| | search_knowledge_bases, |
| | query_knowledge_bases, |
| | search_knowledge_files, |
| | query_knowledge_files, |
| | view_knowledge_file, |
| | ] |
| | ) |
| |
|
| | |
| | if is_builtin_tool_enabled("chats"): |
| | builtin_functions.extend([search_chats, view_chat]) |
| |
|
| | |
| | if is_builtin_tool_enabled("memory") and features.get("memory"): |
| | builtin_functions.extend([search_memories, add_memory, replace_memory_content]) |
| |
|
| | |
| | if ( |
| | is_builtin_tool_enabled("web_search") |
| | and getattr(request.app.state.config, "ENABLE_WEB_SEARCH", False) |
| | and get_model_capability("web_search") |
| | and features.get("web_search") |
| | ): |
| | builtin_functions.extend([search_web, fetch_url]) |
| |
|
| | |
| | if ( |
| | is_builtin_tool_enabled("image_generation") |
| | and getattr(request.app.state.config, "ENABLE_IMAGE_GENERATION", False) |
| | and get_model_capability("image_generation") |
| | and features.get("image_generation") |
| | ): |
| | builtin_functions.append(generate_image) |
| | if ( |
| | is_builtin_tool_enabled("image_generation") |
| | and getattr(request.app.state.config, "ENABLE_IMAGE_EDIT", False) |
| | and get_model_capability("image_generation") |
| | and features.get("image_generation") |
| | ): |
| | builtin_functions.append(edit_image) |
| |
|
| | |
| | if ( |
| | is_builtin_tool_enabled("code_interpreter") |
| | and getattr(request.app.state.config, "ENABLE_CODE_INTERPRETER", True) |
| | and get_model_capability("code_interpreter") |
| | and features.get("code_interpreter") |
| | ): |
| | builtin_functions.append(execute_code) |
| |
|
| | |
| | if is_builtin_tool_enabled("notes") and getattr( |
| | request.app.state.config, "ENABLE_NOTES", False |
| | ): |
| | builtin_functions.extend( |
| | [search_notes, view_note, write_note, replace_note_content] |
| | ) |
| |
|
| | |
| | if is_builtin_tool_enabled("channels") and getattr( |
| | request.app.state.config, "ENABLE_CHANNELS", False |
| | ): |
| | builtin_functions.extend( |
| | [ |
| | search_channels, |
| | search_channel_messages, |
| | view_channel_thread, |
| | view_channel_message, |
| | ] |
| | ) |
| |
|
| | |
| | if extra_params.get("__skill_ids__"): |
| | builtin_functions.append(view_skill) |
| |
|
| | for func in builtin_functions: |
| | callable = get_async_tool_function_and_apply_extra_params( |
| | func, |
| | { |
| | "__request__": request, |
| | "__user__": extra_params.get("__user__", {}), |
| | "__event_emitter__": extra_params.get("__event_emitter__"), |
| | "__event_call__": extra_params.get("__event_call__"), |
| | "__metadata__": extra_params.get("__metadata__"), |
| | "__chat_id__": extra_params.get("__chat_id__"), |
| | "__message_id__": extra_params.get("__message_id__"), |
| | "__model_knowledge__": model_knowledge, |
| | }, |
| | ) |
| |
|
| | |
| | pydantic_model = convert_function_to_pydantic_model(func) |
| | spec = convert_pydantic_model_to_openai_function_spec(pydantic_model) |
| |
|
| | tools_dict[func.__name__] = { |
| | "tool_id": f"builtin:{func.__name__}", |
| | "callable": callable, |
| | "spec": spec, |
| | "type": "builtin", |
| | } |
| |
|
| | return tools_dict |
| |
|
| |
|
| | def parse_description(docstring: str | None) -> str: |
| | """ |
| | Parse a function's docstring to extract the description. |
| | |
| | Args: |
| | docstring (str): The docstring to parse. |
| | |
| | Returns: |
| | str: The description. |
| | """ |
| |
|
| | if not docstring: |
| | return "" |
| |
|
| | lines = [line.strip() for line in docstring.strip().split("\n")] |
| | description_lines: list[str] = [] |
| |
|
| | for line in lines: |
| | if re.match(r":param", line) or re.match(r":return", line): |
| | break |
| |
|
| | description_lines.append(line) |
| |
|
| | return "\n".join(description_lines) |
| |
|
| |
|
| | def parse_docstring(docstring): |
| | """ |
| | Parse a function's docstring to extract parameter descriptions in reST format. |
| | |
| | Args: |
| | docstring (str): The docstring to parse. |
| | |
| | Returns: |
| | dict: A dictionary where keys are parameter names and values are descriptions. |
| | """ |
| | if not docstring: |
| | return {} |
| |
|
| | |
| | param_pattern = re.compile(r":param (\w+):\s*(.+)") |
| | param_descriptions = {} |
| |
|
| | for line in docstring.splitlines(): |
| | match = param_pattern.match(line.strip()) |
| | if not match: |
| | continue |
| | param_name, param_description = match.groups() |
| | if param_name.startswith("__"): |
| | continue |
| | param_descriptions[param_name] = param_description |
| |
|
| | return param_descriptions |
| |
|
| |
|
| | def convert_function_to_pydantic_model(func: Callable) -> type[BaseModel]: |
| | """ |
| | Converts a Python function's type hints and docstring to a Pydantic model, |
| | including support for nested types, default values, and descriptions. |
| | |
| | Args: |
| | func: The function whose type hints and docstring should be converted. |
| | model_name: The name of the generated Pydantic model. |
| | |
| | Returns: |
| | A Pydantic model class. |
| | """ |
| | type_hints = get_type_hints(func) |
| | signature = inspect.signature(func) |
| | parameters = signature.parameters |
| |
|
| | docstring = func.__doc__ |
| |
|
| | function_description = parse_description(docstring) |
| | function_param_descriptions = parse_docstring(docstring) |
| |
|
| | field_defs = {} |
| | for name, param in parameters.items(): |
| | type_hint = type_hints.get(name, Any) |
| | default_value = param.default if param.default is not param.empty else ... |
| |
|
| | param_description = function_param_descriptions.get(name, None) |
| |
|
| | if param_description: |
| | field_defs[name] = ( |
| | type_hint, |
| | Field(default_value, description=param_description), |
| | ) |
| | else: |
| | field_defs[name] = type_hint, default_value |
| |
|
| | model = create_model(func.__name__, **field_defs) |
| | model.__doc__ = function_description |
| |
|
| | return model |
| |
|
| |
|
| | def get_functions_from_tool(tool: object) -> list[Callable]: |
| | return [ |
| | getattr(tool, func) |
| | for func in dir(tool) |
| | if callable( |
| | getattr(tool, func) |
| | ) |
| | and not func.startswith( |
| | "__" |
| | ) |
| | and not inspect.isclass( |
| | getattr(tool, func) |
| | ) |
| | ] |
| |
|
| |
|
| | def get_tool_specs(tool_module: object) -> list[dict]: |
| | function_models = map( |
| | convert_function_to_pydantic_model, get_functions_from_tool(tool_module) |
| | ) |
| |
|
| | specs = [ |
| | convert_pydantic_model_to_openai_function_spec(function_model) |
| | for function_model in function_models |
| | ] |
| |
|
| | return specs |
| |
|
| |
|
| | def resolve_schema(schema, components): |
| | """ |
| | Recursively resolves a JSON schema using OpenAPI components. |
| | """ |
| | if not schema: |
| | return {} |
| |
|
| | if "$ref" in schema: |
| | ref_path = schema["$ref"] |
| | ref_parts = ref_path.strip("#/").split("/") |
| | resolved = components |
| | for part in ref_parts[1:]: |
| | resolved = resolved.get(part, {}) |
| | return resolve_schema(resolved, components) |
| |
|
| | resolved_schema = copy.deepcopy(schema) |
| |
|
| | |
| | if "properties" in resolved_schema: |
| | for prop, prop_schema in resolved_schema["properties"].items(): |
| | resolved_schema["properties"][prop] = resolve_schema( |
| | prop_schema, components |
| | ) |
| |
|
| | if "items" in resolved_schema: |
| | resolved_schema["items"] = resolve_schema(resolved_schema["items"], components) |
| |
|
| | return resolved_schema |
| |
|
| |
|
| | def convert_openapi_to_tool_payload(openapi_spec): |
| | """ |
| | Converts an OpenAPI specification into a custom tool payload structure. |
| | |
| | Args: |
| | openapi_spec (dict): The OpenAPI specification as a Python dict. |
| | |
| | Returns: |
| | list: A list of tool payloads. |
| | """ |
| | tool_payload = [] |
| |
|
| | for path, methods in openapi_spec.get("paths", {}).items(): |
| | for method, operation in methods.items(): |
| | if operation.get("operationId"): |
| | tool = { |
| | "name": operation.get("operationId"), |
| | "description": operation.get( |
| | "description", |
| | operation.get("summary", "No description available."), |
| | ), |
| | "parameters": {"type": "object", "properties": {}, "required": []}, |
| | } |
| |
|
| | for param in operation.get("parameters", []): |
| | param_name = param.get("name") |
| | if not param_name: |
| | continue |
| | param_schema = param.get("schema", {}) |
| | description = param_schema.get("description", "") |
| | if not description: |
| | description = param.get("description") or "" |
| | if param_schema.get("enum") and isinstance( |
| | param_schema.get("enum"), list |
| | ): |
| | description += ( |
| | f". Possible values: {', '.join(param_schema.get('enum'))}" |
| | ) |
| | param_property = { |
| | "type": param_schema.get("type"), |
| | "description": description, |
| | } |
| |
|
| | |
| | if param_schema.get("type") == "array" and "items" in param_schema: |
| | param_property["items"] = param_schema["items"] |
| |
|
| | tool["parameters"]["properties"][param_name] = param_property |
| | if param.get("required"): |
| | tool["parameters"]["required"].append(param_name) |
| |
|
| | |
| | request_body = operation.get("requestBody") |
| | if request_body: |
| | content = request_body.get("content", {}) |
| | json_schema = content.get("application/json", {}).get("schema") |
| | if json_schema: |
| | resolved_schema = resolve_schema( |
| | json_schema, openapi_spec.get("components", {}) |
| | ) |
| |
|
| | if resolved_schema.get("properties"): |
| | tool["parameters"]["properties"].update( |
| | resolved_schema["properties"] |
| | ) |
| | if "required" in resolved_schema: |
| | tool["parameters"]["required"] = list( |
| | set( |
| | tool["parameters"]["required"] |
| | + resolved_schema["required"] |
| | ) |
| | ) |
| | elif resolved_schema.get("type") == "array": |
| | tool["parameters"] = ( |
| | resolved_schema |
| | ) |
| |
|
| | tool_payload.append(tool) |
| |
|
| | return tool_payload |
| |
|
| |
|
| | async def set_tool_servers(request: Request): |
| | request.app.state.TOOL_SERVERS = await get_tool_servers_data( |
| | request.app.state.config.TOOL_SERVER_CONNECTIONS |
| | ) |
| |
|
| | if request.app.state.redis is not None: |
| | await request.app.state.redis.set( |
| | "tool_servers", json.dumps(request.app.state.TOOL_SERVERS) |
| | ) |
| |
|
| | return request.app.state.TOOL_SERVERS |
| |
|
| |
|
| | async def get_tool_servers(request: Request): |
| | tool_servers = [] |
| | if request.app.state.redis is not None: |
| | try: |
| | tool_servers = json.loads(await request.app.state.redis.get("tool_servers")) |
| | request.app.state.TOOL_SERVERS = tool_servers |
| | except Exception as e: |
| | log.error(f"Error fetching tool_servers from Redis: {e}") |
| |
|
| | if not tool_servers: |
| | tool_servers = await set_tool_servers(request) |
| |
|
| | return tool_servers |
| |
|
| |
|
| | async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, Any]: |
| | _headers = { |
| | "Accept": "application/json", |
| | "Content-Type": "application/json", |
| | } |
| |
|
| | if headers: |
| | _headers.update(headers) |
| |
|
| | error = None |
| | try: |
| | timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA) |
| | async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: |
| | async with session.get( |
| | url, headers=_headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL |
| | ) as response: |
| | if response.status != 200: |
| | error_body = await response.json() |
| | raise Exception(error_body) |
| |
|
| | text_content = None |
| |
|
| | |
| | if url.lower().endswith((".yaml", ".yml")): |
| | text_content = await response.text() |
| | res = yaml.safe_load(text_content) |
| | else: |
| | text_content = await response.text() |
| |
|
| | try: |
| | res = json.loads(text_content) |
| | except json.JSONDecodeError: |
| | try: |
| | res = yaml.safe_load(text_content) |
| | except Exception as e: |
| | raise e |
| |
|
| | except Exception as err: |
| | log.exception(f"Could not fetch tool server spec from {url}") |
| | if isinstance(err, dict) and "detail" in err: |
| | error = err["detail"] |
| | else: |
| | error = str(err) |
| | raise Exception(error) |
| |
|
| | log.debug(f"Fetched data: {res}") |
| | return res |
| |
|
| |
|
| | async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| | |
| |
|
| | tasks = [] |
| | server_entries = [] |
| | for idx, server in enumerate(servers): |
| | if ( |
| | server.get("config", {}).get("enable") |
| | and server.get("type", "openapi") == "openapi" |
| | ): |
| | info = server.get("info", {}) |
| |
|
| | auth_type = server.get("auth_type", "bearer") |
| | token = None |
| |
|
| | if auth_type == "bearer": |
| | token = server.get("key", "") |
| | elif auth_type == "none": |
| | |
| | pass |
| |
|
| | id = info.get("id") |
| | if not id: |
| | id = str(idx) |
| |
|
| | server_url = server.get("url") |
| | spec_type = server.get("spec_type", "url") |
| |
|
| | |
| | task = None |
| | if spec_type == "url": |
| | |
| | openapi_path = server.get("path", "openapi.json") |
| | spec_url = get_tool_server_url(server_url, openapi_path) |
| | |
| | task = get_tool_server_data( |
| | spec_url, |
| | {"Authorization": f"Bearer {token}"} if token else None, |
| | ) |
| | elif spec_type == "json" and server.get("spec", ""): |
| | |
| | spec_json = None |
| | try: |
| | spec_json = json.loads(server.get("spec", "")) |
| | except Exception as e: |
| | log.error(f"Error parsing JSON spec for tool server {id}: {e}") |
| |
|
| | if spec_json: |
| | task = asyncio.sleep( |
| | 0, |
| | result=spec_json, |
| | ) |
| |
|
| | if task: |
| | tasks.append(task) |
| | server_entries.append((id, idx, server, server_url, info, token)) |
| |
|
| | |
| | responses = await asyncio.gather(*tasks, return_exceptions=True) |
| |
|
| | |
| | results = [] |
| | for (id, idx, server, url, info, _), response in zip(server_entries, responses): |
| | if isinstance(response, Exception): |
| | log.error(f"Failed to connect to {url} OpenAPI tool server") |
| | continue |
| |
|
| | response = { |
| | "openapi": response, |
| | "info": response.get("info", {}), |
| | "specs": convert_openapi_to_tool_payload(response), |
| | } |
| |
|
| | openapi_data = response.get("openapi", {}) |
| | if info and isinstance(openapi_data, dict): |
| | openapi_data["info"] = openapi_data.get("info", {}) |
| |
|
| | if "name" in info: |
| | openapi_data["info"]["title"] = info.get("name", "Tool Server") |
| |
|
| | if "description" in info: |
| | openapi_data["info"]["description"] = info.get("description", "") |
| |
|
| | results.append( |
| | { |
| | "id": str(id), |
| | "idx": idx, |
| | "url": server.get("url"), |
| | "openapi": openapi_data, |
| | "info": response.get("info"), |
| | "specs": response.get("specs"), |
| | } |
| | ) |
| |
|
| | return results |
| |
|
| |
|
| | async def execute_tool_server( |
| | url: str, |
| | headers: Dict[str, str], |
| | cookies: Dict[str, str], |
| | name: str, |
| | params: Dict[str, Any], |
| | server_data: Dict[str, Any], |
| | ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: |
| | error = None |
| | try: |
| | openapi = server_data.get("openapi", {}) |
| | paths = openapi.get("paths", {}) |
| |
|
| | matching_route = None |
| | for route_path, methods in paths.items(): |
| | for http_method, operation in methods.items(): |
| | if isinstance(operation, dict) and operation.get("operationId") == name: |
| | matching_route = (route_path, methods) |
| | break |
| | if matching_route: |
| | break |
| |
|
| | if not matching_route: |
| | raise Exception(f"No matching route found for operationId: {name}") |
| |
|
| | route_path, methods = matching_route |
| |
|
| | method_entry = None |
| | for http_method, operation in methods.items(): |
| | if operation.get("operationId") == name: |
| | method_entry = (http_method.lower(), operation) |
| | break |
| |
|
| | if not method_entry: |
| | raise Exception(f"No matching method found for operationId: {name}") |
| |
|
| | http_method, operation = method_entry |
| |
|
| | path_params = {} |
| | query_params = {} |
| | body_params = {} |
| |
|
| | for param in operation.get("parameters", []): |
| | param_name = param.get("name") |
| | if not param_name: |
| | continue |
| | param_in = param.get("in") |
| | if param_name in params: |
| | if param_in == "path": |
| | path_params[param_name] = params[param_name] |
| | elif param_in == "query": |
| | query_params[param_name] = params[param_name] |
| |
|
| | final_url = f"{url}{route_path}" |
| | for key, value in path_params.items(): |
| | final_url = final_url.replace(f"{{{key}}}", str(value)) |
| |
|
| | if query_params: |
| | query_string = "&".join(f"{k}={v}" for k, v in query_params.items()) |
| | final_url = f"{final_url}?{query_string}" |
| |
|
| | if operation.get("requestBody", {}).get("content"): |
| | if params: |
| | body_params = params |
| |
|
| | async with aiohttp.ClientSession( |
| | trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) |
| | ) as session: |
| | request_method = getattr(session, http_method.lower()) |
| |
|
| | if http_method in ["post", "put", "patch", "delete"]: |
| | async with request_method( |
| | final_url, |
| | json=body_params, |
| | headers=headers, |
| | cookies=cookies, |
| | ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, |
| | allow_redirects=False, |
| | ) as response: |
| | if response.status >= 400: |
| | text = await response.text() |
| | raise Exception(f"HTTP error {response.status}: {text}") |
| |
|
| | try: |
| | response_data = await response.json() |
| | except Exception: |
| | response_data = await response.text() |
| |
|
| | response_headers = response.headers |
| | return (response_data, response_headers) |
| | else: |
| | async with request_method( |
| | final_url, |
| | headers=headers, |
| | cookies=cookies, |
| | ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, |
| | allow_redirects=False, |
| | ) as response: |
| | if response.status >= 400: |
| | text = await response.text() |
| | raise Exception(f"HTTP error {response.status}: {text}") |
| |
|
| | try: |
| | response_data = await response.json() |
| | except Exception: |
| | response_data = await response.text() |
| |
|
| | response_headers = response.headers |
| | return (response_data, response_headers) |
| |
|
| | except Exception as err: |
| | error = str(err) |
| | log.exception(f"API Request Error: {error}") |
| | return ({"error": error}, None) |
| |
|
| |
|
| | def get_tool_server_url(url: Optional[str], path: str) -> str: |
| | """ |
| | Build the full URL for a tool server, given a base url and a path. |
| | """ |
| | if "://" in path: |
| | |
| | return path |
| | if not path.startswith("/"): |
| | |
| | path = f"/{path}" |
| | return f"{url}{path}" |
| |
|