Spaces:
Paused
Paused
| import json | |
| from os import getenv | |
| from typing import Any | |
| from urllib.parse import urlencode | |
| import httpx | |
| from core.helper import ssrf_proxy | |
| from core.tools.entities.tool_bundle import ApiToolBundle | |
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType | |
| from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError | |
| from core.tools.tool.tool import Tool | |
| API_TOOL_DEFAULT_TIMEOUT = ( | |
| int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), | |
| int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")), | |
| ) | |
| class ApiTool(Tool): | |
| api_bundle: ApiToolBundle | |
| """ | |
| Api tool | |
| """ | |
| def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": | |
| """ | |
| fork a new tool with meta data | |
| :param meta: the meta data of a tool call processing, tenant_id is required | |
| :return: the new tool | |
| """ | |
| return self.__class__( | |
| identity=self.identity.model_copy() if self.identity else None, | |
| parameters=self.parameters.copy() if self.parameters else None, | |
| description=self.description.model_copy() if self.description else None, | |
| api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, | |
| runtime=Tool.Runtime(**runtime), | |
| ) | |
| def validate_credentials( | |
| self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False | |
| ) -> str: | |
| """ | |
| validate the credentials for Api tool | |
| """ | |
| # assemble validate request and request parameters | |
| headers = self.assembling_request(parameters) | |
| if format_only: | |
| return "" | |
| response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) | |
| # validate response | |
| return self.validate_and_parse_response(response) | |
| def tool_provider_type(self) -> ToolProviderType: | |
| return ToolProviderType.API | |
| def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: | |
| headers = {} | |
| credentials = self.runtime.credentials or {} | |
| if "auth_type" not in credentials: | |
| raise ToolProviderCredentialValidationError("Missing auth_type") | |
| if credentials["auth_type"] == "api_key": | |
| api_key_header = "api_key" | |
| if "api_key_header" in credentials: | |
| api_key_header = credentials["api_key_header"] | |
| if "api_key_value" not in credentials: | |
| raise ToolProviderCredentialValidationError("Missing api_key_value") | |
| elif not isinstance(credentials["api_key_value"], str): | |
| raise ToolProviderCredentialValidationError("api_key_value must be a string") | |
| if "api_key_header_prefix" in credentials: | |
| api_key_header_prefix = credentials["api_key_header_prefix"] | |
| if api_key_header_prefix == "basic" and credentials["api_key_value"]: | |
| credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}' | |
| elif api_key_header_prefix == "bearer" and credentials["api_key_value"]: | |
| credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}' | |
| elif api_key_header_prefix == "custom": | |
| pass | |
| headers[api_key_header] = credentials["api_key_value"] | |
| needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] | |
| for parameter in needed_parameters: | |
| if parameter.required and parameter.name not in parameters: | |
| raise ToolParameterValidationError(f"Missing required parameter {parameter.name}") | |
| if parameter.default is not None and parameter.name not in parameters: | |
| parameters[parameter.name] = parameter.default | |
| return headers | |
| def validate_and_parse_response(self, response: httpx.Response) -> str: | |
| """ | |
| validate the response | |
| """ | |
| if isinstance(response, httpx.Response): | |
| if response.status_code >= 400: | |
| raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") | |
| if not response.content: | |
| return "Empty response from the tool, please check your parameters and try again." | |
| try: | |
| response = response.json() | |
| try: | |
| return json.dumps(response, ensure_ascii=False) | |
| except Exception as e: | |
| return json.dumps(response) | |
| except Exception as e: | |
| return response.text | |
| else: | |
| raise ValueError(f"Invalid response type {type(response)}") | |
| def get_parameter_value(parameter, parameters): | |
| if parameter["name"] in parameters: | |
| return parameters[parameter["name"]] | |
| elif parameter.get("required", False): | |
| raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}") | |
| else: | |
| return (parameter.get("schema", {}) or {}).get("default", "") | |
| def do_http_request( | |
| self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] | |
| ) -> httpx.Response: | |
| """ | |
| do http request depending on api bundle | |
| """ | |
| method = method.lower() | |
| params = {} | |
| path_params = {} | |
| body = {} | |
| cookies = {} | |
| # check parameters | |
| for parameter in self.api_bundle.openapi.get("parameters", []): | |
| value = self.get_parameter_value(parameter, parameters) | |
| if parameter["in"] == "path": | |
| path_params[parameter["name"]] = value | |
| elif parameter["in"] == "query": | |
| if value != "": | |
| params[parameter["name"]] = value | |
| elif parameter["in"] == "cookie": | |
| cookies[parameter["name"]] = value | |
| elif parameter["in"] == "header": | |
| headers[parameter["name"]] = value | |
| # check if there is a request body and handle it | |
| if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None: | |
| # handle json request body | |
| if "content" in self.api_bundle.openapi["requestBody"]: | |
| for content_type in self.api_bundle.openapi["requestBody"]["content"]: | |
| headers["Content-Type"] = content_type | |
| body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"] | |
| required = body_schema.get("required", []) | |
| properties = body_schema.get("properties", {}) | |
| for name, property in properties.items(): | |
| if name in parameters: | |
| # convert type | |
| body[name] = self._convert_body_property_type(property, parameters[name]) | |
| elif name in required: | |
| raise ToolParameterValidationError( | |
| f"Missing required parameter {name} in operation {self.api_bundle.operation_id}" | |
| ) | |
| elif "default" in property: | |
| body[name] = property["default"] | |
| else: | |
| body[name] = None | |
| break | |
| # replace path parameters | |
| for name, value in path_params.items(): | |
| url = url.replace(f"{{{name}}}", f"{value}") | |
| # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored | |
| if "Content-Type" in headers: | |
| if headers["Content-Type"] == "application/json": | |
| body = json.dumps(body) | |
| elif headers["Content-Type"] == "application/x-www-form-urlencoded": | |
| body = urlencode(body) | |
| else: | |
| body = body | |
| if method in {"get", "head", "post", "put", "delete", "patch"}: | |
| response = getattr(ssrf_proxy, method)( | |
| url, | |
| params=params, | |
| headers=headers, | |
| cookies=cookies, | |
| data=body, | |
| timeout=API_TOOL_DEFAULT_TIMEOUT, | |
| follow_redirects=True, | |
| ) | |
| return response | |
| else: | |
| raise ValueError(f"Invalid http method {self.method}") | |
| def _convert_body_property_any_of( | |
| self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 | |
| ) -> Any: | |
| if max_recursive <= 0: | |
| raise Exception("Max recursion depth reached") | |
| for option in any_of or []: | |
| try: | |
| if "type" in option: | |
| # Attempt to convert the value based on the type. | |
| if option["type"] == "integer" or option["type"] == "int": | |
| return int(value) | |
| elif option["type"] == "number": | |
| if "." in str(value): | |
| return float(value) | |
| else: | |
| return int(value) | |
| elif option["type"] == "string": | |
| return str(value) | |
| elif option["type"] == "boolean": | |
| if str(value).lower() in {"true", "1"}: | |
| return True | |
| elif str(value).lower() in {"false", "0"}: | |
| return False | |
| else: | |
| continue # Not a boolean, try next option | |
| elif option["type"] == "null" and not value: | |
| return None | |
| else: | |
| continue # Unsupported type, try next option | |
| elif "anyOf" in option and isinstance(option["anyOf"], list): | |
| # Recursive call to handle nested anyOf | |
| return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1) | |
| except ValueError: | |
| continue # Conversion failed, try next option | |
| # If no option succeeded, you might want to return the value as is or raise an error | |
| return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf") | |
| def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any: | |
| try: | |
| if "type" in property: | |
| if property["type"] == "integer" or property["type"] == "int": | |
| return int(value) | |
| elif property["type"] == "number": | |
| # check if it is a float | |
| if "." in str(value): | |
| return float(value) | |
| else: | |
| return int(value) | |
| elif property["type"] == "string": | |
| return str(value) | |
| elif property["type"] == "boolean": | |
| return bool(value) | |
| elif property["type"] == "null": | |
| if value is None: | |
| return None | |
| elif property["type"] == "object" or property["type"] == "array": | |
| if isinstance(value, str): | |
| try: | |
| # an array str like '[1,2]' also can convert to list [1,2] through json.loads | |
| # json not support single quote, but we can support it | |
| value = value.replace("'", '"') | |
| return json.loads(value) | |
| except ValueError: | |
| return value | |
| elif isinstance(value, dict): | |
| return value | |
| else: | |
| return value | |
| else: | |
| raise ValueError(f"Invalid type {property['type']} for property {property}") | |
| elif "anyOf" in property and isinstance(property["anyOf"], list): | |
| return self._convert_body_property_any_of(property, value, property["anyOf"]) | |
| except ValueError as e: | |
| return value | |
| def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: | |
| """ | |
| invoke http request | |
| """ | |
| # assemble request | |
| headers = self.assembling_request(tool_parameters) | |
| # do http request | |
| response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters) | |
| # validate response | |
| response = self.validate_and_parse_response(response) | |
| # assemble invoke message | |
| return self.create_text_message(response) | |