| | import json |
| | import os |
| | import re |
| | import time |
| | from concurrent.futures import ThreadPoolExecutor |
| | from threading import Lock |
| | from typing import Dict, List, Optional, Union |
| |
|
| | import jieba |
| | import requests |
| |
|
| | from opencompass.registry import MODELS |
| | from opencompass.utils.prompt import PromptList |
| |
|
| | from .base_api import BaseAPIModel |
| |
|
| | PromptType = Union[PromptList, str] |
| | OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions' |
| |
|
| |
|
| | @MODELS.register_module() |
| | class OpenAI(BaseAPIModel): |
| | """Model wrapper around OpenAI's models. |
| | |
| | Args: |
| | path (str): The name of OpenAI's model. |
| | max_seq_len (int): The maximum allowed sequence length of a model. |
| | Note that the length of prompt + generated tokens shall not exceed |
| | this value. Defaults to 2048. |
| | query_per_second (int): The maximum queries allowed per second |
| | between two consecutive calls of the API. Defaults to 1. |
| | retry (int): Number of retires if the API call fails. Defaults to 2. |
| | key (str or List[str]): OpenAI key(s). In particular, when it |
| | is set to "ENV", the key will be fetched from the environment |
| | variable $OPENAI_API_KEY, as how openai defaults to be. If it's a |
| | list, the keys will be used in round-robin manner. Defaults to |
| | 'ENV'. |
| | org (str or List[str], optional): OpenAI organization(s). If not |
| | specified, OpenAI uses the default organization bound to each API |
| | key. If specified, the orgs will be posted with each request in |
| | round-robin manner. Defaults to None. |
| | meta_template (Dict, optional): The model's meta prompt |
| | template if needed, in case the requirement of injecting or |
| | wrapping of any meta instructions. |
| | openai_api_base (str): The base url of OpenAI's API. Defaults to |
| | 'https://api.openai.com/v1/chat/completions'. |
| | mode (str, optional): The method of input truncation when input length |
| | exceeds max_seq_len. 'front','mid' and 'rear' represents the part |
| | of input to truncate. Defaults to 'none'. |
| | temperature (float, optional): What sampling temperature to use. |
| | If not None, will override the temperature in the `generate()` |
| | call. Defaults to None. |
| | """ |
| |
|
| | is_api: bool = True |
| |
|
| | def __init__(self, |
| | path: str = 'gpt-3.5-turbo', |
| | max_seq_len: int = 4096, |
| | query_per_second: int = 1, |
| | rpm_verbose: bool = False, |
| | retry: int = 2, |
| | key: Union[str, List[str]] = 'ENV', |
| | org: Optional[Union[str, List[str]]] = None, |
| | meta_template: Optional[Dict] = None, |
| | openai_api_base: str = OPENAI_API_BASE, |
| | mode: str = 'none', |
| | temperature: Optional[float] = None): |
| |
|
| | super().__init__(path=path, |
| | max_seq_len=max_seq_len, |
| | meta_template=meta_template, |
| | query_per_second=query_per_second, |
| | rpm_verbose=rpm_verbose, |
| | retry=retry) |
| | import tiktoken |
| | self.tiktoken = tiktoken |
| | self.temperature = temperature |
| | assert mode in ['none', 'front', 'mid', 'rear'] |
| | self.mode = mode |
| |
|
| | if isinstance(key, str): |
| | self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] |
| | else: |
| | self.keys = key |
| |
|
| | |
| | |
| | self.invalid_keys = set() |
| |
|
| | self.key_ctr = 0 |
| | if isinstance(org, str): |
| | self.orgs = [org] |
| | else: |
| | self.orgs = org |
| | self.org_ctr = 0 |
| | self.url = openai_api_base |
| | self.path = path |
| |
|
| | def generate( |
| | self, |
| | inputs: List[str or PromptList], |
| | max_out_len: int = 512, |
| | temperature: float = 0.7, |
| | ) -> List[str]: |
| | """Generate results given a list of inputs. |
| | |
| | Args: |
| | inputs (List[str or PromptList]): A list of strings or PromptDicts. |
| | The PromptDict should be organized in OpenCompass' |
| | API format. |
| | max_out_len (int): The maximum length of the output. |
| | temperature (float): What sampling temperature to use, |
| | between 0 and 2. Higher values like 0.8 will make the output |
| | more random, while lower values like 0.2 will make it more |
| | focused and deterministic. Defaults to 0.7. |
| | |
| | Returns: |
| | List[str]: A list of generated strings. |
| | """ |
| | if self.temperature is not None: |
| | temperature = self.temperature |
| |
|
| | with ThreadPoolExecutor() as executor: |
| | results = list( |
| | executor.map(self._generate, inputs, |
| | [max_out_len] * len(inputs), |
| | [temperature] * len(inputs))) |
| | return results |
| |
|
| | def _generate(self, input: str or PromptList, max_out_len: int, |
| | temperature: float) -> str: |
| | """Generate results given a list of inputs. |
| | |
| | Args: |
| | inputs (str or PromptList): A string or PromptDict. |
| | The PromptDict should be organized in OpenCompass' |
| | API format. |
| | max_out_len (int): The maximum length of the output. |
| | temperature (float): What sampling temperature to use, |
| | between 0 and 2. Higher values like 0.8 will make the output |
| | more random, while lower values like 0.2 will make it more |
| | focused and deterministic. |
| | |
| | Returns: |
| | str: The generated string. |
| | """ |
| | assert isinstance(input, (str, PromptList)) |
| |
|
| | |
| | context_window = 4096 |
| | if '32k' in self.path: |
| | context_window = 32768 |
| | elif '16k' in self.path: |
| | context_window = 16384 |
| | elif 'gpt-4' in self.path: |
| | context_window = 8192 |
| |
|
| | |
| | if isinstance(input, str) and self.mode != 'none': |
| | context_window = self.max_seq_len |
| | input = self.bin_trim(input, context_window - 100 - max_out_len) |
| |
|
| | if isinstance(input, str): |
| | messages = [{'role': 'user', 'content': input}] |
| | else: |
| | messages = [] |
| | for item in input: |
| | msg = {'content': item['prompt']} |
| | if item['role'] == 'HUMAN': |
| | msg['role'] = 'user' |
| | elif item['role'] == 'BOT': |
| | msg['role'] = 'assistant' |
| | elif item['role'] == 'SYSTEM': |
| | msg['role'] = 'system' |
| | messages.append(msg) |
| |
|
| | |
| | max_out_len = min( |
| | max_out_len, context_window - self.get_token_len(str(input)) - 100) |
| | if max_out_len <= 0: |
| | return '' |
| |
|
| | max_num_retries = 0 |
| | while max_num_retries < self.retry: |
| | self.wait() |
| |
|
| | with Lock(): |
| | if len(self.invalid_keys) == len(self.keys): |
| | raise RuntimeError('All keys have insufficient quota.') |
| |
|
| | |
| | while True: |
| | self.key_ctr += 1 |
| | if self.key_ctr == len(self.keys): |
| | self.key_ctr = 0 |
| |
|
| | if self.keys[self.key_ctr] not in self.invalid_keys: |
| | break |
| |
|
| | key = self.keys[self.key_ctr] |
| |
|
| | header = { |
| | 'Authorization': f'Bearer {key}', |
| | 'content-type': 'application/json', |
| | } |
| |
|
| | if self.orgs: |
| | with Lock(): |
| | self.org_ctr += 1 |
| | if self.org_ctr == len(self.orgs): |
| | self.org_ctr = 0 |
| | header['OpenAI-Organization'] = self.orgs[self.org_ctr] |
| |
|
| | try: |
| | data = dict( |
| | model=self.path, |
| | messages=messages, |
| | max_tokens=max_out_len, |
| | n=1, |
| | stop=None, |
| | temperature=temperature, |
| | ) |
| | raw_response = requests.post(self.url, |
| | headers=header, |
| | data=json.dumps(data)) |
| | except requests.ConnectionError: |
| | self.logger.error('Got connection error, retrying...') |
| | continue |
| | try: |
| | response = raw_response.json() |
| | except requests.JSONDecodeError: |
| | self.logger.error('JsonDecode error, got', |
| | str(raw_response.content)) |
| | continue |
| | try: |
| | return response['choices'][0]['message']['content'].strip() |
| | except KeyError: |
| | if 'error' in response: |
| | if response['error']['code'] == 'rate_limit_exceeded': |
| | time.sleep(1) |
| | continue |
| | elif response['error']['code'] == 'insufficient_quota': |
| | self.invalid_keys.add(key) |
| | self.logger.warn(f'insufficient_quota key: {key}') |
| | continue |
| |
|
| | self.logger.error('Find error message in response: ', |
| | str(response['error'])) |
| | max_num_retries += 1 |
| |
|
| | raise RuntimeError('Calling OpenAI failed after retrying for ' |
| | f'{max_num_retries} times. Check the logs for ' |
| | 'details.') |
| |
|
| | def get_token_len(self, prompt: str) -> int: |
| | """Get lengths of the tokenized string. Only English and Chinese |
| | characters are counted for now. Users are encouraged to override this |
| | method if more accurate length is needed. |
| | |
| | Args: |
| | prompt (str): Input string. |
| | |
| | Returns: |
| | int: Length of the input tokens |
| | """ |
| | enc = self.tiktoken.encoding_for_model(self.path) |
| | return len(enc.encode(prompt)) |
| |
|
| | def bin_trim(self, prompt: str, num_token: int) -> str: |
| | """Get a suffix of prompt which is no longer than num_token tokens. |
| | |
| | Args: |
| | prompt (str): Input string. |
| | num_token (int): The upper bound of token numbers. |
| | |
| | Returns: |
| | str: The trimmed prompt. |
| | """ |
| | token_len = self.get_token_len(prompt) |
| | if token_len <= num_token: |
| | return prompt |
| | pattern = re.compile(r'[\u4e00-\u9fa5]') |
| | if pattern.search(prompt): |
| | words = list(jieba.cut(prompt, cut_all=False)) |
| | sep = '' |
| | else: |
| | words = prompt.split(' ') |
| | sep = ' ' |
| |
|
| | l, r = 1, len(words) |
| | while l + 2 < r: |
| | mid = (l + r) // 2 |
| | if self.mode == 'front': |
| | cur_prompt = sep.join(words[-mid:]) |
| | elif self.mode == 'mid': |
| | cur_prompt = sep.join(words[:mid]) + sep.join(words[-mid:]) |
| | elif self.mode == 'rear': |
| | cur_prompt = sep.join(words[:mid]) |
| |
|
| | if self.get_token_len(cur_prompt) <= num_token: |
| | l = mid |
| | else: |
| | r = mid |
| |
|
| | if self.mode == 'front': |
| | prompt = sep.join(words[-l:]) |
| | elif self.mode == 'mid': |
| | prompt = sep.join(words[:l]) + sep.join(words[-l:]) |
| | elif self.mode == 'rear': |
| | prompt = sep.join(words[:l]) |
| | return prompt |
| |
|
| |
|
| | class OpenAIAllesAPIN(OpenAI): |
| | """Model wrapper around OpenAI-AllesAPIN. |
| | |
| | Args: |
| | path (str): The name of OpenAI's model. |
| | url (str): URL to AllesAPIN. |
| | key (str): AllesAPIN key. |
| | query_per_second (int): The maximum queries allowed per second |
| | between two consecutive calls of the API. Defaults to 1. |
| | max_seq_len (int): Unused here. |
| | meta_template (Dict, optional): The model's meta prompt |
| | template if needed, in case the requirement of injecting or |
| | wrapping of any meta instructions. |
| | retry (int): Number of retires if the API call fails. Defaults to 2. |
| | """ |
| |
|
| | is_api: bool = True |
| |
|
| | def __init__(self, |
| | path: str, |
| | url: str, |
| | key: str, |
| | temperature: float = 1.0, |
| | query_per_second: int = 1, |
| | rpm_verbose: bool = False, |
| | max_seq_len: int = 2048, |
| | meta_template: Optional[Dict] = None, |
| | retry: int = 2): |
| | super().__init__(path=path, |
| | max_seq_len=max_seq_len, |
| | query_per_second=query_per_second, |
| | rpm_verbose=rpm_verbose, |
| | meta_template=meta_template, |
| | retry=retry) |
| | self.url = url |
| | self.temperature = temperature |
| | self.headers = { |
| | 'alles-apin-token': key, |
| | 'content-type': 'application/json', |
| | } |
| |
|
| | def _generate(self, input: str or PromptList, max_out_len: int, |
| | temperature: float) -> str: |
| | """Generate results given an input. |
| | |
| | Args: |
| | inputs (str or PromptList): A string or PromptDict. |
| | The PromptDict should be organized in OpenCompass' |
| | API format. |
| | max_out_len (int): The maximum length of the output. |
| | temperature (float): What sampling temperature to use, |
| | between 0 and 2. Higher values like 0.8 will make the output |
| | more random, while lower values like 0.2 will make it more |
| | focused and deterministic. |
| | |
| | Returns: |
| | str: The generated string. |
| | """ |
| | assert isinstance(input, (str, PromptList)) |
| |
|
| | if isinstance(input, str): |
| | messages = [{'role': 'user', 'content': input}] |
| | else: |
| | messages = [] |
| | for item in input: |
| | msg = {'content': item['prompt']} |
| | if item['role'] == 'HUMAN': |
| | msg['role'] = 'user' |
| | elif item['role'] == 'BOT': |
| | msg['role'] = 'assistant' |
| | elif item['role'] == 'SYSTEM': |
| | msg['role'] = 'system' |
| | messages.append(msg) |
| |
|
| | |
| | |
| | assert msg['role'] in ['user', 'system'] |
| |
|
| | data = { |
| | 'model': self.path, |
| | 'messages': messages, |
| | 'temperature': temperature |
| | } |
| | for _ in range(self.retry): |
| | self.wait() |
| | raw_response = requests.post(self.url, |
| | headers=self.headers, |
| | data=json.dumps(data)) |
| | try: |
| | response = raw_response.json() |
| | except requests.JSONDecodeError: |
| | self.logger.error('JsonDecode error, got', |
| | str(raw_response.content)) |
| | time.sleep(1) |
| | continue |
| | if raw_response.status_code == 200 and response[ |
| | 'msgCode'] == '10000': |
| | data = response['data'] |
| | choices = data['choices'] |
| | if choices is None: |
| | self.logger.error(data) |
| | else: |
| | return choices[0]['message']['content'].strip() |
| | try: |
| | match = re.match(r'Error code: \d+ - (.*)', response['data']) |
| | err = eval(match.group(1))['error'] |
| | if err['code'] == 'content_filter' and err['status'] == 400: |
| | return err['message'] |
| | except Exception: |
| | pass |
| | self.logger.error(response['msg']) |
| | self.logger.error(response) |
| | time.sleep(1) |
| |
|
| | raise RuntimeError('API call failed.') |
| |
|
| | def get_token_len(self, prompt: str) -> int: |
| | """Get lengths of the tokenized string. Only English and Chinese |
| | characters are counted for now. Users are encouraged to override this |
| | method if more accurate length is needed. |
| | |
| | Args: |
| | prompt (str): Input string. |
| | |
| | Returns: |
| | int: Length of the input tokens |
| | """ |
| | enc = self.tiktoken.encoding_for_model(self.path) |
| | return len(enc.encode(prompt)) |
| |
|