| | '''AntGLM Chat-model data format. |
| | |
| | 格式化 AntGLM 以及各种开源模型的符号系统: |
| | - 确定 Chat 模型依赖的文件数据结构协议 |
| | - 确定单轮/多轮的统一结构 |
| | - 确定 Chat 符号系统的协议, 包括角色定义、分隔符等 |
| | - 方便做开源模型依赖的 prompt 转换 |
| | - 支持工具、代码、推理等支持 |
| | |
| | 参考 FastChat Conversation 对象的设计思路. |
| | Reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py |
| | ''' |
| |
|
| | import copy |
| | import dataclasses |
| | import logging |
| | import re |
| | import uuid |
| | from copy import deepcopy |
| | from enum import IntEnum, auto |
| | from typing import Dict, List, Optional, Tuple |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class PromptStyle(IntEnum): |
| | '''Prompt styles.''' |
| |
|
| | |
| | ANTGLM_RAW = auto() |
| | |
| | ANTGLM_CHAT = auto() |
| | |
| | ANTGLM_ONLY_MULTITURN_CHAT = auto() |
| | |
| | CHATML = auto() |
| | |
| | LLAMA2 = auto() |
| | |
| | CHATGLM = auto() |
| | |
| | CHATGLM3 = auto() |
| | |
| | BAICHUAN2 = auto() |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class Chat: |
| | '''Chat 数据符号结构, 格式化 AntGLM 以及各种开源模型的符号系统. |
| | |
| | Examples: |
| | |
| | ```python |
| | >>> from antllm.data.chat_format import Chat |
| | |
| | >>> ### 从 json 数据结构创建 chat 对象, 并且 format 结构使用 AntGLM 原始结构 |
| | >>> input_json = { |
| | ... "messages": [ |
| | ... {"role": "HUMAN", "content": "讲一个笑话"}, |
| | ... {"role": "ASSISTANT", "content": "为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!"}, |
| | ... {"role": "HUMAN", "content": "不好笑,换个程序员的笑话"} |
| | ... ], |
| | ... } |
| | >>> chat = Chat.from_json(input_json, name='antglm_raw') |
| | |
| | >>> ### 根据 chat 对象创建大模型训练所需 pack 数据 |
| | >>> pack_data = chat.prompt_pack |
| | >>> print(pack_data) |
| | |
| | >>> ### 根据 chat 对象创建大模型训练所需 input, output 数据 |
| | >>> data = chat.prompt_inout |
| | >>> print(data) |
| | |
| | >>> ### 根据 chat 对象创建大模型预测用的 prompt |
| | >>> prompt = chat.prompt_str |
| | >>> print(prompt) |
| | |
| | >>> ### 从大模型训练数据 {"input": "xx", "output": "xx"} 中创建 chat 对象 |
| | >>> data = { |
| | ... 'input': ( |
| | ... '第1轮\n用户: 讲一个笑话\n机器人: 为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!\n' |
| | ... '第2轮\n用户: 不好笑,换个程序员的笑话\n机器人:' |
| | ... ), |
| | ... 'output': '' |
| | ... } |
| | >>> chat = Chat.from_inout(data, name='antglm_raw') |
| | |
| | >>> ### 从大模型 pack 训练数据创建 chat 对象列表 |
| | >>> pack_data = { |
| | ... 'inputs': ['第1轮\n用户: 讲一个笑话\n机器人:', '第2轮\n用户: 不好笑,换个程序员的笑话\n机器人:', '第1轮\n用户: 写首诗\n机器人:'], |
| | ... 'outputs': [ |
| | ... '为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!\n', |
| | ... '为什么程序员总是喜欢使用黑色主题?因为他们喜欢“黑暗模式”(Dark Mode),这样他们就可以在晚上加班时更好地隐藏自己的错误!', |
| | ... ''] |
| | ... } |
| | >>> chats = Chat.from_pack(pack_data, name='antglm_raw') |
| | >>> assert len(chats) == 2 |
| | >>> print(chats[0]) |
| | >>> print(chats[1]) |
| | |
| | >>> ### 显示总交互轮数 (以用户输出多少次为轮数个数) |
| | >>> print(chat.turns_num) |
| | |
| | >>> ### 根据 chat 对象创建 json 格式化输出 |
| | >>> data_json = chat.to_json() |
| | >>> print(data_json) |
| | |
| | >>> ### 增加轮次信息 |
| | >>> content = ( |
| | ... '为什么程序员总是喜欢使用黑色主题?' |
| | ... '因为他们喜欢“黑暗模式”(Dark Mode),这样他们就可以在晚上加班时更好地隐藏自己的错误!' |
| | ... ) |
| | >>> chat.append_message(chat.role_assistant, content) |
| | |
| | >>> ### 将 chat 对象转成 OpenAI ChatCompletion 接口的入参 |
| | >>> openai_messages = chat.to_openai_api_messages() |
| | >>> print(openai_messages) |
| | |
| | >>> ### 复制一个 chat 对象 |
| | >>> chat_new = chat.copy() |
| | ``` |
| | ''' |
| |
|
| | |
| | id: str = None |
| |
|
| | |
| | name: Optional[str] = None |
| |
|
| | |
| | prompt_style: Optional[PromptStyle] = None |
| |
|
| | |
| | system_template: str = '<role>SYSTEM</role>{}' |
| | system_message: str = '' |
| |
|
| | |
| | role_human: str = 'HUMAN' |
| | role_assistant: str = 'ASSISTANT' |
| | role_observation: str = 'OBSERVATION' |
| | role_template: str = '<role>{}</role>' |
| |
|
| | |
| | turn_start: str = '' |
| | human_end: str = '' |
| | assistant_start: str = '' |
| | assistant_end: str = '' |
| | assistant_end_ids: Optional[List[int]] = None |
| | general_role_end: str = '' |
| |
|
| | |
| | tool_template = '<tool>{}</tool>' |
| | code_template = '<code>{}</code>' |
| | arithemetic_templte = '<arithemetic>{}</arithemetic>' |
| | image_template = '<image>{}</image>' |
| |
|
| | |
| | messages: List[Tuple[str, str]] = () |
| |
|
| | |
| | offset: int = 0 |
| |
|
| | |
| | source: Optional[str] = None |
| | lang: Optional[str] = None |
| | topic: Optional[str] = None |
| |
|
| | |
| | origin_json: Optional[dict] = None |
| |
|
| | @property |
| | def support_names(self) -> Dict[str, str]: |
| | '''支持的数据对象名称.''' |
| | return { |
| | 'antglm_raw': '原始 antglm format 格式, 单轮指令没有结构, 多轮 `第1轮\\n用户:xx\\n机器人xx\\n`', |
| | 'antglm_chat': 'Chat format 格式, 单轮多轮统一为 chat format 格式', |
| | 'chatglm1': 'chatglm1 format', |
| | 'chatglm2': 'chatglm2 format', |
| | 'llama2': 'llama2 format', |
| | 'qwen': '千问 format', |
| | 'baichuan2': '百川 2 format', |
| | } |
| |
|
| | @classmethod |
| | def from_json( |
| | cls, |
| | input: dict, |
| | name: Optional[str] = None, |
| | prompt_style: Optional[PromptStyle] = None, |
| | ): |
| | '''从文件数据结构到数据对象的转换. |
| | |
| | Params: |
| | name: `Optional[str]`, 符号系统名称 |
| | - format 支持: antglm_raw, antglm_chat, chatglm1, chatglm2, llama2, qwen, baichuan2 |
| | - 如果指定了 format name, 使用该 name 符号系统, 否则使用 input 中 `name` 字段 |
| | |
| | prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格 |
| | |
| | input: `dict`, 文件中的 json dict 对象, 协议为: |
| | - 既支持 `messages` 字段, 也支持 `turns` 字段 |
| | { |
| | "id": "xxx", |
| | "name": "antglm", |
| | "source": "xxx", |
| | "lang": "xx", |
| | "topic": "xx", |
| | "system_template": "", |
| | "system_message": "xx", |
| | "messages": [ |
| | { |
| | "role": "HUMAN", |
| | "content": "Hi" |
| | }, |
| | { |
| | "role": "ASSISTANT", |
| | "content": "Hello" |
| | }, |
| | { |
| | "role": "OBSERVATION", |
| | "content": "xxx" |
| | }, |
| | { |
| | "role": "ASSISTANT", |
| | "content": "xxx" |
| | } |
| | ], |
| | "turns": [ |
| | {"HUMAN": "xxx", "OBSERVATION": "xx", "ASSISTANT": "xx"} |
| | ] |
| | } |
| | |
| | Returns: |
| | `Chat` 对象 |
| | ''' |
| | _id = input.get('id') |
| | if name: |
| | _name = name |
| | else: |
| | _name = input.get('name') |
| | source = input.get('source') |
| | lang = input.get('lang') |
| | topic = input.get('topic') |
| | kwargs = {} |
| | if 'system_template' in input: |
| | kwargs['system_template'] = input['system_template'] |
| | if 'system_message' in input: |
| | kwargs['system_message'] = input['system_message'] |
| |
|
| | |
| | chat = cls( |
| | id=_id, |
| | name=_name, |
| | prompt_style=prompt_style, |
| | source=source, |
| | lang=lang, |
| | topic=topic, |
| | origin_json=deepcopy(input), |
| | **kwargs, |
| | ) |
| | if 'messages' in input: |
| | for msg in input['messages']: |
| | if msg['role'] == 'HUMAN': |
| | role = chat.role_human |
| | elif msg['role'] == 'OBSERVATION': |
| | role = chat.role_observation |
| | elif msg['role'] == 'ASSISTANT': |
| | role = chat.role_assistant |
| | else: |
| | raise ValueError(f'不支持数据集中的 role: {msg["role"]}') |
| |
|
| | chat.append_message(role, msg['content']) |
| |
|
| | elif 'turns' in input: |
| | for turn in input['turns']: |
| | if 'HUMAN' in turn: |
| | content = turn['HUMAN'] |
| | chat.append_message(chat.role_human, content) |
| | if 'OBSERVATION' in turn: |
| | content = turn['OBSERVATION'] |
| | chat.append_message(chat.role_observation, content) |
| | if 'ASSISTANT' in turn: |
| | content = turn['ASSISTANT'] |
| | chat.append_message(chat.role_assistant, content) |
| |
|
| | return chat |
| |
|
| | @classmethod |
| | def from_pack( |
| | cls, |
| | packs: Dict[str, List[str]], |
| | name: str, |
| | prompt_style: Optional[PromptStyle] = None, |
| | ) -> list: |
| | '''根据 pack 数据创建 Chat 对象. |
| | |
| | Params: |
| | packs: `dict`, pack 样本数据 |
| | { |
| | 'inputs': ['xx', 'xx'], |
| | 'outputs': ['xx', 'xx'], |
| | } |
| | |
| | name: `str`, 符号系统名称 |
| | prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格 |
| | ''' |
| | chat = cls(name=name, prompt_style=prompt_style) |
| | packs = cls._format_packs(packs) |
| |
|
| | sys_pattern = re.compile(chat.system_template.format(r'(.*?)'), re.DOTALL) |
| | turn_pattern = re.compile(chat.turn_start.format(r'(\d+)'), re.DOTALL) |
| | human_pattern = re.compile(chat.role_template.format(chat.role_human).strip(), re.DOTALL) |
| | observe_pattern = re.compile(chat.role_template.format(chat.role_observation).strip(), re.DOTALL) |
| | assistant_pattern = re.compile(chat.role_template.format(chat.role_assistant).strip(), re.DOTALL) |
| |
|
| | chats = [] |
| | for input, output in zip(packs['input'], packs['output']): |
| | |
| | sys_match = sys_pattern.search(input) |
| | if sys_match and sys_match.group(0): |
| | |
| | if len(chat.messages) > 0: |
| | chats.append(chat) |
| | chat = cls(name=name, prompt_style=prompt_style) |
| |
|
| | input = input[sys_match.end() :] |
| | chat.system_message = sys_match.group(1) |
| |
|
| | |
| | turn_match = turn_pattern.search(input) |
| | if turn_match and turn_match.group(0): |
| | |
| | if name in ['antglm', 'antglm_raw', 'chatglm2']: |
| | round_start = 1 |
| | else: |
| | round_start = 0 |
| |
|
| | if all( |
| | [ |
| | len(turn_match.groups()) > 0, |
| | int(turn_match.group(1)) == round_start, |
| | len(chat.messages) > 0, |
| | ] |
| | ): |
| | chats.append(chat) |
| | chat = cls(name=name, prompt_style=prompt_style) |
| |
|
| | input = input[turn_match.end() :] |
| |
|
| | human_iter = human_pattern.finditer(input) |
| | observe_iter = observe_pattern.finditer(input) |
| | assistant_iter = assistant_pattern.finditer(input) |
| | human_match = next(human_iter, None) |
| | observe_match = next(observe_iter, None) |
| | assistant_match = next(assistant_iter, None) |
| |
|
| | if not human_match and not observe_match: |
| | |
| | chat.append_message(chat.role_human, input) |
| |
|
| | while human_match or observe_match: |
| | next_human_match = next(human_iter, None) |
| | next_observe_match = next(observe_iter, None) |
| | input = cls._append_human_observation( |
| | chat, |
| | input, |
| | human_match=human_match, |
| | next_human_match=next_human_match, |
| | observe_match=observe_match, |
| | next_observe_match=next_observe_match, |
| | assistant_match=assistant_match, |
| | ) |
| |
|
| | human_match = next_human_match |
| | observe_match = next_observe_match |
| | next_human_match = next(human_iter, None) |
| | next_observe_match = next(observe_iter, None) |
| |
|
| | if output: |
| | chat.append_message(chat.role_assistant, output) |
| |
|
| | if chat.messages: |
| | chats.append(chat) |
| |
|
| | return chats |
| |
|
| | @classmethod |
| | def _append_human_observation( |
| | cls, |
| | chat, |
| | input: str, |
| | human_match: Optional[re.Match] = None, |
| | next_human_match: Optional[re.Match] = None, |
| | observe_match: Optional[re.Match] = None, |
| | next_observe_match: Optional[re.Match] = None, |
| | assistant_match: Optional[re.Match] = None, |
| | ) -> str: |
| | '''给 chat 对象增加 human/observation message.''' |
| | if observe_match: |
| | |
| | if observe_match.span()[0] > observe_match.span()[0]: |
| | human_str = input[observe_match.span()[1] : observe_match.span()[0]] |
| | observe_str = input[observe_match.span()[1] : assistant_match.span()[0]] |
| | chat.append_message(chat.role_human, human_str.strip()) |
| | input_end = observe_match.span()[1] |
| | if observe_match.span()[0] < next_human_match.span()[0]: |
| | chat.append_message(chat.role_observation, observe_str.strip()) |
| | input_end = assistant_match.span()[1] |
| | else: |
| | |
| | human_str = input[observe_match.span()[1] : assistant_match.span()[0]] |
| | observe_str = input[observe_match.span()[1] : observe_match.span()[0]] |
| | chat.append_message(chat.role_observation, observe_str.strip()) |
| | input_end = observe_match.span()[1] |
| | if observe_match.span()[0] < next_observe_match.span()[0]: |
| | chat.append_message(chat.role_human, human_str.strip()) |
| | input_end = assistant_match.span()[1] |
| | else: |
| | if assistant_match: |
| | human_str = input[human_match.span()[1] : assistant_match.span()[0]] |
| | input_end = assistant_match.span()[1] |
| | else: |
| | human_str = input[human_match.span()[1] :] |
| | input_end = len(input) |
| | chat.append_message(chat.role_human, human_str.strip()) |
| |
|
| | return input[input_end:] |
| |
|
| | @classmethod |
| | def from_inout( |
| | cls, |
| | sample: Dict[str, str], |
| | name: str, |
| | prompt_style: Optional[PromptStyle] = None, |
| | ): |
| | '''根据单样本创建一个 Chat 对象. |
| | |
| | Params: |
| | sample: `Dict[str, str]`, input/output 数据样本 |
| | { |
| | "input": "xxx", |
| | "output": "xxx", |
| | } |
| | |
| | name: `str`, 符号系统名称 |
| | prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格 |
| | ''' |
| | chat = cls(name=name, prompt_style=prompt_style) |
| | input = sample['input'] |
| | output = sample['output'] |
| |
|
| | sys_pattern = re.compile(chat.system_template.format(r'(.*?)'), re.DOTALL) |
| | turn_pattern = re.compile(chat.turn_start.format(r'(\d+)'), re.DOTALL) |
| | human_pattern = re.compile(chat.role_template.format(chat.role_human).strip(), re.DOTALL) |
| | observe_pattern = re.compile(chat.role_template.format(chat.role_observation).strip(), re.DOTALL) |
| | assistant_pattern = re.compile(chat.role_template.format(chat.role_assistant).strip(), re.DOTALL) |
| |
|
| | |
| | input = turn_pattern.sub('', input) |
| |
|
| | |
| | sys_match = sys_pattern.search(input) |
| | if sys_match and sys_match.group(0): |
| | input = input[sys_match.end() :] |
| | chat.system_message = sys_match.group(1) |
| |
|
| | human_iter = human_pattern.finditer(input) |
| | observe_iter = observe_pattern.finditer(input) |
| | assistant_iter = assistant_pattern.finditer(input) |
| | human_match = next(human_iter, None) |
| | observe_match = next(observe_iter, None) |
| | assistant_match = next(assistant_iter, None) |
| | next_human_match = next(human_iter, None) |
| | next_observe_match = next(observe_iter, None) |
| |
|
| | while any( |
| | [ |
| | human_match, |
| | observe_match, |
| | assistant_match, |
| | ] |
| | ): |
| |
|
| | |
| | |
| | while any( |
| | [ |
| | human_match and human_match.span()[0] < assistant_match.span()[0], |
| | observe_match and observe_match.span()[0] < assistant_match.span()[0], |
| | next_human_match and next_human_match.span()[0] < assistant_match.span()[0], |
| | next_observe_match and next_observe_match.span()[0] < assistant_match.span()[0], |
| | ] |
| | ): |
| | if not input: |
| | break |
| |
|
| | cls._append_human_observation( |
| | chat, |
| | input, |
| | human_match=human_match, |
| | next_human_match=next_human_match, |
| | observe_match=observe_match, |
| | next_observe_match=next_observe_match, |
| | assistant_match=assistant_match, |
| | ) |
| |
|
| | human_match = next_human_match |
| | observe_match = next_observe_match |
| | next_human_match = next(human_iter, None) |
| | next_observe_match = next(observe_iter, None) |
| |
|
| | |
| | if assistant_match and assistant_match.span(): |
| | if observe_match: |
| | if observe_match.span() and observe_match.span()[0] < human_match.span()[0]: |
| | assistant_str = input[assistant_match.span()[1] : observe_match.span()[0]] |
| | elif human_match: |
| | if human_match.span(): |
| | assistant_str = input[assistant_match.span()[1] : human_match.span()[0]] |
| | else: |
| | assistant_str = input[assistant_match.span()[1] :] |
| |
|
| | if assistant_str: |
| | chat.append_message(chat.role_assistant, assistant_str) |
| |
|
| | assistant_match = next(assistant_iter, None) |
| |
|
| | if output: |
| | chat.append_message(chat.role_assistant, output) |
| |
|
| | return chat |
| |
|
| | def __hash__(self): |
| | '''数据对象的 hash 函数.''' |
| | return hash(self.id) |
| |
|
| | def __post_init__(self): |
| | '''对象初始化后的处理, 处理包括: |
| | - 根据数据对象名称, 支持转成其他开源数据对象的基本信息 |
| | ''' |
| | self.id = str(uuid.uuid4()) |
| | if not self.messages: |
| | self.messages = [] |
| |
|
| | if not self.name and not self.prompt_style: |
| | logger.error('构造 Chat 对象至少包含以下一个入参: `name/prompt_style`.\n\n' '`name` 支持以下 format 名称:') |
| | logger.error('\n'.join([f'{k}: {v}' for k, v in self.support_names.items()])) |
| | logger.error('\n`prompt_style` 参考 antllm.data.chat_format.PromptStyle') |
| | raise ValueError |
| |
|
| | if self.name == 'antglm': |
| | |
| | self.name = 'antglm_raw' |
| |
|
| | if not self.name and self.prompt_style == PromptStyle.ANTGLM_CHAT: |
| | logger.info( |
| | 'Chat 对象入参没有 `name`, 默认使用 `ANTGLM_CHAT`, format:\n' |
| | f'role_human: {self.role_human}\n' |
| | f'role_assistant: {self.role_assistant}\n' |
| | f'role_observation: {self.role_observation}\n' |
| | f'role_template: {self.role_template}\n' |
| | f'turn_start: {self.turn_start}\n' |
| | f'human_end: {self.human_end}\n' |
| | f'assistant_start: {self.assistant_start}\n' |
| | f'assistant_end: {self.assistant_end}\n' |
| | f'assistant_end_ids: {self.assistant_end_ids}\n' |
| | f'general_role_end: {self.general_role_end}\n' |
| | f'tool_template: {self.tool_template}\n' |
| | f'code_template: {self.code_template}\n' |
| | f'arithemetic_templte: {self.arithemetic_templte}\n' |
| | f'image_template: {self.image_template}\n' |
| | f'\n入参 `name` 支持: ``' |
| | ) |
| | return |
| |
|
| | if self.name == 'antglm_raw' or self.prompt_style == PromptStyle.ANTGLM_RAW: |
| | self.prompt_style = PromptStyle.ANTGLM_RAW |
| | self.role_template = '{}' |
| | self.role_human = '用户: ' |
| | self.role_assistant = '机器人: ' |
| | self.turn_start = '第{}轮\n' |
| | self.general_role_end = '\n' |
| |
|
| | if self.name in ['chatglm1', 'chatglm2'] or self.prompt_style == PromptStyle.CHATGLM: |
| | self.prompt_style = PromptStyle.CHATGLM |
| | self.role_template = '{}' |
| | self.role_human = '问:' |
| | self.role_assistant = '答:' |
| | self.turn_start = '[Round {}]\n' |
| | if self.name == 'chatglm1': |
| | self.general_role_end = '\n' |
| | else: |
| | self.general_role_end = '\n\n' |
| |
|
| | elif self.name == 'chatglm3' or self.prompt_style == PromptStyle.CHATGLM3: |
| | self.prompt_style = PromptStyle.CHATGLM3 |
| | self.system_template = '<|system|>\n {}' |
| | self.role_human = '<|user|>\n ' |
| | self.role_assistant = '<|assistant|>\n ' |
| | self.role_template = '{}' |
| |
|
| | elif self.name == 'llama2' or self.prompt_style == PromptStyle.LLAMA2: |
| | self.prompt_style = PromptStyle.LLAMA2 |
| | self.role_template = '{}' |
| | self.system_template = '[INST] <<SYS>>\n{}\n<</SYS>>\n\n' |
| | self.role_human = '[INST] ' |
| | self.role_assistant = '[/INST] ' |
| | self.human_end = ' ' |
| | self.assistant_end = ' </s><s>' |
| |
|
| | elif self.name == 'qwen': |
| | self.prompt_style = PromptStyle.CHATML |
| | self.role_template = '{}' |
| | self.system_template = '<|im_start|>system\n{}' |
| | if not self.system_message: |
| | self.system_message = 'You are a helpful assistant.' |
| | self.role_human = '<|im_start|>user\n' |
| | self.role_assistant = '<|im_start|>assistant\n' |
| | self.general_role_end = '<|im_end|>\n' |
| |
|
| | elif self.name == 'baichuan': |
| | self.prompt_style = PromptStyle.BAICHUAN2 |
| | self.role_template = '{}' |
| | self.system_template = '{}' |
| | self.role_human = '<token_id-195>' |
| | self.role_assistant = '<token_id-196>' |
| |
|
| | if not self.system_template: |
| | self.system_template = '{}' |
| |
|
| | def readable_messages(self) -> str: |
| | '''将 messages 输出为人类可读的字符串, 方便分析数据.''' |
| | pass |
| |
|
| | @property |
| | def prompt_str(self) -> str: |
| | '''将 Chat 对象转成 prompt str, 合并 human/assitant 输出为 format 字符串.''' |
| | return f'{self.prompt_inout["input"]}{self.prompt_inout["output"]}' |
| |
|
| | @classmethod |
| | def _format_packs(cls, packs: Dict[str, List[str]]) -> Dict[str, List[str]]: |
| | '''格式化 pack 样本, 输出相同 pack inputs, outputs 个数.''' |
| | _packs = copy.deepcopy(packs) |
| | if len(_packs['input']) - 1 == len(_packs['output']): |
| | _packs['output'].append('') |
| |
|
| | if len(_packs['input']) != len(_packs['output']): |
| | print(packs) |
| | raise ValueError( |
| | '输入 input 和 output 数量不匹配, ' |
| | f'input num: {len(packs["input"])}, ' |
| | f'output num: {len(packs["output"])}' |
| | ) |
| |
|
| | return _packs |
| |
|
| | @property |
| | def prompt_inout(self) -> Dict[str, str]: |
| | '''将 Chat 对象转成 input prompt, output prompt 字符串. |
| | |
| | Returns: |
| | `Dict[str, str]`, 示例: |
| | { |
| | "input": "<role>SYSTEM</role>xxxx<role>HUMAN</role>你好<role>ASSISTANT</role>你好,有什么可以帮您?<role>ASSISTANT</role>", # noqa |
| | "output": "你好,有什么可以帮您?" |
| | } |
| | ''' |
| | packs = self._format_packs(self.prompt_pack) |
| |
|
| | |
| | if self.prompt_style == PromptStyle.ANTGLM_RAW: |
| | packs['input'] = [f'{item} ' for item in packs['input']] |
| |
|
| | prompt_input = ''.join([f'{x}{y}' for x, y in zip(packs['input'][:-1], packs['output'][:-1])]) |
| | prompt_input += packs['input'][-1] |
| | prompt_output = packs['output'][-1] |
| |
|
| | |
| | if self.prompt_style == PromptStyle.ANTGLM_RAW: |
| | prompt_input = prompt_input.strip() |
| |
|
| | return { |
| | 'input': prompt_input, |
| | 'output': prompt_output, |
| | } |
| |
|
| | @property |
| | def prompt_pack(self) -> Dict[str, List[str]]: |
| | '''将数据对象转成 pack input prompt, output prompt 字符串列表.: |
| | |
| | Returns: |
| | `Dict[str, List[str]]`, 示例: |
| | |
| | { |
| | "input": [ |
| | "<role>SYSTEM</role>xxxx<role>HUMAN</role>你好<role>ASSISTANT</role>", |
| | "<role>HUMAN</role>讲个笑话<role>ASSISTANT</role>", |
| | "<role>OBSERVATION</role>{\"weather\": \"晴\"}<role>ASSISTANT</role>" |
| | ], |
| | "output": [ |
| | "你好,有什么可以帮您?", |
| | "笑话 1", |
| | "今天天气 xxx" |
| | ] |
| | } |
| | |
| | ''' |
| | inputs = [] |
| | outputs = [] |
| |
|
| | |
| | system_prompt = '' |
| | if self.system_message: |
| | system_prompt = self.system_template.format(self.system_message) |
| |
|
| | if system_prompt: |
| | ret = system_prompt + self.general_role_end |
| | else: |
| | ret = '' |
| |
|
| | |
| | if self.prompt_style in [ |
| | PromptStyle.ANTGLM_RAW, |
| | PromptStyle.ANTGLM_ONLY_MULTITURN_CHAT, |
| | ]: |
| | if len(self.messages) <= 2: |
| | output = '' |
| | for role, message in self.messages: |
| | if role == self.role_assistant: |
| | output = message |
| | else: |
| | input = ret + message |
| | return { |
| | 'input': [input], |
| | 'output': [output], |
| | } |
| |
|
| | |
| | if self.name in ['antglm_raw', 'chatglm2']: |
| | round_start = 1 |
| | else: |
| | round_start = 0 |
| |
|
| | for i, (role, message) in enumerate(self.messages): |
| | |
| | if self.name in ['antglm_raw', 'chatglm1', 'chatglm2']: |
| | if i % 2 == 0: |
| | ret += self.turn_start.format(i // 2 + round_start) |
| |
|
| | |
| | role_end = self.general_role_end |
| | if role == self.role_assistant and self.assistant_end: |
| | role_end = self.assistant_end |
| | elif self.human_end: |
| | role_end = self.human_end |
| |
|
| | ret += self.role_template.format(role) + message + role_end |
| |
|
| | if role == self.role_assistant: |
| | |
| | if not message: |
| | outputs.append('') |
| | else: |
| | outputs.append(message + role_end) |
| | |
| | inputs[-1] += ret[: -len(message + role_end)] |
| | elif all( |
| | [ |
| | role == self.role_observation, |
| | len(self.messages) > 1, |
| | self.messages[i - 1][0] != self.role_assistant, |
| | ] |
| | ): |
| | |
| | continue |
| | else: |
| | inputs.append(ret) |
| | ret = '' |
| |
|
| | |
| | if i == len(self.messages) - 1 and role != self.role_assistant: |
| | inputs[-1] += self.role_template.format(self.role_assistant).strip() |
| |
|
| | |
| | if self.prompt_style == PromptStyle.ANTGLM_RAW: |
| | inputs = [item.strip() for item in inputs] |
| |
|
| | return { |
| | 'input': inputs, |
| | 'output': outputs, |
| | } |
| |
|
| | @property |
| | def turns_num(self) -> int: |
| | '''和机器人的交互轮数, 以用户输出多少次为轮数个数.''' |
| | return sum([1 if msg[0] == self.role_human else 0 for msg in self.messages]) |
| |
|
| | def to_json(self) -> dict: |
| | '''输出 chat json dict 格式, 包含不同角色和机器人交互的每轮信息. |
| | |
| | Returns |
| | `List[dict]`, { |
| | "id": "xx", |
| | "messages": [ |
| | {"role": "HUMAN", "content": "xxx"} |
| | ] |
| | "turns": [ |
| | {"HUMAN": "xx", "OBSERVATION": "xx", "ASSISTANT": "xx"} |
| | ] |
| | } |
| | ''' |
| | turns = [] |
| | messages = [] |
| | turn = {} |
| | for msg in self.messages: |
| | if msg[0] == self.role_assistant: |
| | messages.append({'role': 'ASSISTANT', 'content': msg[1]}) |
| | turn['ASSISTANT'] = msg[1] |
| | turns.append(turn) |
| | turn = {} |
| |
|
| | if msg[0] == self.role_human: |
| | messages.append({'role': 'HUMAN', 'content': msg[1]}) |
| | turn['HUMAN'] = msg[1] |
| |
|
| | if msg[0] == self.role_observation: |
| | messages.append({'role': 'OBSERVATION', 'content': msg[1]}) |
| | turn['OBSERVATION'] = msg[1] |
| |
|
| | if self.messages[-1][0] == self.role_human: |
| | messages.append({'role': 'ASSISTANT', 'content': ''}) |
| | turn['ASSISTANT'] = '' |
| | turns.append(turn) |
| |
|
| | result = self.origin_json or {} |
| | result.update( |
| | { |
| | 'id': self.id, |
| | 'name': self.name, |
| | 'source': self.source, |
| | 'lang': self.lang, |
| | 'topic': self.topic, |
| | 'system_template': self.system_template, |
| | 'system_message': self.system_message, |
| | 'turns': turns, |
| | 'messages': messages, |
| | } |
| | ) |
| |
|
| | return result |
| |
|
| | def set_system_message(self, system_message: str): |
| | '''Set the system message.''' |
| | self.system_message = system_message |
| |
|
| | def append_message(self, role: str, message: str): |
| | '''Append a new message.''' |
| | if not message: |
| | message = '' |
| | self.messages.append([role, message]) |
| |
|
| | def to_openai_api_messages(self) -> List[dict]: |
| | '''Convert the conversation to OpenAI chat completion format.''' |
| | ret = [{'role': 'system', 'content': self.system_message}] |
| |
|
| | for i, (_, msg) in enumerate(self.messages[self.offset :]): |
| | if i % 2 == 0: |
| | ret.append({'role': 'user', 'content': msg}) |
| | else: |
| | if msg is not None: |
| | ret.append({'role': 'assistant', 'content': msg}) |
| | return ret |
| |
|
| | def copy(self): |
| | return copy.deepcopy(self) |
| |
|