| from copy import copy |
| from typing import Dict, List, Optional, Tuple, Union |
|
|
|
|
| class LMTemplateParser: |
| """Intermidate prompt template parser, specifically for language models. |
| |
| Args: |
| meta_template (list of dict, optional): The meta template for the |
| model. |
| """ |
|
|
| def __init__(self, meta_template: Optional[List[Dict]] = None): |
| self.meta_template = meta_template |
| if meta_template: |
| assert isinstance(meta_template, list) |
| self.roles: Dict[str, dict] = dict() |
| for item in meta_template: |
| assert isinstance(item, dict) |
| assert item['role'] not in self.roles, \ |
| 'role in meta prompt must be unique!' |
| self.roles[item['role']] = item.copy() |
|
|
| def __call__(self, dialog) -> str: |
| """Parse a prompt template, and wrap it with meta template if |
| applicable. |
| |
| Args: |
| dialog (List[str or PromptList]): A prompt |
| template (potentially before being wrapped by meta template). |
| |
| Returns: |
| str: The final string. |
| """ |
| assert isinstance(dialog, (str, list)) |
| if isinstance(dialog, str): |
| return dialog |
| if self.meta_template: |
|
|
| prompt = '' |
| for index, item in enumerate(dialog): |
| if isinstance(item, str): |
| prompt += item |
| else: |
| new_str = self._prompt2str(item, index == len(dialog) - 1) |
| prompt += new_str |
| else: |
| |
| prompt = '' |
| last_sep = '' |
| for item in dialog: |
| if isinstance(item, str): |
| if item: |
| prompt += last_sep + item |
| elif item.get('content', ''): |
| prompt += last_sep + item.get('prompt', '') |
| last_sep = '\n' |
| return prompt |
|
|
| def _format_begin(self, role_cfg, message): |
| name = message.get('name', None) |
| if name is not None: |
| begin = role_cfg['begin'].get('with_name', '') |
| if name in role_cfg['begin'].get('name', {}): |
| begin = begin.format(name=role_cfg['begin']['name'][name]) |
| else: |
| begin = begin.format(name=name) |
| else: |
| if isinstance(role_cfg.get('begin', ''), str): |
| begin = role_cfg.get('begin', '') |
| elif isinstance(role_cfg['begin'], dict): |
| begin = role_cfg['begin'].get('without_name', '') |
| return begin |
|
|
| def _prompt2str(self, |
| prompt: Union[str, Dict], |
| last: bool = False) -> Tuple[str, bool]: |
| if isinstance(prompt, str): |
| return prompt |
| merged_prompt = self.roles.get(prompt['role']) |
|
|
| if merged_prompt.get('fallback_role'): |
| merged_prompt = self.roles.get(merged_prompt['fallback_role']) |
| begin = self._format_begin(merged_prompt, prompt) |
| res = begin |
| if last and merged_prompt.get('generate', False): |
| res += prompt.get('content', '') |
| return res |
| res += prompt.get('content', '') + merged_prompt.get('end', '') |
| if last and merged_prompt['role'] != 'assistant': |
| res += self._format_begin(self.roles['assistant'], {}) |
| return res |
| return res |
|
|
|
|
| class BaseLLM: |
| """Base class for model wrapper. |
| |
| Args: |
| path (str): The path to the model. |
| max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults |
| to 512. |
| tokenizer_only (bool): If True, only the tokenizer will be initialized. |
| Defaults to False. |
| meta_template (list of dict, optional): The model's meta prompt |
| template if needed, in case the requirement of injecting or |
| wrapping of any meta instructions. |
| """ |
|
|
| def __init__(self, |
| path: str, |
| tokenizer_only: bool = False, |
| template_parser: 'LMTemplateParser' = LMTemplateParser, |
| meta_template: Optional[List[Dict]] = None, |
| *, |
| max_new_tokens: int = 512, |
| top_p: float = 0.8, |
| top_k: float = 40, |
| temperature: float = 0.8, |
| repetition_penalty: float = 1.0, |
| stop_words: Union[List[str], str] = None): |
| self.path = path |
| self.tokenizer_only = tokenizer_only |
| |
| self.template_parser = template_parser(meta_template) |
| self.eos_token_id = None |
| if meta_template and 'eos_token_id' in meta_template: |
| self.eos_token_id = meta_template['eos_token_id'] |
|
|
| if isinstance(stop_words, str): |
| stop_words = [stop_words] |
| self.gen_params = dict( |
| max_new_tokens=max_new_tokens, |
| top_p=top_p, |
| top_k=top_k, |
| temperature=temperature, |
| repetition_penalty=repetition_penalty, |
| stop_words=stop_words) |
|
|
| def generate(self, inputs: Union[str, List[str]], **gen_params) -> str: |
| """Generate results given a str (or list of) inputs. |
| |
| Args: |
| inputs (Union[str, List[str]]): |
| gen_params (dict): The input params for generation. |
| |
| Returns: |
| Union[str, List[str]]: A (list of) generated strings. |
| |
| eg. |
| batched = True |
| if isinstance(inputs, str): |
| inputs = [inputs] |
| batched = False |
| response = [''] |
| if batched: |
| return response |
| return response[0] |
| """ |
| raise NotImplementedError |
|
|
| def stream_generate(self, inputs: str, **gen_params) -> List[str]: |
| """Generate results as streaming given a str inputs. |
| |
| Args: |
| inputs (str): |
| gen_params (dict): The input params for generation. |
| |
| Returns: |
| str: A generated string. |
| """ |
| raise NotImplementedError |
|
|
| def chat(self, |
| inputs: Union[List[dict], List[List[dict]]], |
| session_ids: Union[int, List[int]] = None, |
| **gen_params): |
| """Generate completion from a list of templates. |
| |
| Args: |
| inputs (Union[List[dict], List[List[dict]]]): |
| gen_params (dict): The input params for generation. |
| Returns: |
| """ |
| if isinstance(inputs[0], list): |
| _inputs = list() |
| for msg in inputs: |
| _inputs.append(self.template_parser(msg)) |
| else: |
| _inputs = self.template_parser(inputs) |
| return self.generate(_inputs, **gen_params) |
|
|
| def stream_chat(self, inputs: List[dict], **gen_params): |
| """Generate results as streaming given a list of templates. |
| |
| Args: |
| inputs (Union[List[dict]): |
| gen_params (dict): The input params for generation. |
| Returns: |
| """ |
| raise NotImplementedError |
|
|
| def tokenize(self, prompts: Union[str, List[str], List[dict], |
| List[List[dict]]]): |
| """Tokenize the input prompts. |
| |
| Args: |
| prompts(str | List[str]): user's prompt, or a batch prompts |
| |
| Returns: |
| Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token |
| ids, ids' length and requested output length |
| """ |
| raise NotImplementedError |
|
|
| def update_gen_params(self, **kwargs): |
| gen_params = copy(self.gen_params) |
| gen_params.update(kwargs) |
| return gen_params |
|
|
|
|
| class AsyncLLMMixin: |
|
|
| async def generate(self, |
| inputs: Union[str, List[str]], |
| session_ids: Union[int, List[int]] = None, |
| **gen_params) -> str: |
| """Generate results given a str (or list of) inputs. |
| |
| Args: |
| inputs (Union[str, List[str]]): |
| gen_params (dict): The input params for generation. |
| |
| Returns: |
| Union[str, List[str]]: A (list of) generated strings. |
| |
| eg. |
| batched = True |
| if isinstance(inputs, str): |
| inputs = [inputs] |
| batched = False |
| response = [''] |
| if batched: |
| return response |
| return response[0] |
| """ |
| raise NotImplementedError |
|
|
| async def stream_generate(self, inputs: str, **gen_params) -> List[str]: |
| """Generate results as streaming given a str inputs. |
| |
| Args: |
| inputs (str): |
| gen_params (dict): The input params for generation. |
| |
| Returns: |
| str: A generated string. |
| """ |
| raise NotImplementedError |
|
|
| async def chat(self, |
| inputs: Union[List[dict], List[List[dict]]], |
| session_ids: Union[int, List[int]] = None, |
| **gen_params): |
| """Generate completion from a list of templates. |
| |
| Args: |
| inputs (Union[List[dict], List[List[dict]]]): |
| gen_params (dict): The input params for generation. |
| Returns: |
| """ |
| if isinstance(inputs[0], list): |
| _inputs = list() |
| for msg in inputs: |
| _inputs.append(self.template_parser(msg)) |
| else: |
| _inputs = self.template_parser(inputs) |
| return await self.generate(_inputs, session_ids, **gen_params) |
|
|
| async def stream_chat(self, inputs: List[dict], **gen_params): |
| """Generate results as streaming given a list of templates. |
| |
| Args: |
| inputs (Union[List[dict]): |
| gen_params (dict): The input params for generation. |
| Returns: |
| """ |
| raise NotImplementedError |
|
|
| async def tokenize(self, prompts: Union[str, List[str], List[dict], |
| List[List[dict]]]): |
| """Tokenize the input prompts. |
| |
| Args: |
| prompts(str | List[str]): user's prompt, or a batch prompts |
| |
| Returns: |
| Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token |
| ids, ids' length and requested output length |
| """ |
| raise NotImplementedError |
|
|
|
|
| class AsyncBaseLLM(AsyncLLMMixin, BaseLLM): |
| pass |
|
|