| '''AntGLM Chat-model data format. |
| |
| 格式化 AntGLM 以及各种开源模型的符号系统: |
| - 确定 Chat 模型依赖的文件数据结构协议 |
| - 确定单轮/多轮的统一结构 |
| - 确定 Chat 符号系统的协议, 包括角色定义、分隔符等 |
| - 方便做开源模型依赖的 prompt 转换 |
| - 支持工具、代码、推理等支持 |
| |
| 参考 FastChat Conversation 对象的设计思路. |
| Reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py |
| ''' |
|
|
| import copy |
| import dataclasses |
| import logging |
| import re |
| import uuid |
| from copy import deepcopy |
| from enum import IntEnum, auto |
| from typing import Dict, List, Optional, Tuple |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class PromptStyle(IntEnum): |
| '''Prompt styles.''' |
|
|
| |
| ANTGLM_RAW = auto() |
| |
| ANTGLM_CHAT = auto() |
| |
| ANTGLM_ONLY_MULTITURN_CHAT = auto() |
| |
| CHATML = auto() |
| |
| LLAMA2 = auto() |
| |
| CHATGLM = auto() |
| |
| CHATGLM3 = auto() |
| |
| BAICHUAN2 = auto() |
|
|
|
|
| @dataclasses.dataclass |
| class Chat: |
| '''Chat 数据符号结构, 格式化 AntGLM 以及各种开源模型的符号系统. |
| |
| Examples: |
| |
| ```python |
| >>> from antllm.data.chat_format import Chat |
| |
| >>> ### 从 json 数据结构创建 chat 对象, 并且 format 结构使用 AntGLM 原始结构 |
| >>> input_json = { |
| ... "messages": [ |
| ... {"role": "HUMAN", "content": "讲一个笑话"}, |
| ... {"role": "ASSISTANT", "content": "为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!"}, |
| ... {"role": "HUMAN", "content": "不好笑,换个程序员的笑话"} |
| ... ], |
| ... } |
| >>> chat = Chat.from_json(input_json, name='antglm_raw') |
| |
| >>> ### 根据 chat 对象创建大模型训练所需 pack 数据 |
| >>> pack_data = chat.prompt_pack |
| >>> print(pack_data) |
| |
| >>> ### 根据 chat 对象创建大模型训练所需 input, output 数据 |
| >>> data = chat.prompt_inout |
| >>> print(data) |
| |
| >>> ### 根据 chat 对象创建大模型预测用的 prompt |
| >>> prompt = chat.prompt_str |
| >>> print(prompt) |
| |
| >>> ### 从大模型训练数据 {"input": "xx", "output": "xx"} 中创建 chat 对象 |
| >>> data = { |
| ... 'input': ( |
| ... '第1轮\n用户: 讲一个笑话\n机器人: 为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!\n' |
| ... '第2轮\n用户: 不好笑,换个程序员的笑话\n机器人:' |
| ... ), |
| ... 'output': '' |
| ... } |
| >>> chat = Chat.from_inout(data, name='antglm_raw') |
| |
| >>> ### 从大模型 pack 训练数据创建 chat 对象列表 |
| >>> pack_data = { |
| ... 'inputs': ['第1轮\n用户: 讲一个笑话\n机器人:', '第2轮\n用户: 不好笑,换个程序员的笑话\n机器人:', '第1轮\n用户: 写首诗\n机器人:'], |
| ... 'outputs': [ |
| ... '为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!\n', |
| ... '为什么程序员总是喜欢使用黑色主题?因为他们喜欢“黑暗模式”(Dark Mode),这样他们就可以在晚上加班时更好地隐藏自己的错误!', |
| ... ''] |
| ... } |
| >>> chats = Chat.from_pack(pack_data, name='antglm_raw') |
| >>> assert len(chats) == 2 |
| >>> print(chats[0]) |
| >>> print(chats[1]) |
| |
| >>> ### 显示总交互轮数 (以用户输出多少次为轮数个数) |
| >>> print(chat.turns_num) |
| |
| >>> ### 根据 chat 对象创建 json 格式化输出 |
| >>> data_json = chat.to_json() |
| >>> print(data_json) |
| |
| >>> ### 增加轮次信息 |
| >>> content = ( |
| ... '为什么程序员总是喜欢使用黑色主题?' |
| ... '因为他们喜欢“黑暗模式”(Dark Mode),这样他们就可以在晚上加班时更好地隐藏自己的错误!' |
| ... ) |
| >>> chat.append_message(chat.role_assistant, content) |
| |
| >>> ### 将 chat 对象转成 OpenAI ChatCompletion 接口的入参 |
| >>> openai_messages = chat.to_openai_api_messages() |
| >>> print(openai_messages) |
| |
| >>> ### 复制一个 chat 对象 |
| >>> chat_new = chat.copy() |
| ``` |
| ''' |
|
|
| |
| id: str = None |
|
|
| |
| name: Optional[str] = None |
|
|
| |
| prompt_style: Optional[PromptStyle] = None |
|
|
| |
| system_template: str = '<role>SYSTEM</role>{}' |
| system_message: str = '' |
|
|
| |
| role_human: str = 'HUMAN' |
| role_assistant: str = 'ASSISTANT' |
| role_observation: str = 'OBSERVATION' |
| role_template: str = '<role>{}</role>' |
|
|
| |
| turn_start: str = '' |
| human_end: str = '' |
| assistant_start: str = '' |
| assistant_end: str = '' |
| assistant_end_ids: Optional[List[int]] = None |
| general_role_end: str = '' |
|
|
| |
| tool_template = '<tool>{}</tool>' |
| code_template = '<code>{}</code>' |
| arithemetic_templte = '<arithemetic>{}</arithemetic>' |
| image_template = '<image>{}</image>' |
|
|
| |
| messages: List[Tuple[str, str]] = () |
|
|
| |
| offset: int = 0 |
|
|
| |
| source: Optional[str] = None |
| lang: Optional[str] = None |
| topic: Optional[str] = None |
|
|
| |
| origin_json: Optional[dict] = None |
|
|
| @property |
| def support_names(self) -> Dict[str, str]: |
| '''支持的数据对象名称.''' |
| return { |
| 'antglm_raw': '原始 antglm format 格式, 单轮指令没有结构, 多轮 `第1轮\\n用户:xx\\n机器人xx\\n`', |
| 'antglm_chat': 'Chat format 格式, 单轮多轮统一为 chat format 格式', |
| 'chatglm1': 'chatglm1 format', |
| 'chatglm2': 'chatglm2 format', |
| 'llama2': 'llama2 format', |
| 'qwen': '千问 format', |
| 'baichuan2': '百川 2 format', |
| } |
|
|
| @classmethod |
| def from_json( |
| cls, |
| input: dict, |
| name: Optional[str] = None, |
| prompt_style: Optional[PromptStyle] = None, |
| ): |
| '''从文件数据结构到数据对象的转换. |
| |
| Params: |
| name: `Optional[str]`, 符号系统名称 |
| - format 支持: antglm_raw, antglm_chat, chatglm1, chatglm2, llama2, qwen, baichuan2 |
| - 如果指定了 format name, 使用该 name 符号系统, 否则使用 input 中 `name` 字段 |
| |
| prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格 |
| |
| input: `dict`, 文件中的 json dict 对象, 协议为: |
| - 既支持 `messages` 字段, 也支持 `turns` 字段 |
| { |
| "id": "xxx", |
| "name": "antglm", |
| "source": "xxx", |
| "lang": "xx", |
| "topic": "xx", |
| "system_template": "", |
| "system_message": "xx", |
| "messages": [ |
| { |
| "role": "HUMAN", |
| "content": "Hi" |
| }, |
| { |
| "role": "ASSISTANT", |
| "content": "Hello" |
| }, |
| { |
| "role": "OBSERVATION", |
| "content": "xxx" |
| }, |
| { |
| "role": "ASSISTANT", |
| "content": "xxx" |
| } |
| ], |
| "turns": [ |
| {"HUMAN": "xxx", "OBSERVATION": "xx", "ASSISTANT": "xx"} |
| ] |
| } |
| |
| Returns: |
| `Chat` 对象 |
| ''' |
| _id = input.get('id') |
| if name: |
| _name = name |
| else: |
| _name = input.get('name') |
| source = input.get('source') |
| lang = input.get('lang') |
| topic = input.get('topic') |
| kwargs = {} |
| if 'system_template' in input: |
| kwargs['system_template'] = input['system_template'] |
| if 'system_message' in input: |
| kwargs['system_message'] = input['system_message'] |
|
|
| |
| chat = cls( |
| id=_id, |
| name=_name, |
| prompt_style=prompt_style, |
| source=source, |
| lang=lang, |
| topic=topic, |
| origin_json=deepcopy(input), |
| **kwargs, |
| ) |
| if 'messages' in input: |
| for msg in input['messages']: |
| if msg['role'] == 'HUMAN': |
| role = chat.role_human |
| elif msg['role'] == 'OBSERVATION': |
| role = chat.role_observation |
| elif msg['role'] == 'ASSISTANT': |
| role = chat.role_assistant |
| else: |
| raise ValueError(f'不支持数据集中的 role: {msg["role"]}') |
|
|
| chat.append_message(role, msg['content']) |
|
|
| elif 'turns' in input: |
| for turn in input['turns']: |
| if 'HUMAN' in turn: |
| content = turn['HUMAN'] |
| chat.append_message(chat.role_human, content) |
| if 'OBSERVATION' in turn: |
| content = turn['OBSERVATION'] |
| chat.append_message(chat.role_observation, content) |
| if 'ASSISTANT' in turn: |
| content = turn['ASSISTANT'] |
| chat.append_message(chat.role_assistant, content) |
|
|
| return chat |
|
|
| @classmethod |
| def from_pack( |
| cls, |
| packs: Dict[str, List[str]], |
| name: str, |
| prompt_style: Optional[PromptStyle] = None, |
| ) -> list: |
| '''根据 pack 数据创建 Chat 对象. |
| |
| Params: |
| packs: `dict`, pack 样本数据 |
| { |
| 'inputs': ['xx', 'xx'], |
| 'outputs': ['xx', 'xx'], |
| } |
| |
| name: `str`, 符号系统名称 |
| prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格 |
| ''' |
| chat = cls(name=name, prompt_style=prompt_style) |
| packs = cls._format_packs(packs) |
|
|
| sys_pattern = re.compile(chat.system_template.format(r'(.*?)'), re.DOTALL) |
| turn_pattern = re.compile(chat.turn_start.format(r'(\d+)'), re.DOTALL) |
| human_pattern = re.compile(chat.role_template.format(chat.role_human).strip(), re.DOTALL) |
| observe_pattern = re.compile(chat.role_template.format(chat.role_observation).strip(), re.DOTALL) |
| assistant_pattern = re.compile(chat.role_template.format(chat.role_assistant).strip(), re.DOTALL) |
|
|
| chats = [] |
| for input, output in zip(packs['input'], packs['output']): |
| |
| sys_match = sys_pattern.search(input) |
| if sys_match and sys_match.group(0): |
| |
| if len(chat.messages) > 0: |
| chats.append(chat) |
| chat = cls(name=name, prompt_style=prompt_style) |
|
|
| input = input[sys_match.end() :] |
| chat.system_message = sys_match.group(1) |
|
|
| |
| turn_match = turn_pattern.search(input) |
| if turn_match and turn_match.group(0): |
| |
| if name in ['antglm', 'antglm_raw', 'chatglm2']: |
| round_start = 1 |
| else: |
| round_start = 0 |
|
|
| if all( |
| [ |
| len(turn_match.groups()) > 0, |
| int(turn_match.group(1)) == round_start, |
| len(chat.messages) > 0, |
| ] |
| ): |
| chats.append(chat) |
| chat = cls(name=name, prompt_style=prompt_style) |
|
|
| input = input[turn_match.end() :] |
|
|
| human_iter = human_pattern.finditer(input) |
| observe_iter = observe_pattern.finditer(input) |
| assistant_iter = assistant_pattern.finditer(input) |
| human_match = next(human_iter, None) |
| observe_match = next(observe_iter, None) |
| assistant_match = next(assistant_iter, None) |
|
|
| if not human_match and not observe_match: |
| |
| chat.append_message(chat.role_human, input) |
|
|
| while human_match or observe_match: |
| next_human_match = next(human_iter, None) |
| next_observe_match = next(observe_iter, None) |
| input = cls._append_human_observation( |
| chat, |
| input, |
| human_match=human_match, |
| next_human_match=next_human_match, |
| observe_match=observe_match, |
| next_observe_match=next_observe_match, |
| assistant_match=assistant_match, |
| ) |
|
|
| human_match = next_human_match |
| observe_match = next_observe_match |
| next_human_match = next(human_iter, None) |
| next_observe_match = next(observe_iter, None) |
|
|
| if output: |
| chat.append_message(chat.role_assistant, output) |
|
|
| if chat.messages: |
| chats.append(chat) |
|
|
| return chats |
|
|
| @classmethod |
| def _append_human_observation( |
| cls, |
| chat, |
| input: str, |
| human_match: Optional[re.Match] = None, |
| next_human_match: Optional[re.Match] = None, |
| observe_match: Optional[re.Match] = None, |
| next_observe_match: Optional[re.Match] = None, |
| assistant_match: Optional[re.Match] = None, |
| ) -> str: |
| '''给 chat 对象增加 human/observation message.''' |
| if observe_match: |
| |
| if observe_match.span()[0] > observe_match.span()[0]: |
| human_str = input[observe_match.span()[1] : observe_match.span()[0]] |
| observe_str = input[observe_match.span()[1] : assistant_match.span()[0]] |
| chat.append_message(chat.role_human, human_str.strip()) |
| input_end = observe_match.span()[1] |
| if observe_match.span()[0] < next_human_match.span()[0]: |
| chat.append_message(chat.role_observation, observe_str.strip()) |
| input_end = assistant_match.span()[1] |
| else: |
| |
| human_str = input[observe_match.span()[1] : assistant_match.span()[0]] |
| observe_str = input[observe_match.span()[1] : observe_match.span()[0]] |
| chat.append_message(chat.role_observation, observe_str.strip()) |
| input_end = observe_match.span()[1] |
| if observe_match.span()[0] < next_observe_match.span()[0]: |
| chat.append_message(chat.role_human, human_str.strip()) |
| input_end = assistant_match.span()[1] |
| else: |
| if assistant_match: |
| human_str = input[human_match.span()[1] : assistant_match.span()[0]] |
| input_end = assistant_match.span()[1] |
| else: |
| human_str = input[human_match.span()[1] :] |
| input_end = len(input) |
| chat.append_message(chat.role_human, human_str.strip()) |
|
|
| return input[input_end:] |
|
|
| @classmethod |
| def from_inout( |
| cls, |
| sample: Dict[str, str], |
| name: str, |
| prompt_style: Optional[PromptStyle] = None, |
| ): |
| '''根据单样本创建一个 Chat 对象. |
| |
| Params: |
| sample: `Dict[str, str]`, input/output 数据样本 |
| { |
| "input": "xxx", |
| "output": "xxx", |
| } |
| |
| name: `str`, 符号系统名称 |
| prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格 |
| ''' |
| chat = cls(name=name, prompt_style=prompt_style) |
| input = sample['input'] |
| output = sample['output'] |
|
|
| sys_pattern = re.compile(chat.system_template.format(r'(.*?)'), re.DOTALL) |
| turn_pattern = re.compile(chat.turn_start.format(r'(\d+)'), re.DOTALL) |
| human_pattern = re.compile(chat.role_template.format(chat.role_human).strip(), re.DOTALL) |
| observe_pattern = re.compile(chat.role_template.format(chat.role_observation).strip(), re.DOTALL) |
| assistant_pattern = re.compile(chat.role_template.format(chat.role_assistant).strip(), re.DOTALL) |
|
|
| |
| input = turn_pattern.sub('', input) |
|
|
| |
| sys_match = sys_pattern.search(input) |
| if sys_match and sys_match.group(0): |
| input = input[sys_match.end() :] |
| chat.system_message = sys_match.group(1) |
|
|
| human_iter = human_pattern.finditer(input) |
| observe_iter = observe_pattern.finditer(input) |
| assistant_iter = assistant_pattern.finditer(input) |
| human_match = next(human_iter, None) |
| observe_match = next(observe_iter, None) |
| assistant_match = next(assistant_iter, None) |
| next_human_match = next(human_iter, None) |
| next_observe_match = next(observe_iter, None) |
|
|
| while any( |
| [ |
| human_match, |
| observe_match, |
| assistant_match, |
| ] |
| ): |
|
|
| |
| |
| while any( |
| [ |
| human_match and human_match.span()[0] < assistant_match.span()[0], |
| observe_match and observe_match.span()[0] < assistant_match.span()[0], |
| next_human_match and next_human_match.span()[0] < assistant_match.span()[0], |
| next_observe_match and next_observe_match.span()[0] < assistant_match.span()[0], |
| ] |
| ): |
| if not input: |
| break |
|
|
| cls._append_human_observation( |
| chat, |
| input, |
| human_match=human_match, |
| next_human_match=next_human_match, |
| observe_match=observe_match, |
| next_observe_match=next_observe_match, |
| assistant_match=assistant_match, |
| ) |
|
|
| human_match = next_human_match |
| observe_match = next_observe_match |
| next_human_match = next(human_iter, None) |
| next_observe_match = next(observe_iter, None) |
|
|
| |
| if assistant_match and assistant_match.span(): |
| if observe_match: |
| if observe_match.span() and observe_match.span()[0] < human_match.span()[0]: |
| assistant_str = input[assistant_match.span()[1] : observe_match.span()[0]] |
| elif human_match: |
| if human_match.span(): |
| assistant_str = input[assistant_match.span()[1] : human_match.span()[0]] |
| else: |
| assistant_str = input[assistant_match.span()[1] :] |
|
|
| if assistant_str: |
| chat.append_message(chat.role_assistant, assistant_str) |
|
|
| assistant_match = next(assistant_iter, None) |
|
|
| if output: |
| chat.append_message(chat.role_assistant, output) |
|
|
| return chat |
|
|
| def __hash__(self): |
| '''数据对象的 hash 函数.''' |
| return hash(self.id) |
|
|
| def __post_init__(self): |
| '''对象初始化后的处理, 处理包括: |
| - 根据数据对象名称, 支持转成其他开源数据对象的基本信息 |
| ''' |
| self.id = str(uuid.uuid4()) |
| if not self.messages: |
| self.messages = [] |
|
|
| if not self.name and not self.prompt_style: |
| logger.error('构造 Chat 对象至少包含以下一个入参: `name/prompt_style`.\n\n' '`name` 支持以下 format 名称:') |
| logger.error('\n'.join([f'{k}: {v}' for k, v in self.support_names.items()])) |
| logger.error('\n`prompt_style` 参考 antllm.data.chat_format.PromptStyle') |
| raise ValueError |
|
|
| if self.name == 'antglm': |
| |
| self.name = 'antglm_raw' |
|
|
| if not self.name and self.prompt_style == PromptStyle.ANTGLM_CHAT: |
| logger.info( |
| 'Chat 对象入参没有 `name`, 默认使用 `ANTGLM_CHAT`, format:\n' |
| f'role_human: {self.role_human}\n' |
| f'role_assistant: {self.role_assistant}\n' |
| f'role_observation: {self.role_observation}\n' |
| f'role_template: {self.role_template}\n' |
| f'turn_start: {self.turn_start}\n' |
| f'human_end: {self.human_end}\n' |
| f'assistant_start: {self.assistant_start}\n' |
| f'assistant_end: {self.assistant_end}\n' |
| f'assistant_end_ids: {self.assistant_end_ids}\n' |
| f'general_role_end: {self.general_role_end}\n' |
| f'tool_template: {self.tool_template}\n' |
| f'code_template: {self.code_template}\n' |
| f'arithemetic_templte: {self.arithemetic_templte}\n' |
| f'image_template: {self.image_template}\n' |
| f'\n入参 `name` 支持: ``' |
| ) |
| return |
|
|
| if self.name == 'antglm_raw' or self.prompt_style == PromptStyle.ANTGLM_RAW: |
| self.prompt_style = PromptStyle.ANTGLM_RAW |
| self.role_template = '{}' |
| self.role_human = '用户: ' |
| self.role_assistant = '机器人: ' |
| self.turn_start = '第{}轮\n' |
| self.general_role_end = '\n' |
|
|
| if self.name in ['chatglm1', 'chatglm2'] or self.prompt_style == PromptStyle.CHATGLM: |
| self.prompt_style = PromptStyle.CHATGLM |
| self.role_template = '{}' |
| self.role_human = '问:' |
| self.role_assistant = '答:' |
| self.turn_start = '[Round {}]\n' |
| if self.name == 'chatglm1': |
| self.general_role_end = '\n' |
| else: |
| self.general_role_end = '\n\n' |
|
|
| elif self.name == 'chatglm3' or self.prompt_style == PromptStyle.CHATGLM3: |
| self.prompt_style = PromptStyle.CHATGLM3 |
| self.system_template = '<|system|>\n {}' |
| self.role_human = '<|user|>\n ' |
| self.role_assistant = '<|assistant|>\n ' |
| self.role_template = '{}' |
|
|
| elif self.name == 'llama2' or self.prompt_style == PromptStyle.LLAMA2: |
| self.prompt_style = PromptStyle.LLAMA2 |
| self.role_template = '{}' |
| self.system_template = '[INST] <<SYS>>\n{}\n<</SYS>>\n\n' |
| self.role_human = '[INST] ' |
| self.role_assistant = '[/INST] ' |
| self.human_end = ' ' |
| self.assistant_end = ' </s><s>' |
|
|
| elif self.name == 'qwen': |
| self.prompt_style = PromptStyle.CHATML |
| self.role_template = '{}' |
| self.system_template = '<|im_start|>system\n{}' |
| if not self.system_message: |
| self.system_message = 'You are a helpful assistant.' |
| self.role_human = '<|im_start|>user\n' |
| self.role_assistant = '<|im_start|>assistant\n' |
| self.general_role_end = '<|im_end|>\n' |
|
|
| elif self.name == 'baichuan': |
| self.prompt_style = PromptStyle.BAICHUAN2 |
| self.role_template = '{}' |
| self.system_template = '{}' |
| self.role_human = '<token_id-195>' |
| self.role_assistant = '<token_id-196>' |
|
|
| if not self.system_template: |
| self.system_template = '{}' |
|
|
| def readable_messages(self) -> str: |
| '''将 messages 输出为人类可读的字符串, 方便分析数据.''' |
| pass |
|
|
| @property |
| def prompt_str(self) -> str: |
| '''将 Chat 对象转成 prompt str, 合并 human/assitant 输出为 format 字符串.''' |
| return f'{self.prompt_inout["input"]}{self.prompt_inout["output"]}' |
|
|
| @classmethod |
| def _format_packs(cls, packs: Dict[str, List[str]]) -> Dict[str, List[str]]: |
| '''格式化 pack 样本, 输出相同 pack inputs, outputs 个数.''' |
| _packs = copy.deepcopy(packs) |
| if len(_packs['input']) - 1 == len(_packs['output']): |
| _packs['output'].append('') |
|
|
| if len(_packs['input']) != len(_packs['output']): |
| print(packs) |
| raise ValueError( |
| '输入 input 和 output 数量不匹配, ' |
| f'input num: {len(packs["input"])}, ' |
| f'output num: {len(packs["output"])}' |
| ) |
|
|
| return _packs |
|
|
| @property |
| def prompt_inout(self) -> Dict[str, str]: |
| '''将 Chat 对象转成 input prompt, output prompt 字符串. |
| |
| Returns: |
| `Dict[str, str]`, 示例: |
| { |
| "input": "<role>SYSTEM</role>xxxx<role>HUMAN</role>你好<role>ASSISTANT</role>你好,有什么可以帮您?<role>ASSISTANT</role>", # noqa |
| "output": "你好,有什么可以帮您?" |
| } |
| ''' |
| packs = self._format_packs(self.prompt_pack) |
|
|
| |
| if self.prompt_style == PromptStyle.ANTGLM_RAW: |
| packs['input'] = [f'{item} ' for item in packs['input']] |
|
|
| prompt_input = ''.join([f'{x}{y}' for x, y in zip(packs['input'][:-1], packs['output'][:-1])]) |
| prompt_input += packs['input'][-1] |
| prompt_output = packs['output'][-1] |
|
|
| |
| if self.prompt_style == PromptStyle.ANTGLM_RAW: |
| prompt_input = prompt_input.strip() |
|
|
| return { |
| 'input': prompt_input, |
| 'output': prompt_output, |
| } |
|
|
| @property |
| def prompt_pack(self) -> Dict[str, List[str]]: |
| '''将数据对象转成 pack input prompt, output prompt 字符串列表.: |
| |
| Returns: |
| `Dict[str, List[str]]`, 示例: |
| |
| { |
| "input": [ |
| "<role>SYSTEM</role>xxxx<role>HUMAN</role>你好<role>ASSISTANT</role>", |
| "<role>HUMAN</role>讲个笑话<role>ASSISTANT</role>", |
| "<role>OBSERVATION</role>{\"weather\": \"晴\"}<role>ASSISTANT</role>" |
| ], |
| "output": [ |
| "你好,有什么可以帮您?", |
| "笑话 1", |
| "今天天气 xxx" |
| ] |
| } |
| |
| ''' |
| inputs = [] |
| outputs = [] |
|
|
| |
| system_prompt = '' |
| if self.system_message: |
| system_prompt = self.system_template.format(self.system_message) |
|
|
| if system_prompt: |
| ret = system_prompt + self.general_role_end |
| else: |
| ret = '' |
|
|
| |
| if self.prompt_style in [ |
| PromptStyle.ANTGLM_RAW, |
| PromptStyle.ANTGLM_ONLY_MULTITURN_CHAT, |
| ]: |
| if len(self.messages) <= 2: |
| output = '' |
| for role, message in self.messages: |
| if role == self.role_assistant: |
| output = message |
| else: |
| input = ret + message |
| return { |
| 'input': [input], |
| 'output': [output], |
| } |
|
|
| |
| if self.name in ['antglm_raw', 'chatglm2']: |
| round_start = 1 |
| else: |
| round_start = 0 |
|
|
| for i, (role, message) in enumerate(self.messages): |
| |
| if self.name in ['antglm_raw', 'chatglm1', 'chatglm2']: |
| if i % 2 == 0: |
| ret += self.turn_start.format(i // 2 + round_start) |
|
|
| |
| role_end = self.general_role_end |
| if role == self.role_assistant and self.assistant_end: |
| role_end = self.assistant_end |
| elif self.human_end: |
| role_end = self.human_end |
|
|
| ret += self.role_template.format(role) + message + role_end |
|
|
| if role == self.role_assistant: |
| |
| if not message: |
| outputs.append('') |
| else: |
| outputs.append(message + role_end) |
| |
| inputs[-1] += ret[: -len(message + role_end)] |
| elif all( |
| [ |
| role == self.role_observation, |
| len(self.messages) > 1, |
| self.messages[i - 1][0] != self.role_assistant, |
| ] |
| ): |
| |
| continue |
| else: |
| inputs.append(ret) |
| ret = '' |
|
|
| |
| if i == len(self.messages) - 1 and role != self.role_assistant: |
| inputs[-1] += self.role_template.format(self.role_assistant).strip() |
|
|
| |
| if self.prompt_style == PromptStyle.ANTGLM_RAW: |
| inputs = [item.strip() for item in inputs] |
|
|
| return { |
| 'input': inputs, |
| 'output': outputs, |
| } |
|
|
| @property |
| def turns_num(self) -> int: |
| '''和机器人的交互轮数, 以用户输出多少次为轮数个数.''' |
| return sum([1 if msg[0] == self.role_human else 0 for msg in self.messages]) |
|
|
| def to_json(self) -> dict: |
| '''输出 chat json dict 格式, 包含不同角色和机器人交互的每轮信息. |
| |
| Returns |
| `List[dict]`, { |
| "id": "xx", |
| "messages": [ |
| {"role": "HUMAN", "content": "xxx"} |
| ] |
| "turns": [ |
| {"HUMAN": "xx", "OBSERVATION": "xx", "ASSISTANT": "xx"} |
| ] |
| } |
| ''' |
| turns = [] |
| messages = [] |
| turn = {} |
| for msg in self.messages: |
| if msg[0] == self.role_assistant: |
| messages.append({'role': 'ASSISTANT', 'content': msg[1]}) |
| turn['ASSISTANT'] = msg[1] |
| turns.append(turn) |
| turn = {} |
|
|
| if msg[0] == self.role_human: |
| messages.append({'role': 'HUMAN', 'content': msg[1]}) |
| turn['HUMAN'] = msg[1] |
|
|
| if msg[0] == self.role_observation: |
| messages.append({'role': 'OBSERVATION', 'content': msg[1]}) |
| turn['OBSERVATION'] = msg[1] |
|
|
| if self.messages[-1][0] == self.role_human: |
| messages.append({'role': 'ASSISTANT', 'content': ''}) |
| turn['ASSISTANT'] = '' |
| turns.append(turn) |
|
|
| result = self.origin_json or {} |
| result.update( |
| { |
| 'id': self.id, |
| 'name': self.name, |
| 'source': self.source, |
| 'lang': self.lang, |
| 'topic': self.topic, |
| 'system_template': self.system_template, |
| 'system_message': self.system_message, |
| 'turns': turns, |
| 'messages': messages, |
| } |
| ) |
|
|
| return result |
|
|
| def set_system_message(self, system_message: str): |
| '''Set the system message.''' |
| self.system_message = system_message |
|
|
| def append_message(self, role: str, message: str): |
| '''Append a new message.''' |
| if not message: |
| message = '' |
| self.messages.append([role, message]) |
|
|
| def to_openai_api_messages(self) -> List[dict]: |
| '''Convert the conversation to OpenAI chat completion format.''' |
| ret = [{'role': 'system', 'content': self.system_message}] |
|
|
| for i, (_, msg) in enumerate(self.messages[self.offset :]): |
| if i % 2 == 0: |
| ret.append({'role': 'user', 'content': msg}) |
| else: |
| if msg is not None: |
| ret.append({'role': 'assistant', 'content': msg}) |
| return ret |
|
|
| def copy(self): |
| return copy.deepcopy(self) |
|
|