Spaces:
Running
Running
| import os | |
| import re | |
| from typing import List, Optional | |
| import json | |
| import requests | |
| from jsonschema import RefResolver | |
| from pydantic import BaseModel, ValidationError | |
| from requests.exceptions import RequestException, Timeout | |
| from .tool import Tool | |
| MAX_RETRY_TIMES = 3 | |
| class ParametersSchema(BaseModel): | |
| name: str | |
| description: str | |
| required: Optional[bool] = True | |
| class ToolSchema(BaseModel): | |
| name: str | |
| description: str | |
| parameters: List[ParametersSchema] | |
| class OpenAPIPluginTool(Tool): | |
| """ | |
| openapi schema tool | |
| """ | |
| name: str = 'api tool' | |
| description: str = 'This is a api tool that ...' | |
| parameters: list = [] | |
| def __init__(self, cfg, name): | |
| self.name = name | |
| self.cfg = cfg.get(self.name, {}) | |
| self.is_remote_tool = self.cfg.get('is_remote_tool', False) | |
| # remote call | |
| self.url = self.cfg.get('url', '') | |
| self.token = self.cfg.get('token', '') | |
| self.header = self.cfg.get('header', '') | |
| self.method = self.cfg.get('method', '') | |
| self.parameters = self.cfg.get('parameters', []) | |
| self.description = self.cfg.get('description', | |
| 'This is a api tool that ...') | |
| self.responses_param = self.cfg.get('responses_param', []) | |
| try: | |
| all_para = { | |
| 'name': self.name, | |
| 'description': self.description, | |
| 'parameters': self.parameters | |
| } | |
| self.tool_schema = ToolSchema(**all_para) | |
| except ValidationError: | |
| raise ValueError(f'Error when parsing parameters of {self.name}') | |
| self._str = self.tool_schema.model_dump_json() | |
| self._function = self.parse_pydantic_model_to_openai_function(all_para) | |
| def _remote_call(self, *args, **kwargs): | |
| if self.url == '': | |
| raise ValueError( | |
| f"Could not use remote call for {self.name} since this tool doesn't have a remote endpoint" | |
| ) | |
| remote_parsed_input = json.dumps( | |
| self._remote_parse_input(*args, **kwargs)) | |
| origin_result = None | |
| if self.method == 'POST': | |
| retry_times = MAX_RETRY_TIMES | |
| while retry_times: | |
| retry_times -= 1 | |
| try: | |
| print(f'data: {kwargs}') | |
| print(f'header: {self.header}') | |
| response = requests.request( | |
| 'POST', | |
| url=self.url, | |
| headers=self.header, | |
| data=remote_parsed_input) | |
| if response.status_code != requests.codes.ok: | |
| response.raise_for_status() | |
| origin_result = json.loads( | |
| response.content.decode('utf-8')) | |
| final_result = self._parse_output( | |
| origin_result, remote=True) | |
| return final_result | |
| except Timeout: | |
| continue | |
| except RequestException as e: | |
| raise ValueError( | |
| f'Remote call failed with error code: {e.response.status_code},\ | |
| error message: {e.response.content.decode("utf-8")}') | |
| raise ValueError( | |
| 'Remote call max retry times exceeded! Please try to use local call.' | |
| ) | |
| elif self.method == 'GET': | |
| retry_times = MAX_RETRY_TIMES | |
| new_url = self.url | |
| matches = re.findall(r'\{(.*?)\}', self.url) | |
| for match in matches: | |
| if match in kwargs: | |
| new_url = new_url.replace('{' + match + '}', kwargs[match]) | |
| else: | |
| print( | |
| f'The parameter {match} was not generated by the model.' | |
| ) | |
| while retry_times: | |
| retry_times -= 1 | |
| try: | |
| print('GET:', new_url) | |
| print('GET:', self.url) | |
| response = requests.request( | |
| 'GET', | |
| url=new_url, | |
| headers=self.header, | |
| params=remote_parsed_input) | |
| if response.status_code != requests.codes.ok: | |
| response.raise_for_status() | |
| origin_result = json.loads( | |
| response.content.decode('utf-8')) | |
| final_result = self._parse_output( | |
| origin_result, remote=True) | |
| return final_result | |
| except Timeout: | |
| continue | |
| except RequestException as e: | |
| raise ValueError( | |
| f'Remote call failed with error code: {e.response.status_code},\ | |
| error message: {e.response.content.decode("utf-8")}') | |
| raise ValueError( | |
| 'Remote call max retry times exceeded! Please try to use local call.' | |
| ) | |
| else: | |
| raise ValueError( | |
| 'Remote call method is invalid!We have POST and GET method.') | |
| def _remote_parse_input(self, *args, **kwargs): | |
| restored_dict = {} | |
| for key, value in kwargs.items(): | |
| if '.' in key: | |
| # Split keys by "." and create nested dictionary structures | |
| keys = key.split('.') | |
| temp_dict = restored_dict | |
| for k in keys[:-1]: | |
| temp_dict = temp_dict.setdefault(k, {}) | |
| temp_dict[keys[-1]] = value | |
| else: | |
| # f the key does not contain ".", directly store the key-value pair into restored_dict | |
| restored_dict[key] = value | |
| kwargs = restored_dict | |
| print('传给tool的参数:', kwargs) | |
| return kwargs | |
| # openapi_schema_convert,register to tool_config.json | |
| def extract_references(schema_content): | |
| references = [] | |
| if isinstance(schema_content, dict): | |
| if '$ref' in schema_content: | |
| references.append(schema_content['$ref']) | |
| for key, value in schema_content.items(): | |
| references.extend(extract_references(value)) | |
| elif isinstance(schema_content, list): | |
| for item in schema_content: | |
| references.extend(extract_references(item)) | |
| return references | |
| def parse_nested_parameters(param_name, param_info, parameters_list, content): | |
| param_type = param_info['type'] | |
| param_description = param_info.get('description', | |
| f'用户输入的{param_name}') # 按需更改描述 | |
| param_required = param_name in content['required'] | |
| try: | |
| if param_type == 'object': | |
| properties = param_info.get('properties') | |
| if properties: | |
| # If the argument type is an object and has a non-empty "properties" field, | |
| # its internal properties are parsed recursively | |
| for inner_param_name, inner_param_info in properties.items(): | |
| inner_param_type = inner_param_info['type'] | |
| inner_param_description = inner_param_info.get( | |
| 'description', f'用户输入的{param_name}.{inner_param_name}') | |
| inner_param_required = param_name.split( | |
| '.')[0] in content['required'] | |
| # Recursively call the function to handle nested objects | |
| if inner_param_type == 'object': | |
| parse_nested_parameters( | |
| f'{param_name}.{inner_param_name}', | |
| inner_param_info, parameters_list, content) | |
| else: | |
| parameters_list.append({ | |
| 'name': | |
| f'{param_name}.{inner_param_name}', | |
| 'description': | |
| inner_param_description, | |
| 'required': | |
| inner_param_required, | |
| 'type': | |
| inner_param_type, | |
| 'value': | |
| inner_param_info.get('enum', '') | |
| }) | |
| else: | |
| # Non-nested parameters are added directly to the parameter list | |
| parameters_list.append({ | |
| 'name': param_name, | |
| 'description': param_description, | |
| 'required': param_required, | |
| 'type': param_type, | |
| 'value': param_info.get('enum', '') | |
| }) | |
| except Exception as e: | |
| raise ValueError(f'{e}:schema结构出错') | |
| def parse_responses_parameters(param_name, param_info, parameters_list): | |
| param_type = param_info['type'] | |
| param_description = param_info.get('description', | |
| f'调用api返回的{param_name}') # 按需更改描述 | |
| try: | |
| if param_type == 'object': | |
| properties = param_info.get('properties') | |
| if properties: | |
| # If the argument type is an object and has a non-empty "properties" | |
| # field, its internal properties are parsed recursively | |
| for inner_param_name, inner_param_info in properties.items(): | |
| param_type = inner_param_info['type'] | |
| param_description = inner_param_info.get( | |
| 'description', | |
| f'调用api返回的{param_name}.{inner_param_name}') | |
| parameters_list.append({ | |
| 'name': f'{param_name}.{inner_param_name}', | |
| 'description': param_description, | |
| 'type': param_type, | |
| }) | |
| else: | |
| # Non-nested parameters are added directly to the parameter list | |
| parameters_list.append({ | |
| 'name': param_name, | |
| 'description': param_description, | |
| 'type': param_type, | |
| }) | |
| except Exception as e: | |
| raise ValueError(f'{e}:schema结构出错') | |
| def openapi_schema_convert(schema, auth): | |
| resolver = RefResolver.from_schema(schema) | |
| servers = schema.get('servers', []) | |
| if servers: | |
| servers_url = servers[0].get('url') | |
| else: | |
| print('No URL found in the schema.') | |
| # Extract endpoints | |
| endpoints = schema.get('paths', {}) | |
| description = schema.get('info', {}).get('description', | |
| 'This is a api tool that ...') | |
| config_data = {} | |
| # Iterate over each endpoint and its contents | |
| for endpoint_path, methods in endpoints.items(): | |
| for method, details in methods.items(): | |
| summary = details.get('summary', 'No summary').replace(' ', '_') | |
| name = details.get('operationId', 'No operationId') | |
| url = f'{servers_url}{endpoint_path}' | |
| security = details.get('security', [{}]) | |
| # Security (Bearer Token) | |
| authorization = '' | |
| if security: | |
| for sec in security: | |
| if 'BearerAuth' in sec: | |
| api_token = auth.get('apikey', os.environ['apikey']) | |
| api_token_type = auth.get('apikey_type', | |
| os.environ['apikey_type']) | |
| authorization = f'{api_token_type} {api_token}' | |
| if method.upper() == 'POST': | |
| requestBody = details.get('requestBody', {}) | |
| if requestBody: | |
| for content_type, content_details in requestBody.get( | |
| 'content', {}).items(): | |
| schema_content = content_details.get('schema', {}) | |
| references = extract_references(schema_content) | |
| for reference in references: | |
| resolved_schema = resolver.resolve(reference) | |
| content = resolved_schema[1] | |
| parameters_list = [] | |
| for param_name, param_info in content[ | |
| 'properties'].items(): | |
| parse_nested_parameters( | |
| param_name, param_info, parameters_list, | |
| content) | |
| X_DashScope_Async = requestBody.get( | |
| 'X-DashScope-Async', '') | |
| if X_DashScope_Async == '': | |
| config_entry = { | |
| 'name': name, | |
| 'description': description, | |
| 'is_active': True, | |
| 'is_remote_tool': True, | |
| 'url': url, | |
| 'method': method.upper(), | |
| 'parameters': parameters_list, | |
| 'header': { | |
| 'Content-Type': content_type, | |
| 'Authorization': authorization | |
| } | |
| } | |
| else: | |
| config_entry = { | |
| 'name': name, | |
| 'description': description, | |
| 'is_active': True, | |
| 'is_remote_tool': True, | |
| 'url': url, | |
| 'method': method.upper(), | |
| 'parameters': parameters_list, | |
| 'header': { | |
| 'Content-Type': content_type, | |
| 'Authorization': authorization, | |
| 'X-DashScope-Async': 'enable' | |
| } | |
| } | |
| else: | |
| config_entry = { | |
| 'name': name, | |
| 'description': description, | |
| 'is_active': True, | |
| 'is_remote_tool': True, | |
| 'url': url, | |
| 'method': method.upper(), | |
| 'parameters': [], | |
| 'header': { | |
| 'Content-Type': 'application/json', | |
| 'Authorization': authorization | |
| } | |
| } | |
| elif method.upper() == 'GET': | |
| parameters_list = [] | |
| parameters_list = details.get('parameters', []) | |
| config_entry = { | |
| 'name': name, | |
| 'description': description, | |
| 'is_active': True, | |
| 'is_remote_tool': True, | |
| 'url': url, | |
| 'method': method.upper(), | |
| 'parameters': parameters_list, | |
| 'header': { | |
| 'Authorization': authorization | |
| } | |
| } | |
| else: | |
| raise 'method is not POST or GET' | |
| config_data[summary] = config_entry | |
| return config_data | |