Spaces:
Build error
Build error
| import base64 | |
| import inspect | |
| import logging | |
| import re | |
| import inspect | |
| import aiohttp | |
| import asyncio | |
| import yaml | |
| import json | |
| from pydantic import BaseModel | |
| from pydantic.fields import FieldInfo | |
| from typing import ( | |
| Any, | |
| Awaitable, | |
| Callable, | |
| get_type_hints, | |
| get_args, | |
| get_origin, | |
| Dict, | |
| List, | |
| Tuple, | |
| Union, | |
| Optional, | |
| Type, | |
| ) | |
| from functools import update_wrapper, partial | |
| from fastapi import Request | |
| from pydantic import BaseModel, Field, create_model | |
| from langchain_core.utils.function_calling import ( | |
| convert_to_openai_function as convert_pydantic_model_to_openai_function_spec, | |
| ) | |
| from open_webui.utils.misc import is_string_allowed | |
| from open_webui.models.tools import Tools | |
| from open_webui.models.users import UserModel | |
| from open_webui.models.groups import Groups | |
| from open_webui.models.access_grants import AccessGrants | |
| from open_webui.utils.plugin import load_tool_module_by_id | |
| from open_webui.utils.access_control import has_access, has_connection_access | |
| from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL | |
| from open_webui.env import ( | |
| AIOHTTP_CLIENT_SESSION_SSL, | |
| AIOHTTP_CLIENT_TIMEOUT, | |
| AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER, | |
| AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA, | |
| AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, | |
| ENABLE_FORWARD_USER_INFO_HEADERS, | |
| FORWARD_SESSION_INFO_HEADER_CHAT_ID, | |
| FORWARD_SESSION_INFO_HEADER_MESSAGE_ID, | |
| REDIS_KEY_PREFIX, | |
| ) | |
| from open_webui.utils.headers import include_user_info_headers | |
| from open_webui.tools.builtin import ( | |
| search_web, | |
| fetch_url, | |
| generate_image, | |
| edit_image, | |
| execute_code, | |
| search_memories, | |
| add_memory, | |
| replace_memory_content, | |
| delete_memory, | |
| list_memories, | |
| get_current_timestamp, | |
| calculate_timestamp, | |
| search_notes, | |
| search_chats, | |
| search_channels, | |
| search_channel_messages, | |
| view_note, | |
| view_chat, | |
| view_channel_message, | |
| view_channel_thread, | |
| replace_note_content, | |
| write_note, | |
| list_knowledge_bases, | |
| search_knowledge_bases, | |
| query_knowledge_bases, | |
| search_knowledge_files, | |
| query_knowledge_files, | |
| list_knowledge, | |
| view_file, | |
| view_knowledge_file, | |
| view_skill, | |
| create_tasks, | |
| update_task, | |
| create_automation, | |
| update_automation, | |
| list_automations, | |
| toggle_automation, | |
| delete_automation, | |
| search_calendar_events, | |
| create_calendar_event, | |
| update_calendar_event, | |
| delete_calendar_event, | |
| ) | |
| import copy | |
| from open_webui.utils.access_control import has_permission | |
| log = logging.getLogger(__name__) | |
| # Let no function be called without need, and let what | |
| # it yields justify the cost of running it. | |
| async def get_async_tool_function_and_apply_extra_params( | |
| function: Callable, extra_params: dict | |
| ) -> Callable[..., Awaitable]: | |
| sig = inspect.signature(function) | |
| extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters} | |
| partial_func = partial(function, **extra_params) | |
| # Remove the 'frozen' keyword arguments from the signature | |
| # python-genai uses the signature to infer the tool properties for native function calling | |
| parameters = [] | |
| for name, parameter in sig.parameters.items(): | |
| # Exclude keyword arguments that are frozen | |
| if name in extra_params: | |
| continue | |
| # Keep remaining parameters | |
| parameters.append(parameter) | |
| new_sig = inspect.Signature(parameters=parameters, return_annotation=sig.return_annotation) | |
| if inspect.iscoroutinefunction(function): | |
| # wrap the functools.partial as python-genai has trouble with it | |
| # https://github.com/googleapis/python-genai/issues/907 | |
| async def new_function(*args, **kwargs): | |
| return await partial_func(*args, **kwargs) | |
| else: | |
| # Make it a coroutine function when it is not already | |
| async def new_function(*args, **kwargs): | |
| return partial_func(*args, **kwargs) | |
| update_wrapper(new_function, function) | |
| new_function.__signature__ = new_sig | |
| new_function.__function__ = function # type: ignore | |
| new_function.__extra_params__ = extra_params # type: ignore | |
| return new_function | |
| async def get_updated_tool_function(function: Callable, extra_params: dict): | |
| # Get the original function and merge updated params | |
| __function__ = getattr(function, '__function__', None) | |
| __extra_params__ = getattr(function, '__extra_params__', None) | |
| if __function__ is not None and __extra_params__ is not None: | |
| return await get_async_tool_function_and_apply_extra_params( | |
| __function__, | |
| {**__extra_params__, **extra_params}, | |
| ) | |
| return function | |
| async def get_tools(request: Request, tool_ids: list[str], user: UserModel, extra_params: dict) -> dict[str, dict]: | |
| """Load tools for the given tool_ids, checking access control.""" | |
| if not tool_ids: | |
| return {} | |
| tools_dict = {} | |
| # Get user's group memberships for access control checks | |
| user_group_ids = {group.id for group in await Groups.get_groups_by_member_id(user.id)} | |
| for tool_id in tool_ids: | |
| tool = await Tools.get_tool_by_id(tool_id) | |
| if tool: | |
| # Check access control for local tools | |
| if ( | |
| not (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) | |
| and tool.user_id != user.id | |
| and not await AccessGrants.has_access( | |
| user_id=user.id, | |
| resource_type='tool', | |
| resource_id=tool.id, | |
| permission='read', | |
| user_group_ids=user_group_ids, | |
| ) | |
| ): | |
| log.warning(f'Access denied to tool {tool_id} for user {user.id}') | |
| continue | |
| module = request.app.state.TOOLS.get(tool_id, None) | |
| if module is None: | |
| module, _ = await load_tool_module_by_id(tool_id) | |
| request.app.state.TOOLS[tool_id] = module | |
| __user__ = { | |
| **extra_params['__user__'], | |
| } | |
| # Set valves for the tool | |
| if hasattr(module, 'valves') and hasattr(module, 'Valves'): | |
| valves = await Tools.get_tool_valves_by_id(tool_id) or {} | |
| module.valves = module.Valves(**valves) | |
| if hasattr(module, 'UserValves'): | |
| __user__['valves'] = module.UserValves( # type: ignore | |
| **await Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) | |
| ) | |
| for spec in tool.specs: | |
| # TODO: Fix hack for OpenAI API | |
| # Some times breaks OpenAI but others don't. Leaving the comment | |
| for val in spec.get('parameters', {}).get('properties', {}).values(): | |
| if val.get('type') == 'str': | |
| val['type'] = 'string' | |
| # Remove internal reserved parameters (e.g. __id__, __user__) | |
| spec['parameters']['properties'] = { | |
| key: val for key, val in spec['parameters']['properties'].items() if not key.startswith('__') | |
| } | |
| # convert to function that takes only model params and inserts custom params | |
| function_name = spec['name'] | |
| tool_function = getattr(module, function_name) | |
| callable = await get_async_tool_function_and_apply_extra_params( | |
| tool_function, | |
| { | |
| **extra_params, | |
| '__id__': tool_id, | |
| '__user__': __user__, | |
| }, | |
| ) | |
| # TODO: Support Pydantic models as parameters | |
| if callable.__doc__ and callable.__doc__.strip() != '': | |
| s = re.split(':(param|return)', callable.__doc__, 1) | |
| spec['description'] = s[0] | |
| else: | |
| spec['description'] = function_name | |
| tool_dict = { | |
| 'tool_id': tool_id, | |
| 'callable': callable, | |
| 'spec': spec, | |
| # Misc info | |
| 'metadata': { | |
| 'file_handler': hasattr(module, 'file_handler') and module.file_handler, | |
| 'citation': hasattr(module, 'citation') and module.citation, | |
| }, | |
| } | |
| # Handle function name collisions | |
| while function_name in tools_dict: | |
| log.warning(f'Tool {function_name} already exists in another tools!') | |
| # Prepend tool ID to function name | |
| function_name = f'{tool_id}_{function_name}' | |
| tools_dict[function_name] = tool_dict | |
| else: | |
| if tool_id.startswith('server:'): | |
| splits = tool_id.split(':') | |
| if len(splits) == 2: | |
| type = 'openapi' | |
| server_id = splits[1] | |
| elif len(splits) == 3: | |
| type = splits[1] | |
| server_id = splits[2] | |
| server_id_splits = server_id.split('|') | |
| if len(server_id_splits) == 2: | |
| server_id = server_id_splits[0] | |
| function_names = server_id_splits[1].split(',') | |
| if type == 'openapi': | |
| tool_server_data = None | |
| for server in await get_tool_servers(request): | |
| if server['id'] == server_id: | |
| tool_server_data = server | |
| break | |
| if tool_server_data is None: | |
| log.warning(f'Tool server data not found for {server_id}') | |
| continue | |
| tool_server_idx = tool_server_data.get('idx', 0) | |
| connections = request.app.state.config.TOOL_SERVER_CONNECTIONS | |
| if tool_server_idx >= len(connections): | |
| log.warning( | |
| f'Tool server index {tool_server_idx} out of range ' | |
| f'(have {len(connections)} connections), skipping server {server_id}' | |
| ) | |
| continue | |
| tool_server_connection = connections[tool_server_idx] | |
| # Check access control for tool server | |
| if not await has_connection_access(user, tool_server_connection, user_group_ids): | |
| log.warning(f'Access denied to tool server {server_id} for user {user.id}') | |
| continue | |
| specs = tool_server_data.get('specs', []) | |
| function_name_filter_list = tool_server_connection.get('config', {}).get( | |
| 'function_name_filter_list', '' | |
| ) | |
| if isinstance(function_name_filter_list, str): | |
| function_name_filter_list = function_name_filter_list.split(',') | |
| for spec in specs: | |
| function_name = spec['name'] | |
| if function_name_filter_list: | |
| if not is_string_allowed(function_name, function_name_filter_list): | |
| # Skip this function | |
| continue | |
| auth_type = tool_server_connection.get('auth_type', 'bearer') | |
| cookies = {} | |
| headers = { | |
| 'Content-Type': 'application/json', | |
| } | |
| if auth_type == 'bearer': | |
| headers['Authorization'] = f'Bearer {tool_server_connection.get("key", "")}' | |
| elif auth_type == 'none': | |
| # No authentication | |
| pass | |
| elif auth_type == 'session': | |
| cookies = request.cookies | |
| headers['Authorization'] = f'Bearer {request.state.token.credentials}' | |
| elif auth_type == 'system_oauth': | |
| cookies = request.cookies | |
| oauth_token = extra_params.get('__oauth_token__', None) | |
| if oauth_token: | |
| headers['Authorization'] = f'Bearer {oauth_token.get("access_token", "")}' | |
| connection_headers = tool_server_connection.get('headers', None) | |
| if connection_headers and isinstance(connection_headers, dict): | |
| for key, value in connection_headers.items(): | |
| headers[key] = value | |
| # Add user info headers if enabled | |
| if ENABLE_FORWARD_USER_INFO_HEADERS and user: | |
| headers = include_user_info_headers(headers, user) | |
| metadata = extra_params.get('__metadata__', {}) | |
| if metadata and metadata.get('chat_id'): | |
| headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get('chat_id') | |
| if metadata and metadata.get('message_id'): | |
| headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = metadata.get('message_id') | |
| async def make_tool_function(function_name, tool_server_data, headers): | |
| async def tool_function(**kwargs): | |
| return await execute_tool_server( | |
| url=tool_server_data['url'], | |
| headers=headers, | |
| cookies=cookies, | |
| name=function_name, | |
| params=kwargs, | |
| server_data=tool_server_data, | |
| ) | |
| return tool_function | |
| tool_function = await make_tool_function(function_name, tool_server_data, headers) | |
| callable = await get_async_tool_function_and_apply_extra_params( | |
| tool_function, | |
| {}, | |
| ) | |
| tool_dict = { | |
| 'tool_id': tool_id, | |
| 'callable': callable, | |
| 'spec': clean_openai_tool_schema(spec), | |
| # Misc info | |
| 'type': 'external', | |
| } | |
| # Handle function name collisions | |
| while function_name in tools_dict: | |
| log.warning(f'Tool {function_name} already exists in another tools!') | |
| # Prepend server ID to function name | |
| function_name = f'{server_id}_{function_name}' | |
| tools_dict[function_name] = tool_dict | |
| else: | |
| continue | |
| return tools_dict | |
| async def get_builtin_tools( | |
| request: Request, extra_params: dict, features: dict = None, model: dict = None | |
| ) -> dict[str, dict]: | |
| """ | |
| Get built-in tools for native function calling. | |
| Only returns tools when BOTH the global config is enabled AND the model capability allows it. | |
| """ | |
| tools_dict = {} | |
| builtin_functions = [] | |
| features = features or {} | |
| model = model or {} | |
| # Helper to get model capabilities (defaults to True if not specified) | |
| def get_model_capability(name: str, default: bool = True) -> bool: | |
| return (model.get('info', {}).get('meta', {}).get('capabilities') or {}).get(name, default) | |
| # Helper to check if a builtin tool category is enabled via meta.builtinTools | |
| # Defaults to True if not specified (backward compatible) | |
| def is_builtin_tool_enabled(category: str) -> bool: | |
| builtin_tools = model.get('info', {}).get('meta', {}).get('builtinTools', {}) | |
| return builtin_tools.get(category, True) | |
| # Helper to check user-level feature permission (admins always pass) | |
| user = extra_params.get('__user__', {}) | |
| async def has_user_permission(feature_key: str) -> bool: | |
| if user.get('role') == 'admin': | |
| return True | |
| return await has_permission( | |
| user.get('id', ''), | |
| f'features.{feature_key}', | |
| request.app.state.config.USER_PERMISSIONS, | |
| ) | |
| # Time utilities - available for date calculations | |
| if is_builtin_tool_enabled('time'): | |
| builtin_functions.extend([get_current_timestamp, calculate_timestamp]) | |
| # Knowledge base tools - conditional injection based on model knowledge | |
| # If model has attached knowledge (any type), only provide query_knowledge_files | |
| # Otherwise, provide all KB browsing tools | |
| model_knowledge = model.get('info', {}).get('meta', {}).get('knowledge', []) | |
| # Merge folder-attached knowledge so builtin tools can search it | |
| folder_knowledge = extra_params.get('__metadata__', {}).get('folder_knowledge') | |
| if folder_knowledge: | |
| model_knowledge = list(model_knowledge or []) + list(folder_knowledge) | |
| if is_builtin_tool_enabled('knowledge'): | |
| if model_knowledge: | |
| # Model has attached knowledge - provide discovery, search and semantic tools | |
| builtin_functions.append(list_knowledge) | |
| builtin_functions.append(search_knowledge_files) | |
| builtin_functions.append(query_knowledge_files) | |
| knowledge_types = {item.get('type') for item in model_knowledge} | |
| if 'file' in knowledge_types or 'collection' in knowledge_types: | |
| builtin_functions.append(view_file) | |
| builtin_functions.append(view_knowledge_file) | |
| if 'note' in knowledge_types: | |
| builtin_functions.append(view_note) | |
| else: | |
| # No model knowledge - allow full KB browsing | |
| builtin_functions.extend( | |
| [ | |
| list_knowledge_bases, | |
| search_knowledge_bases, | |
| query_knowledge_bases, | |
| search_knowledge_files, | |
| query_knowledge_files, | |
| view_knowledge_file, | |
| ] | |
| ) | |
| # Chats tools - search and fetch user's chat history | |
| if is_builtin_tool_enabled('chats'): | |
| builtin_functions.extend([search_chats, view_chat]) | |
| # Add memory tools if builtin category enabled AND enabled for this chat | |
| if ( | |
| is_builtin_tool_enabled('memory') | |
| and (features.get('memory') or get_model_capability('memory', False)) | |
| and await has_user_permission('memories') | |
| ): | |
| builtin_functions.extend( | |
| [ | |
| search_memories, | |
| add_memory, | |
| replace_memory_content, | |
| delete_memory, | |
| list_memories, | |
| ] | |
| ) | |
| # Add web search tools if builtin category enabled AND enabled globally AND model has web_search capability | |
| if ( | |
| is_builtin_tool_enabled('web_search') | |
| and getattr(request.app.state.config, 'ENABLE_WEB_SEARCH', False) | |
| and get_model_capability('web_search') | |
| and features.get('web_search') | |
| and await has_user_permission('web_search') | |
| ): | |
| builtin_functions.extend([search_web, fetch_url]) | |
| # Add image generation/edit tools if builtin category enabled AND enabled globally AND model has image_generation capability | |
| if ( | |
| is_builtin_tool_enabled('image_generation') | |
| and getattr(request.app.state.config, 'ENABLE_IMAGE_GENERATION', False) | |
| and get_model_capability('image_generation') | |
| and features.get('image_generation') | |
| and await has_user_permission('image_generation') | |
| ): | |
| builtin_functions.append(generate_image) | |
| if ( | |
| is_builtin_tool_enabled('image_generation') | |
| and getattr(request.app.state.config, 'ENABLE_IMAGE_EDIT', False) | |
| and get_model_capability('image_generation') | |
| and features.get('image_generation') | |
| and await has_user_permission('image_generation') | |
| ): | |
| builtin_functions.append(edit_image) | |
| # Add code interpreter tool if builtin category enabled AND enabled globally AND model has code_interpreter capability | |
| if ( | |
| is_builtin_tool_enabled('code_interpreter') | |
| and getattr(request.app.state.config, 'ENABLE_CODE_INTERPRETER', True) | |
| and get_model_capability('code_interpreter') | |
| and features.get('code_interpreter') | |
| and await has_user_permission('code_interpreter') | |
| ): | |
| builtin_functions.append(execute_code) | |
| # Notes tools - search, view, create, and update user's notes | |
| if ( | |
| is_builtin_tool_enabled('notes') | |
| and getattr(request.app.state.config, 'ENABLE_NOTES', False) | |
| and await has_user_permission('notes') | |
| ): | |
| builtin_functions.extend([search_notes, view_note, write_note, replace_note_content]) | |
| # Channels tools - search channels and messages | |
| if ( | |
| is_builtin_tool_enabled('channels') | |
| and getattr(request.app.state.config, 'ENABLE_CHANNELS', False) | |
| and await has_user_permission('channels') | |
| ): | |
| builtin_functions.extend( | |
| [ | |
| search_channels, | |
| search_channel_messages, | |
| view_channel_thread, | |
| view_channel_message, | |
| ] | |
| ) | |
| # Skills tools - view_skill allows model to load full skill instructions on demand | |
| if extra_params.get('__skill_ids__'): | |
| builtin_functions.append(view_skill) | |
| # Task management - break down complex work into trackable steps | |
| if is_builtin_tool_enabled('tasks'): | |
| builtin_functions.extend([create_tasks, update_task]) | |
| # Automation tools - create and manage scheduled automations from chat | |
| if ( | |
| is_builtin_tool_enabled('automations') | |
| and getattr(request.app.state.config, 'ENABLE_AUTOMATIONS', False) | |
| and await has_user_permission('automations') | |
| ): | |
| builtin_functions.extend( | |
| [create_automation, update_automation, list_automations, toggle_automation, delete_automation] | |
| ) | |
| # Calendar tools - search/create/update/delete events | |
| if ( | |
| is_builtin_tool_enabled('calendar') | |
| and getattr(request.app.state.config, 'ENABLE_CALENDAR', False) | |
| and await has_user_permission('calendar') | |
| ): | |
| builtin_functions.extend( | |
| [search_calendar_events, create_calendar_event, update_calendar_event, delete_calendar_event] | |
| ) | |
| for func in builtin_functions: | |
| callable = await get_async_tool_function_and_apply_extra_params( | |
| func, | |
| { | |
| '__request__': request, | |
| '__user__': extra_params.get('__user__', {}), | |
| '__event_emitter__': extra_params.get('__event_emitter__'), | |
| '__event_call__': extra_params.get('__event_call__'), | |
| '__metadata__': extra_params.get('__metadata__'), | |
| '__chat_id__': extra_params.get('__chat_id__'), | |
| '__message_id__': extra_params.get('__message_id__'), | |
| '__model_knowledge__': model_knowledge, | |
| }, | |
| ) | |
| # Generate spec from function | |
| pydantic_model = convert_function_to_pydantic_model(func) | |
| spec = convert_pydantic_model_to_openai_function_spec(pydantic_model) | |
| spec = clean_openai_tool_schema(spec) | |
| tools_dict[func.__name__] = { | |
| 'tool_id': f'builtin:{func.__name__}', | |
| 'callable': callable, | |
| 'spec': spec, | |
| 'type': 'builtin', | |
| } | |
| return tools_dict | |
| def parse_description(docstring: str | None) -> str: | |
| """ | |
| Parse a function's docstring to extract the description. | |
| Args: | |
| docstring (str): The docstring to parse. | |
| Returns: | |
| str: The description. | |
| """ | |
| if not docstring: | |
| return '' | |
| lines = [line.strip() for line in docstring.strip().split('\n')] | |
| description_lines: list[str] = [] | |
| for line in lines: | |
| if re.match(r':param', line) or re.match(r':return', line): | |
| break | |
| description_lines.append(line) | |
| return '\n'.join(description_lines) | |
| def parse_docstring(docstring): | |
| """ | |
| Parse a function's docstring to extract parameter descriptions in reST format. | |
| Args: | |
| docstring (str): The docstring to parse. | |
| Returns: | |
| dict: A dictionary where keys are parameter names and values are descriptions. | |
| """ | |
| if not docstring: | |
| return {} | |
| # Regex to match `:param name: description` format | |
| param_pattern = re.compile(r':param (\w+):\s*(.+)') | |
| param_descriptions = {} | |
| for line in docstring.splitlines(): | |
| match = param_pattern.match(line.strip()) | |
| if not match: | |
| continue | |
| param_name, param_description = match.groups() | |
| if param_name.startswith('__'): | |
| continue | |
| param_descriptions[param_name] = param_description | |
| return param_descriptions | |
| def convert_function_to_pydantic_model(func: Callable) -> type[BaseModel]: | |
| """ | |
| Converts a Python function's type hints and docstring to a Pydantic model, | |
| including support for nested types, default values, and descriptions. | |
| Args: | |
| func: The function whose type hints and docstring should be converted. | |
| model_name: The name of the generated Pydantic model. | |
| Returns: | |
| A Pydantic model class. | |
| """ | |
| type_hints = get_type_hints(func) | |
| signature = inspect.signature(func) | |
| parameters = signature.parameters | |
| docstring = func.__doc__ | |
| function_description = parse_description(docstring) | |
| function_param_descriptions = parse_docstring(docstring) | |
| field_defs = {} | |
| for name, param in parameters.items(): | |
| type_hint = type_hints.get(name, Any) | |
| default_value = param.default if param.default is not param.empty else ... | |
| param_description = function_param_descriptions.get(name, None) | |
| if param_description: | |
| field_defs[name] = ( | |
| type_hint, | |
| Field(default_value, description=param_description), | |
| ) | |
| else: | |
| field_defs[name] = type_hint, default_value | |
| model = create_model(func.__name__, **field_defs) | |
| model.__doc__ = function_description | |
| return model | |
| def clean_properties(schema: dict): | |
| if not isinstance(schema, dict): | |
| return | |
| if 'anyOf' in schema: | |
| non_null_types = [t for t in schema['anyOf'] if t.get('type') != 'null'] | |
| if len(non_null_types) == 1: | |
| schema.update(non_null_types[0]) | |
| del schema['anyOf'] | |
| else: | |
| schema['anyOf'] = non_null_types | |
| if 'default' in schema and schema['default'] is None: | |
| del schema['default'] | |
| # fix missing type | |
| if 'type' not in schema and 'anyOf' not in schema and 'properties' not in schema: | |
| schema['type'] = 'string' | |
| if 'properties' in schema: | |
| for prop_name, prop_schema in schema['properties'].items(): | |
| clean_properties(prop_schema) | |
| if 'items' in schema: | |
| clean_properties(schema['items']) | |
| def clean_openai_tool_schema(spec: dict) -> dict: | |
| import copy | |
| cleaned_spec = copy.deepcopy(spec) | |
| if 'parameters' in cleaned_spec: | |
| clean_properties(cleaned_spec['parameters']) | |
| return cleaned_spec | |
| def get_functions_from_tool(tool: object) -> list[Callable]: | |
| return [ | |
| getattr(tool, func) | |
| for func in dir(tool) | |
| if callable(getattr(tool, func)) # checks if the attribute is callable (a method or function). | |
| and not func.startswith('_') # filters out internal methods (starting with _) and special (dunder) methods. | |
| and not inspect.isclass( | |
| getattr(tool, func) | |
| ) # ensures that the callable is not a class itself, just a method or function. | |
| ] | |
| def get_tool_specs(tool_module: object) -> list[dict]: | |
| function_models = map(convert_function_to_pydantic_model, get_functions_from_tool(tool_module)) | |
| specs = [ | |
| clean_openai_tool_schema(convert_pydantic_model_to_openai_function_spec(function_model)) | |
| for function_model in function_models | |
| ] | |
| return specs | |
| def resolve_schema(schema, components, resolved_schemas=None): | |
| """ | |
| Recursively resolves a JSON schema using OpenAPI components. | |
| """ | |
| if not schema: | |
| return {} | |
| if resolved_schemas is None: | |
| resolved_schemas = set() | |
| if '$ref' in schema: | |
| ref_path = schema['$ref'] | |
| schema_name = ref_path.split('/')[-1] | |
| if schema_name in resolved_schemas: | |
| # Avoid infinite recursion on circular references | |
| return {} | |
| resolved_schemas.add(schema_name) | |
| ref_parts = ref_path.strip('#/').split('/') | |
| resolved = components | |
| for part in ref_parts[1:]: # Skip the initial 'components' | |
| resolved = resolved.get(part, {}) | |
| return resolve_schema(resolved, components, resolved_schemas) | |
| resolved_schema = copy.deepcopy(schema) | |
| # Recursively resolve inner schemas | |
| if 'properties' in resolved_schema: | |
| for prop, prop_schema in resolved_schema['properties'].items(): | |
| resolved_schema['properties'][prop] = resolve_schema(prop_schema, components) | |
| if 'items' in resolved_schema: | |
| resolved_schema['items'] = resolve_schema(resolved_schema['items'], components) | |
| return resolved_schema | |
| def convert_openapi_to_tool_payload(openapi_spec): | |
| """ | |
| Converts an OpenAPI specification into a custom tool payload structure. | |
| Args: | |
| openapi_spec (dict): The OpenAPI specification as a Python dict. | |
| Returns: | |
| list: A list of tool payloads. | |
| """ | |
| tool_payload = [] | |
| for path, methods in openapi_spec.get('paths', {}).items(): | |
| for method, operation in methods.items(): | |
| if operation.get('operationId'): | |
| tool = { | |
| 'name': operation.get('operationId'), | |
| 'description': operation.get( | |
| 'description', | |
| operation.get('summary', 'No description available.'), | |
| ), | |
| 'parameters': {'type': 'object', 'properties': {}, 'required': []}, | |
| } | |
| for param in operation.get('parameters', []): | |
| param_name = param.get('name') | |
| if not param_name: | |
| continue | |
| param_schema = param.get('schema', {}) | |
| description = param_schema.get('description', '') | |
| if not description: | |
| description = param.get('description') or '' | |
| if param_schema.get('enum') and isinstance(param_schema.get('enum'), list): | |
| description += f'. Possible values: {", ".join(str(v) for v in param_schema.get("enum"))}' | |
| param_property = { | |
| 'type': param_schema.get('type') or 'string', | |
| 'description': description, | |
| } | |
| # Include items property for array types (required by OpenAI) | |
| if param_schema.get('type') == 'array' and 'items' in param_schema: | |
| param_property['items'] = param_schema['items'] | |
| # Filter out None values to prevent schema validation errors | |
| param_property = {k: v for k, v in param_property.items() if v is not None} | |
| tool['parameters']['properties'][param_name] = param_property | |
| if param.get('required'): | |
| tool['parameters']['required'].append(param_name) | |
| # Extract and resolve requestBody if available | |
| request_body = operation.get('requestBody') | |
| if request_body: | |
| content = request_body.get('content', {}) | |
| json_schema = content.get('application/json', {}).get('schema') | |
| if json_schema: | |
| resolved_schema = resolve_schema(json_schema, openapi_spec.get('components', {})) | |
| if resolved_schema.get('properties'): | |
| tool['parameters']['properties'].update(resolved_schema['properties']) | |
| if 'required' in resolved_schema: | |
| tool['parameters']['required'] = list( | |
| set(tool['parameters']['required'] + resolved_schema['required']) | |
| ) | |
| elif resolved_schema.get('type') == 'array': | |
| tool['parameters'] = resolved_schema # special case for array | |
| tool_payload.append(tool) | |
| return tool_payload | |
| async def set_tool_servers(request: Request): | |
| request.app.state.TOOL_SERVERS = await get_tool_servers_data(request.app.state.config.TOOL_SERVER_CONNECTIONS) | |
| if request.app.state.redis is not None: | |
| await request.app.state.redis.set( | |
| f'{REDIS_KEY_PREFIX}:tool_servers', json.dumps(request.app.state.TOOL_SERVERS) | |
| ) | |
| return request.app.state.TOOL_SERVERS | |
| async def get_tool_servers(request: Request): | |
| tool_servers = [] | |
| if request.app.state.redis is not None: | |
| try: | |
| tool_servers = json.loads(await request.app.state.redis.get(f'{REDIS_KEY_PREFIX}:tool_servers')) | |
| request.app.state.TOOL_SERVERS = tool_servers | |
| except Exception as e: | |
| log.error(f'Error fetching tool_servers from Redis: {e}') | |
| if not tool_servers: | |
| tool_servers = await set_tool_servers(request) | |
| return tool_servers | |
| async def get_terminal_cwd( | |
| base_url: str, | |
| headers: dict, | |
| cookies: Optional[dict] = None, | |
| ) -> Optional[str]: | |
| """Fetch the current working directory from a terminal server.""" | |
| try: | |
| cwd_url = f'{base_url.rstrip("/")}/files/cwd' | |
| async with aiohttp.ClientSession( | |
| timeout=aiohttp.ClientTimeout(total=5), | |
| trust_env=True, | |
| ) as session: | |
| async with session.get( | |
| cwd_url, headers=headers, cookies=cookies or {}, ssl=AIOHTTP_CLIENT_SESSION_SSL | |
| ) as resp: | |
| if resp.status == 200: | |
| data = await resp.json() | |
| return data.get('cwd') | |
| except Exception as e: | |
| log.debug(f'Failed to fetch terminal CWD: {e}') | |
| return None | |
| async def get_terminal_system_prompt( | |
| base_url: str, | |
| headers: dict, | |
| cookies: Optional[dict] = None, | |
| ) -> Optional[str]: | |
| """Fetch the system prompt from a terminal server. | |
| Checks ``/api/config`` for the ``system`` feature flag first; | |
| only fetches ``/system`` if the flag is present. Returns *None* | |
| silently when the server doesn't support the endpoint. | |
| """ | |
| base = base_url.rstrip('/') | |
| try: | |
| async with aiohttp.ClientSession( | |
| timeout=aiohttp.ClientTimeout(total=3), | |
| trust_env=True, | |
| ) as session: | |
| # 1. Check feature flag | |
| async with session.get(f'{base}/api/config', ssl=AIOHTTP_CLIENT_SESSION_SSL) as resp: | |
| if resp.status != 200: | |
| return None | |
| config = await resp.json() | |
| if not config.get('features', {}).get('system'): | |
| return None | |
| # 2. Fetch system prompt | |
| async with session.get( | |
| f'{base}/system', headers=headers, cookies=cookies or {}, ssl=AIOHTTP_CLIENT_SESSION_SSL | |
| ) as resp: | |
| if resp.status == 200: | |
| data = await resp.json() | |
| return data.get('prompt') | |
| except Exception as e: | |
| log.debug(f'Failed to fetch terminal system prompt: {e}') | |
| return None | |
| async def set_terminal_servers(request: Request): | |
| """Load and cache OpenAPI specs from all TERMINAL_SERVER_CONNECTIONS.""" | |
| connections = request.app.state.config.TERMINAL_SERVER_CONNECTIONS or [] | |
| # Build server configs compatible with get_tool_servers_data | |
| # Terminal connections store id/name at top level; translate to info dict | |
| server_configs = [] | |
| for connection in connections: | |
| if not connection.get('url'): | |
| continue | |
| enabled = connection.get('enabled', True) | |
| base_url = connection.get('url', '').rstrip('/') | |
| policy_id = connection.get('policy_id', '') | |
| # Orchestrator connections route through /p/{policy_id}/ — the | |
| # OpenAPI spec lives on the proxied terminal, not the orchestrator. | |
| if connection.get('server_type') == 'orchestrator' and policy_id: | |
| base_url = f'{base_url}/p/{policy_id}' | |
| server_configs.append( | |
| { | |
| 'url': base_url, | |
| 'key': connection.get('key', ''), | |
| 'auth_type': connection.get('auth_type', 'bearer'), | |
| 'path': connection.get('path', '/openapi.json'), | |
| 'spec_type': 'url', | |
| # get_tool_servers_data reads config.enable to filter active servers | |
| 'config': {'enable': enabled}, | |
| 'info': { | |
| 'id': connection.get('id', ''), | |
| 'name': connection.get('name', ''), | |
| }, | |
| } | |
| ) | |
| request.app.state.TERMINAL_SERVERS = await get_tool_servers_data(server_configs) | |
| # Fetch system prompts concurrently (runs at cache time, not per-request) | |
| connections_by_id = {c.get('id'): c for c in connections if c.get('id')} | |
| async def _fetch_system_prompt(server): | |
| connection = connections_by_id.get(server.get('id')) | |
| if not connection: | |
| return | |
| headers = {} | |
| if connection.get('auth_type', 'bearer') == 'bearer': | |
| headers['Authorization'] = f'Bearer {connection.get("key", "")}' | |
| prompt = await get_terminal_system_prompt(server['url'], headers) | |
| if prompt: | |
| server['system_prompt'] = prompt | |
| await asyncio.gather( | |
| *[_fetch_system_prompt(s) for s in request.app.state.TERMINAL_SERVERS], | |
| return_exceptions=True, | |
| ) | |
| if request.app.state.redis is not None: | |
| await request.app.state.redis.set( | |
| f'{REDIS_KEY_PREFIX}:terminal_servers', json.dumps(request.app.state.TERMINAL_SERVERS) | |
| ) | |
| return request.app.state.TERMINAL_SERVERS | |
| async def get_terminal_servers(request: Request): | |
| """Return cached terminal server specs, loading if needed.""" | |
| terminal_servers = [] | |
| if request.app.state.redis is not None: | |
| try: | |
| terminal_servers = json.loads(await request.app.state.redis.get(f'{REDIS_KEY_PREFIX}:terminal_servers')) | |
| request.app.state.TERMINAL_SERVERS = terminal_servers | |
| except Exception as e: | |
| log.error(f'Error fetching terminal_servers from Redis: {e}') | |
| if not terminal_servers: | |
| terminal_servers = await set_terminal_servers(request) | |
| return terminal_servers | |
| async def get_terminal_tools( | |
| request: Request, | |
| terminal_id: str, | |
| user: UserModel, | |
| extra_params: dict, | |
| ) -> dict[str, dict] | tuple[dict[str, dict], Optional[str]]: | |
| """Resolve tools for a terminal server identified by terminal_id. | |
| - Finds the connection in TERMINAL_SERVER_CONNECTIONS | |
| - Checks access_grants | |
| - Loads specs from cache | |
| - Builds callables that route through the terminal proxy | |
| """ | |
| connections = request.app.state.config.TERMINAL_SERVER_CONNECTIONS or [] | |
| connection = next((c for c in connections if c.get('id') == terminal_id), None) | |
| if connection is None: | |
| log.warning(f'Terminal server not found: {terminal_id}') | |
| return {} | |
| user_group_ids = {group.id for group in await Groups.get_groups_by_member_id(user.id)} | |
| if not await has_connection_access(user, connection, user_group_ids): | |
| log.warning(f'Access denied to terminal {terminal_id} for user {user.id}') | |
| return {} | |
| # Find the cached spec data for this terminal | |
| terminal_servers = await get_terminal_servers(request) | |
| server_data = next((s for s in terminal_servers if s.get('id') == terminal_id), None) | |
| if server_data is None: | |
| log.warning(f'Terminal server spec not found for {terminal_id}') | |
| return {} | |
| specs = server_data.get('specs', []) | |
| if not specs: | |
| return {} | |
| # Build auth headers | |
| auth_type = connection.get('auth_type', 'bearer') | |
| cookies = {} | |
| headers = {'Content-Type': 'application/json', 'X-User-Id': user.id} | |
| if auth_type == 'bearer': | |
| headers['Authorization'] = f'Bearer {connection.get("key", "")}' | |
| elif auth_type == 'session': | |
| cookies = request.cookies | |
| headers['Authorization'] = f'Bearer {request.state.token.credentials}' | |
| elif auth_type == 'system_oauth': | |
| cookies = request.cookies | |
| oauth_token = extra_params.get('__oauth_token__', None) | |
| if oauth_token: | |
| headers['Authorization'] = f'Bearer {oauth_token.get("access_token", "")}' | |
| # auth_type == "none": no Authorization header | |
| system_prompt = server_data.get('system_prompt') | |
| # Use chat_id as the per-session key for cwd tracking | |
| metadata = extra_params.get('__metadata__', {}) | |
| session_id = metadata.get('chat_id') | |
| if session_id: | |
| headers['X-Session-Id'] = session_id | |
| terminal_cwd = await get_terminal_cwd(connection.get('url', ''), headers, cookies) | |
| tools_dict = {} | |
| for spec in specs: | |
| function_name = spec['name'] | |
| tool_spec = clean_openai_tool_schema(spec) | |
| if function_name == 'run_command' and terminal_cwd: | |
| tool_spec['description'] = ( | |
| tool_spec.get('description', '') + f'\n\nThe current working directory is: {terminal_cwd}' | |
| ) | |
| async def make_tool_function(fn_name, srv_data, hdrs, cks): | |
| async def tool_function(**kwargs): | |
| return await execute_tool_server( | |
| url=srv_data['url'], | |
| headers=hdrs, | |
| cookies=cks, | |
| name=fn_name, | |
| params=kwargs, | |
| server_data=srv_data, | |
| ) | |
| return tool_function | |
| tool_function = await make_tool_function(function_name, server_data, headers, cookies) | |
| callable = await get_async_tool_function_and_apply_extra_params(tool_function, {}) | |
| tools_dict[function_name] = { | |
| 'tool_id': f'terminal:{terminal_id}', | |
| 'callable': callable, | |
| 'spec': tool_spec, | |
| 'type': 'terminal', | |
| } | |
| return tools_dict, system_prompt | |
| async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, Any]: | |
| _headers = { | |
| 'Accept': 'application/json', | |
| 'Content-Type': 'application/json', | |
| } | |
| if headers: | |
| _headers.update(headers) | |
| error = None | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA) | |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
| async with session.get(url, headers=_headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL) as response: | |
| if response.status != 200: | |
| error_body = await response.json() | |
| raise Exception(error_body) | |
| text_content = None | |
| # Check if URL ends with .yaml or .yml to determine format | |
| if url.lower().endswith(('.yaml', '.yml')): | |
| text_content = await response.text() | |
| res = yaml.safe_load(text_content) | |
| else: | |
| text_content = await response.text() | |
| try: | |
| res = json.loads(text_content) | |
| except json.JSONDecodeError: | |
| try: | |
| res = yaml.safe_load(text_content) | |
| except Exception as e: | |
| raise e | |
| except Exception as err: | |
| log.exception(f'Could not fetch tool server spec from {url}') | |
| if isinstance(err, dict) and 'detail' in err: | |
| error = err['detail'] | |
| else: | |
| error = str(err) | |
| raise Exception(error) | |
| log.debug(f'Fetched data: {res}') | |
| return res | |
| async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| # Prepare list of enabled servers along with their original index | |
| tasks = [] | |
| server_entries = [] | |
| for idx, server in enumerate(servers): | |
| if server.get('config', {}).get('enable') and server.get('type', 'openapi') == 'openapi': | |
| info = server.get('info', {}) | |
| auth_type = server.get('auth_type', 'bearer') | |
| token = None | |
| if auth_type == 'bearer': | |
| token = server.get('key', '') | |
| elif auth_type == 'none': | |
| # No authentication | |
| pass | |
| id = info.get('id') | |
| if not id: | |
| id = str(idx) | |
| server_url = server.get('url') | |
| spec_type = server.get('spec_type', 'url') | |
| # Create async tasks to fetch data | |
| task = None | |
| if spec_type == 'url': | |
| # Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL | |
| openapi_path = server.get('path', 'openapi.json') | |
| spec_url = get_tool_server_url(server_url, openapi_path) | |
| # Fetch from URL | |
| task = get_tool_server_data( | |
| spec_url, | |
| {'Authorization': f'Bearer {token}'} if token else None, | |
| ) | |
| elif spec_type == 'json' and server.get('spec', ''): | |
| # Use provided JSON spec | |
| spec_json = None | |
| try: | |
| spec_json = json.loads(server.get('spec', '')) | |
| except Exception as e: | |
| log.error(f'Error parsing JSON spec for tool server {id}: {e}') | |
| if spec_json: | |
| task = asyncio.sleep( | |
| 0, | |
| result=spec_json, | |
| ) | |
| if task: | |
| tasks.append(task) | |
| server_entries.append((id, idx, server, server_url, info, token)) | |
| # Execute tasks concurrently | |
| responses = await asyncio.gather(*tasks, return_exceptions=True) | |
| # Build final results with index and server metadata | |
| results = [] | |
| for (id, idx, server, url, info, _), response in zip(server_entries, responses): | |
| if isinstance(response, Exception): | |
| log.error(f'Failed to connect to {url} OpenAPI tool server') | |
| continue | |
| # Guard against invalid or non-OpenAPI specs (e.g., MCP-style configs) | |
| if not isinstance(response, dict) or 'paths' not in response: | |
| log.warning(f"Invalid OpenAPI spec from {url}: missing 'paths'") | |
| continue | |
| response = { | |
| 'openapi': response, | |
| 'info': response.get('info', {}), | |
| 'specs': convert_openapi_to_tool_payload(response), | |
| } | |
| openapi_data = response.get('openapi', {}) | |
| if info and isinstance(openapi_data, dict): | |
| openapi_data['info'] = openapi_data.get('info', {}) | |
| if 'name' in info: | |
| openapi_data['info']['title'] = info.get('name', 'Tool Server') | |
| if 'description' in info: | |
| openapi_data['info']['description'] = info.get('description', '') | |
| results.append( | |
| { | |
| 'id': str(id), | |
| 'idx': idx, | |
| 'url': (server.get('url') or '').rstrip('/'), | |
| 'openapi': openapi_data, | |
| 'info': response.get('info'), | |
| 'specs': response.get('specs'), | |
| } | |
| ) | |
| return results | |
| async def execute_tool_server( | |
| url: str, | |
| headers: Dict[str, str], | |
| cookies: Dict[str, str], | |
| name: str, | |
| params: Dict[str, Any], | |
| server_data: Dict[str, Any], | |
| ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: | |
| error = None | |
| try: | |
| openapi = server_data.get('openapi', {}) | |
| paths = openapi.get('paths', {}) | |
| matching_route = None | |
| for route_path, methods in paths.items(): | |
| for http_method, operation in methods.items(): | |
| if isinstance(operation, dict) and operation.get('operationId') == name: | |
| matching_route = (route_path, methods) | |
| break | |
| if matching_route: | |
| break | |
| if not matching_route: | |
| raise Exception(f'No matching route found for operationId: {name}') | |
| route_path, methods = matching_route | |
| method_entry = None | |
| for http_method, operation in methods.items(): | |
| if operation.get('operationId') == name: | |
| method_entry = (http_method.lower(), operation) | |
| break | |
| if not method_entry: | |
| raise Exception(f'No matching method found for operationId: {name}') | |
| http_method, operation = method_entry | |
| path_params = {} | |
| query_params = {} | |
| body_params = {} | |
| for param in operation.get('parameters', []): | |
| param_name = param.get('name') | |
| if not param_name: | |
| continue | |
| param_in = param.get('in') | |
| if param_name in params: | |
| if param_in == 'path': | |
| path_params[param_name] = params[param_name] | |
| if param_in == 'query': | |
| value = params[param_name] | |
| # Skip empty values for optional params (LLMs sometimes | |
| # pass "" instead of omitting optional parameters). | |
| if value is None or (value == '' and not param.get('required')): | |
| continue | |
| query_params[param_name] = value | |
| final_url = f'{url.rstrip("/")}{route_path}' | |
| for key, value in path_params.items(): | |
| final_url = final_url.replace(f'{{{key}}}', str(value)) | |
| if query_params: | |
| query_string = '&'.join(f'{k}={v}' for k, v in query_params.items()) | |
| final_url = f'{final_url}?{query_string}' | |
| if operation.get('requestBody', {}).get('content'): | |
| if params: | |
| body_params = params | |
| async with aiohttp.ClientSession( | |
| trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER) | |
| ) as session: | |
| request_method = getattr(session, http_method.lower()) | |
| if http_method in ['post', 'put', 'patch', 'delete']: | |
| async with request_method( | |
| final_url, | |
| json=body_params, | |
| headers=headers, | |
| cookies=cookies, | |
| ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, | |
| allow_redirects=False, | |
| ) as response: | |
| if response.status >= 400: | |
| text = await response.text() | |
| raise Exception(f'HTTP error {response.status}: {text}') | |
| try: | |
| response_data = await response.json() | |
| except Exception: | |
| content_type = response.headers.get('Content-Type', '').split(';')[0].strip() | |
| if content_type.startswith('text/') or not content_type: | |
| response_data = await response.text() | |
| else: | |
| raw = await response.read() | |
| b64 = base64.b64encode(raw).decode() | |
| response_data = f'data:{content_type};base64,{b64}' | |
| response_headers = response.headers | |
| return (response_data, response_headers) | |
| else: | |
| async with request_method( | |
| final_url, | |
| headers=headers, | |
| cookies=cookies, | |
| ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, | |
| allow_redirects=False, | |
| ) as response: | |
| if response.status >= 400: | |
| text = await response.text() | |
| raise Exception(f'HTTP error {response.status}: {text}') | |
| try: | |
| response_data = await response.json() | |
| except Exception: | |
| content_type = response.headers.get('Content-Type', '').split(';')[0].strip() | |
| if content_type.startswith('text/') or not content_type: | |
| response_data = await response.text() | |
| else: | |
| raw = await response.read() | |
| b64 = base64.b64encode(raw).decode() | |
| response_data = f'data:{content_type};base64,{b64}' | |
| response_headers = response.headers | |
| return (response_data, response_headers) | |
| except Exception as err: | |
| error = str(err) | |
| log.exception(f'API Request Error: {error}') | |
| return ({'error': error}, None) | |
| def get_tool_server_url(url: Optional[str], path: str) -> str: | |
| """ | |
| Build the full URL for a tool server, given a base url and a path. | |
| """ | |
| if '://' in path: | |
| # If it contains "://", it's a full URL | |
| return path | |
| if url: | |
| url = url.rstrip('/') | |
| if not path.startswith('/'): | |
| # Ensure the path starts with a slash | |
| path = f'/{path}' | |
| return f'{url}{path}' | |