| | from copy import copy |
| | from typing import Dict, List, Optional, Tuple, Union |
| |
|
| |
|
| | class LMTemplateParser: |
| | """Intermidate prompt template parser, specifically for language models. |
| | |
| | Args: |
| | meta_template (list of dict, optional): The meta template for the |
| | model. |
| | """ |
| |
|
| | def __init__(self, meta_template: Optional[List[Dict]] = None): |
| | self.meta_template = meta_template |
| | if meta_template: |
| | assert isinstance(meta_template, list) |
| | self.roles: Dict[str, dict] = dict() |
| | for item in meta_template: |
| | assert isinstance(item, dict) |
| | assert item['role'] not in self.roles, \ |
| | 'role in meta prompt must be unique!' |
| | self.roles[item['role']] = item.copy() |
| |
|
| | def __call__(self, dialog) -> str: |
| | """Parse a prompt template, and wrap it with meta template if |
| | applicable. |
| | |
| | Args: |
| | dialog (List[str or PromptList]): A prompt |
| | template (potentially before being wrapped by meta template). |
| | |
| | Returns: |
| | str: The final string. |
| | """ |
| | assert isinstance(dialog, (str, list)) |
| | if isinstance(dialog, str): |
| | return dialog |
| | if self.meta_template: |
| |
|
| | prompt = '' |
| | for index, item in enumerate(dialog): |
| | if isinstance(item, str): |
| | prompt += item |
| | else: |
| | new_str = self._prompt2str(item, index == len(dialog) - 1) |
| | prompt += new_str |
| | else: |
| | |
| | prompt = '' |
| | last_sep = '' |
| | for item in dialog: |
| | if isinstance(item, str): |
| | if item: |
| | prompt += last_sep + item |
| | elif item.get('content', ''): |
| | prompt += last_sep + item.get('prompt', '') |
| | last_sep = '\n' |
| | return prompt |
| |
|
| | def _format_begin(self, role_cfg, message): |
| | name = message.get('name', None) |
| | if name is not None: |
| | begin = role_cfg['begin'].get('with_name', '') |
| | if name in role_cfg['begin'].get('name', {}): |
| | begin = begin.format(name=role_cfg['begin']['name'][name]) |
| | else: |
| | begin = begin.format(name=name) |
| | else: |
| | if isinstance(role_cfg.get('begin', ''), str): |
| | begin = role_cfg.get('begin', '') |
| | elif isinstance(role_cfg['begin'], dict): |
| | begin = role_cfg['begin'].get('without_name', '') |
| | return begin |
| |
|
| | def _prompt2str(self, |
| | prompt: Union[str, Dict], |
| | last: bool = False) -> Tuple[str, bool]: |
| | if isinstance(prompt, str): |
| | return prompt |
| | merged_prompt = self.roles.get(prompt['role']) |
| |
|
| | if merged_prompt.get('fallback_role'): |
| | merged_prompt = self.roles.get(merged_prompt['fallback_role']) |
| | begin = self._format_begin(merged_prompt, prompt) |
| | res = begin |
| | if last and merged_prompt.get('generate', False): |
| | res += prompt.get('content', '') |
| | return res |
| | res += prompt.get('content', '') + merged_prompt.get('end', '') |
| | if last and merged_prompt['role'] != 'assistant': |
| | res += self._format_begin(self.roles['assistant'], {}) |
| | return res |
| | return res |
| |
|
| |
|
| | class BaseLLM: |
| | """Base class for model wrapper. |
| | |
| | Args: |
| | path (str): The path to the model. |
| | max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults |
| | to 512. |
| | tokenizer_only (bool): If True, only the tokenizer will be initialized. |
| | Defaults to False. |
| | meta_template (list of dict, optional): The model's meta prompt |
| | template if needed, in case the requirement of injecting or |
| | wrapping of any meta instructions. |
| | """ |
| |
|
| | def __init__(self, |
| | path: str, |
| | tokenizer_only: bool = False, |
| | template_parser: 'LMTemplateParser' = LMTemplateParser, |
| | meta_template: Optional[List[Dict]] = None, |
| | *, |
| | max_new_tokens: int = 512, |
| | top_p: float = 0.8, |
| | top_k: float = 40, |
| | temperature: float = 0.8, |
| | repetition_penalty: float = 1.0, |
| | stop_words: Union[List[str], str] = None): |
| | self.path = path |
| | self.tokenizer_only = tokenizer_only |
| | |
| | self.template_parser = template_parser(meta_template) |
| | self.eos_token_id = None |
| | if meta_template and 'eos_token_id' in meta_template: |
| | self.eos_token_id = meta_template['eos_token_id'] |
| |
|
| | if isinstance(stop_words, str): |
| | stop_words = [stop_words] |
| | self.gen_params = dict( |
| | max_new_tokens=max_new_tokens, |
| | top_p=top_p, |
| | top_k=top_k, |
| | temperature=temperature, |
| | repetition_penalty=repetition_penalty, |
| | stop_words=stop_words) |
| |
|
| | def generate(self, inputs: Union[str, List[str]], **gen_params) -> str: |
| | """Generate results given a str (or list of) inputs. |
| | |
| | Args: |
| | inputs (Union[str, List[str]]): |
| | gen_params (dict): The input params for generation. |
| | |
| | Returns: |
| | Union[str, List[str]]: A (list of) generated strings. |
| | |
| | eg. |
| | batched = True |
| | if isinstance(inputs, str): |
| | inputs = [inputs] |
| | batched = False |
| | response = [''] |
| | if batched: |
| | return response |
| | return response[0] |
| | """ |
| | raise NotImplementedError |
| |
|
| | def stream_generate(self, inputs: str, **gen_params) -> List[str]: |
| | """Generate results as streaming given a str inputs. |
| | |
| | Args: |
| | inputs (str): |
| | gen_params (dict): The input params for generation. |
| | |
| | Returns: |
| | str: A generated string. |
| | """ |
| | raise NotImplementedError |
| |
|
| | def chat(self, |
| | inputs: Union[List[dict], List[List[dict]]], |
| | session_ids: Union[int, List[int]] = None, |
| | **gen_params): |
| | """Generate completion from a list of templates. |
| | |
| | Args: |
| | inputs (Union[List[dict], List[List[dict]]]): |
| | gen_params (dict): The input params for generation. |
| | Returns: |
| | """ |
| | if isinstance(inputs[0], list): |
| | _inputs = list() |
| | for msg in inputs: |
| | _inputs.append(self.template_parser(msg)) |
| | else: |
| | _inputs = self.template_parser(inputs) |
| | return self.generate(_inputs, **gen_params) |
| |
|
| | def stream_chat(self, inputs: List[dict], **gen_params): |
| | """Generate results as streaming given a list of templates. |
| | |
| | Args: |
| | inputs (Union[List[dict]): |
| | gen_params (dict): The input params for generation. |
| | Returns: |
| | """ |
| | raise NotImplementedError |
| |
|
| | def tokenize(self, prompts: Union[str, List[str], List[dict], |
| | List[List[dict]]]): |
| | """Tokenize the input prompts. |
| | |
| | Args: |
| | prompts(str | List[str]): user's prompt, or a batch prompts |
| | |
| | Returns: |
| | Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token |
| | ids, ids' length and requested output length |
| | """ |
| | raise NotImplementedError |
| |
|
| | def update_gen_params(self, **kwargs): |
| | gen_params = copy(self.gen_params) |
| | gen_params.update(kwargs) |
| | return gen_params |
| |
|
| |
|
| | class AsyncLLMMixin: |
| |
|
| | async def generate(self, |
| | inputs: Union[str, List[str]], |
| | session_ids: Union[int, List[int]] = None, |
| | **gen_params) -> str: |
| | """Generate results given a str (or list of) inputs. |
| | |
| | Args: |
| | inputs (Union[str, List[str]]): |
| | gen_params (dict): The input params for generation. |
| | |
| | Returns: |
| | Union[str, List[str]]: A (list of) generated strings. |
| | |
| | eg. |
| | batched = True |
| | if isinstance(inputs, str): |
| | inputs = [inputs] |
| | batched = False |
| | response = [''] |
| | if batched: |
| | return response |
| | return response[0] |
| | """ |
| | raise NotImplementedError |
| |
|
| | async def stream_generate(self, inputs: str, **gen_params) -> List[str]: |
| | """Generate results as streaming given a str inputs. |
| | |
| | Args: |
| | inputs (str): |
| | gen_params (dict): The input params for generation. |
| | |
| | Returns: |
| | str: A generated string. |
| | """ |
| | raise NotImplementedError |
| |
|
| | async def chat(self, |
| | inputs: Union[List[dict], List[List[dict]]], |
| | session_ids: Union[int, List[int]] = None, |
| | **gen_params): |
| | """Generate completion from a list of templates. |
| | |
| | Args: |
| | inputs (Union[List[dict], List[List[dict]]]): |
| | gen_params (dict): The input params for generation. |
| | Returns: |
| | """ |
| | if isinstance(inputs[0], list): |
| | _inputs = list() |
| | for msg in inputs: |
| | _inputs.append(self.template_parser(msg)) |
| | else: |
| | _inputs = self.template_parser(inputs) |
| | return await self.generate(_inputs, session_ids, **gen_params) |
| |
|
| | async def stream_chat(self, inputs: List[dict], **gen_params): |
| | """Generate results as streaming given a list of templates. |
| | |
| | Args: |
| | inputs (Union[List[dict]): |
| | gen_params (dict): The input params for generation. |
| | Returns: |
| | """ |
| | raise NotImplementedError |
| |
|
| | async def tokenize(self, prompts: Union[str, List[str], List[dict], |
| | List[List[dict]]]): |
| | """Tokenize the input prompts. |
| | |
| | Args: |
| | prompts(str | List[str]): user's prompt, or a batch prompts |
| | |
| | Returns: |
| | Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token |
| | ids, ids' length and requested output length |
| | """ |
| | raise NotImplementedError |
| |
|
| |
|
| | class AsyncBaseLLM(AsyncLLMMixin, BaseLLM): |
| | pass |
| |
|