|
|
|
|
|
|
|
|
from copy import deepcopy |
|
|
from dataclasses import dataclass, field |
|
|
from typing import List, Optional, Type, Union |
|
|
|
|
|
from transformers import PreTrainedTokenizerBase |
|
|
|
|
|
from .base import Template |
|
|
from .utils import Prompt, Word |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TemplateMeta: |
|
|
""" |
|
|
Examples: |
|
|
chatml (with bos): |
|
|
prefix: <s> |
|
|
prompt: <|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n |
|
|
chat_sep: <|im_end|>\n |
|
|
suffix: <|im_end|> |
|
|
system_prefix: <s><|im_start|>system\n{{SYSTEM}}<|im_end|>\n |
|
|
|
|
|
<s><|im_start|>system # prefix or system_prefix |
|
|
{{SYSTEM}}<|im_end|> |
|
|
<|im_start|>user # prompt |
|
|
{{QUERY}}<|im_end|> |
|
|
<|im_start|>assistant |
|
|
{{RESPONSE}}<|im_end|> # chat_sep |
|
|
<|im_start|>user # prompt |
|
|
{{QUERY}}<|im_end|> |
|
|
<|im_start|>assistant |
|
|
{{RESPONSE}}<|im_end|> # suffix |
|
|
""" |
|
|
template_type: str |
|
|
prefix: Prompt |
|
|
prompt: Prompt |
|
|
chat_sep: Optional[Prompt] |
|
|
suffix: Prompt = field(default_factory=lambda: [['eos_token_id']]) |
|
|
template_cls: Type[Template] = Template |
|
|
system_prefix: Optional[Prompt] = None |
|
|
default_system: Optional[str] = None |
|
|
response_prefix: str = '' |
|
|
|
|
|
auto_add_bos: bool = False |
|
|
stop_words: List[Word] = field(default_factory=list) |
|
|
agent_template: str = 'react_en' |
|
|
|
|
|
def to_generate_template_meta(self) -> 'TemplateMeta': |
|
|
self = deepcopy(self) |
|
|
return TemplateMeta( |
|
|
self.template_type, |
|
|
prefix=[], |
|
|
prompt=['{{QUERY}}'], |
|
|
chat_sep=None, |
|
|
template_cls=self.template_cls, |
|
|
auto_add_bos=True, |
|
|
stop_words=self.stop_words, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _has_system(prefix_or_prompt: Prompt) -> bool: |
|
|
return any(['{{SYSTEM}}' in p for p in prefix_or_prompt]) |
|
|
|
|
|
@staticmethod |
|
|
def _replace_system(prefix: Prompt) -> Prompt: |
|
|
return [p.replace('{{SYSTEM}}', '') for p in prefix if isinstance(p, str)] |
|
|
|
|
|
def _check_template_meta(self): |
|
|
|
|
|
for x in [self.prefix, self.prompt, self.suffix]: |
|
|
assert isinstance(x, list) |
|
|
for x in [self.chat_sep, self.system_prefix]: |
|
|
assert x is None or isinstance(x, list) |
|
|
|
|
|
def __post_init__(self): |
|
|
|
|
|
if self._has_system(self.prefix): |
|
|
assert self.system_prefix is None, 'The prefix already contains {{SYSTEM}}.' |
|
|
self.system_prefix = self.prefix |
|
|
self.prefix = self._replace_system(self.prefix) |
|
|
|
|
|
self.is_post_system = self._has_system(self.prompt) |
|
|
if self.is_post_system: |
|
|
self.prompt = [context for context in self.prompt if '{{SYSTEM}}' not in context] |
|
|
self.system_prompt = self.prompt |
|
|
|
|
|
if self.system_prefix is None and not self.is_post_system: |
|
|
self.support_system = False |
|
|
else: |
|
|
self.support_system = True |
|
|
self.check_system(self.default_system) |
|
|
|
|
|
self.support_multi_round = self.chat_sep is not None |
|
|
|
|
|
@staticmethod |
|
|
def _token_attr_to_id(tokenizer: PreTrainedTokenizerBase, value: Optional[Prompt]) -> Optional[Prompt]: |
|
|
"""Turn `eos_token_id` to token id |
|
|
|
|
|
e.g. [['eos_token_id']] -> [[2]] |
|
|
""" |
|
|
if value is None: |
|
|
return None |
|
|
res_value = [] |
|
|
for v in value: |
|
|
if isinstance(v, list): |
|
|
v = [getattr(tokenizer, sub_v) if isinstance(sub_v, str) else sub_v for sub_v in v] |
|
|
res_value.append(v) |
|
|
return res_value |
|
|
|
|
|
def init(self, tokenizer: PreTrainedTokenizerBase) -> None: |
|
|
for key in ['prefix', 'prompt', 'chat_sep', 'suffix', 'system_prefix']: |
|
|
value = getattr(self, key) |
|
|
value = self._token_attr_to_id(tokenizer, value) |
|
|
setattr(self, key, value) |
|
|
|
|
|
if self.suffix and self.suffix[-1] not in self.stop_words: |
|
|
self.stop_words.append(self.suffix[-1]) |
|
|
if tokenizer.eos_token not in self.stop_words: |
|
|
self.stop_words.append(tokenizer.eos_token) |
|
|
|
|
|
self.stop_token_id = tokenizer.eos_token_id |
|
|
if self.suffix: |
|
|
suffix_tokens = self.suffix[-1] |
|
|
if isinstance(suffix_tokens, str): |
|
|
stop_token_id = tokenizer.convert_tokens_to_ids(suffix_tokens) |
|
|
elif isinstance(suffix_tokens, list) and len(suffix_tokens) == 1: |
|
|
stop_token_id = suffix_tokens[0] |
|
|
else: |
|
|
stop_token_id = None |
|
|
if stop_token_id is not None: |
|
|
self.stop_token_id = stop_token_id |
|
|
|
|
|
def check_system(self, system: Optional[str]) -> None: |
|
|
if system is not None: |
|
|
assert self.support_system, ( |
|
|
f'The template does not support `system`, template_type: {self.template_type}, system: {system}') |
|
|
|