Upload chat_format.py with huggingface_hub
Browse files- chat_format.py +875 -0
chat_format.py
ADDED
|
@@ -0,0 +1,875 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''AntGLM Chat-model data format.
|
| 2 |
+
|
| 3 |
+
格式化 AntGLM 以及各种开源模型的符号系统:
|
| 4 |
+
- 确定 Chat 模型依赖的文件数据结构协议
|
| 5 |
+
- 确定单轮/多轮的统一结构
|
| 6 |
+
- 确定 Chat 符号系统的协议, 包括角色定义、分隔符等
|
| 7 |
+
- 方便做开源模型依赖的 prompt 转换
|
| 8 |
+
- 支持工具、代码、推理等支持
|
| 9 |
+
|
| 10 |
+
参考 FastChat Conversation 对象的设计思路.
|
| 11 |
+
Reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
| 12 |
+
'''
|
| 13 |
+
|
| 14 |
+
import copy
|
| 15 |
+
import dataclasses
|
| 16 |
+
import logging
|
| 17 |
+
import re
|
| 18 |
+
import uuid
|
| 19 |
+
from copy import deepcopy
|
| 20 |
+
from enum import IntEnum, auto
|
| 21 |
+
from typing import Dict, List, Optional, Tuple
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PromptStyle(IntEnum):
|
| 27 |
+
'''Prompt styles.'''
|
| 28 |
+
|
| 29 |
+
# 原始 antglm format 格式, 单轮指令没有结构, 多轮 `第1轮\n用户: xx\n机器人: xx\n`
|
| 30 |
+
ANTGLM_RAW = auto()
|
| 31 |
+
# Chat format 格式, 单轮多轮统一为 chat format 格式
|
| 32 |
+
ANTGLM_CHAT = auto()
|
| 33 |
+
# 单轮指令没有结构, 只有多轮为 chat format 格式
|
| 34 |
+
ANTGLM_ONLY_MULTITURN_CHAT = auto()
|
| 35 |
+
# OpenAI ChatML 格式, 包括千问
|
| 36 |
+
CHATML = auto()
|
| 37 |
+
# LLAMA2 格式
|
| 38 |
+
LLAMA2 = auto()
|
| 39 |
+
# ChatGLM 1/2 格式
|
| 40 |
+
CHATGLM = auto()
|
| 41 |
+
# ChatGLM3 格式
|
| 42 |
+
CHATGLM3 = auto()
|
| 43 |
+
# 百川格式
|
| 44 |
+
BAICHUAN2 = auto()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclasses.dataclass
|
| 48 |
+
class Chat:
|
| 49 |
+
'''Chat 数据符号结构, 格式化 AntGLM 以及各种开源模型的符号系统.
|
| 50 |
+
|
| 51 |
+
Examples:
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
>>> from antllm.data.chat_format import Chat
|
| 55 |
+
|
| 56 |
+
>>> ### 从 json 数据结构创建 chat 对象, 并且 format 结构使用 AntGLM 原始结构
|
| 57 |
+
>>> input_json = {
|
| 58 |
+
... "messages": [
|
| 59 |
+
... {"role": "HUMAN", "content": "讲一个笑话"},
|
| 60 |
+
... {"role": "ASSISTANT", "content": "为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!"},
|
| 61 |
+
... {"role": "HUMAN", "content": "不好笑,换个程序员的笑话"}
|
| 62 |
+
... ],
|
| 63 |
+
... }
|
| 64 |
+
>>> chat = Chat.from_json(input_json, name='antglm_raw')
|
| 65 |
+
|
| 66 |
+
>>> ### 根据 chat 对象创建大模型训练所需 pack 数据
|
| 67 |
+
>>> pack_data = chat.prompt_pack
|
| 68 |
+
>>> print(pack_data)
|
| 69 |
+
|
| 70 |
+
>>> ### 根据 chat 对象创建大模型训练所需 input, output 数据
|
| 71 |
+
>>> data = chat.prompt_inout
|
| 72 |
+
>>> print(data)
|
| 73 |
+
|
| 74 |
+
>>> ### 根据 chat 对象创建大模型预测用的 prompt
|
| 75 |
+
>>> prompt = chat.prompt_str
|
| 76 |
+
>>> print(prompt)
|
| 77 |
+
|
| 78 |
+
>>> ### 从大模型训练数据 {"input": "xx", "output": "xx"} 中创建 chat 对象
|
| 79 |
+
>>> data = {
|
| 80 |
+
... 'input': (
|
| 81 |
+
... '第1轮\n用户: 讲一个笑话\n机器人: 为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!\n'
|
| 82 |
+
... '第2轮\n用户: 不好笑,换个程序员的笑话\n机器人:'
|
| 83 |
+
... ),
|
| 84 |
+
... 'output': ''
|
| 85 |
+
... }
|
| 86 |
+
>>> chat = Chat.from_inout(data, name='antglm_raw')
|
| 87 |
+
|
| 88 |
+
>>> ### 从大模型 pack 训练数据创建 chat 对象列表
|
| 89 |
+
>>> pack_data = {
|
| 90 |
+
... 'inputs': ['第1轮\n用户: 讲一个笑话\n机器人:', '第2轮\n用户: 不好笑,换个程序员的笑话\n机器人:', '第1轮\n用户: 写首诗\n机器人:'],
|
| 91 |
+
... 'outputs': [
|
| 92 |
+
... '为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!\n',
|
| 93 |
+
... '为什么程序员总是喜欢使用黑色主题?因为他们喜欢“黑暗模式”(Dark Mode),这样他们就可以在晚上加班时更好地隐藏自己的错误!',
|
| 94 |
+
... '']
|
| 95 |
+
... }
|
| 96 |
+
>>> chats = Chat.from_pack(pack_data, name='antglm_raw')
|
| 97 |
+
>>> assert len(chats) == 2
|
| 98 |
+
>>> print(chats[0])
|
| 99 |
+
>>> print(chats[1])
|
| 100 |
+
|
| 101 |
+
>>> ### 显示总交互轮数 (以用户输出多少次为轮数个数)
|
| 102 |
+
>>> print(chat.turns_num)
|
| 103 |
+
|
| 104 |
+
>>> ### 根据 chat 对象创建 json 格式化输出
|
| 105 |
+
>>> data_json = chat.to_json()
|
| 106 |
+
>>> print(data_json)
|
| 107 |
+
|
| 108 |
+
>>> ### 增加轮次信息
|
| 109 |
+
>>> content = (
|
| 110 |
+
... '为什么程序员总是喜欢使用黑色主题?'
|
| 111 |
+
... '因为他们喜欢“黑暗模式”(Dark Mode),这样他们就可以在晚上加班时更好地隐藏自己的错误!'
|
| 112 |
+
... )
|
| 113 |
+
>>> chat.append_message(chat.role_assistant, content)
|
| 114 |
+
|
| 115 |
+
>>> ### 将 chat 对象转成 OpenAI ChatCompletion 接口的入参
|
| 116 |
+
>>> openai_messages = chat.to_openai_api_messages()
|
| 117 |
+
>>> print(openai_messages)
|
| 118 |
+
|
| 119 |
+
>>> ### 复制一个 chat 对象
|
| 120 |
+
>>> chat_new = chat.copy()
|
| 121 |
+
```
|
| 122 |
+
'''
|
| 123 |
+
|
| 124 |
+
# 数据结构名称
|
| 125 |
+
id: str = None
|
| 126 |
+
|
| 127 |
+
# format 支持: antglm_raw, antglm_chat, chatglm1, chatglm2, llama2, qwen, baichuan2
|
| 128 |
+
name: Optional[str] = None
|
| 129 |
+
|
| 130 |
+
# Prompt 风格
|
| 131 |
+
prompt_style: Optional[PromptStyle] = None
|
| 132 |
+
|
| 133 |
+
# System Template 和 message
|
| 134 |
+
system_template: str = '<role>SYSTEM</role>{}'
|
| 135 |
+
system_message: str = ''
|
| 136 |
+
|
| 137 |
+
# 角色定义
|
| 138 |
+
role_human: str = 'HUMAN'
|
| 139 |
+
role_assistant: str = 'ASSISTANT'
|
| 140 |
+
role_observation: str = 'OBSERVATION'
|
| 141 |
+
role_template: str = '<role>{}</role>'
|
| 142 |
+
|
| 143 |
+
# 每轮符号定义
|
| 144 |
+
turn_start: str = ''
|
| 145 |
+
human_end: str = ''
|
| 146 |
+
assistant_start: str = ''
|
| 147 |
+
assistant_end: str = ''
|
| 148 |
+
assistant_end_ids: Optional[List[int]] = None
|
| 149 |
+
general_role_end: str = ''
|
| 150 |
+
|
| 151 |
+
# agent 符号定义
|
| 152 |
+
tool_template = '<tool>{}</tool>'
|
| 153 |
+
code_template = '<code>{}</code>'
|
| 154 |
+
arithemetic_templte = '<arithemetic>{}</arithemetic>'
|
| 155 |
+
image_template = '<image>{}</image>'
|
| 156 |
+
|
| 157 |
+
# All messages. Each item is (role, message).
|
| 158 |
+
messages: List[Tuple[str, str]] = ()
|
| 159 |
+
|
| 160 |
+
# messages 中用于 few-shot messages
|
| 161 |
+
offset: int = 0
|
| 162 |
+
|
| 163 |
+
# 其他 meta data
|
| 164 |
+
source: Optional[str] = None
|
| 165 |
+
lang: Optional[str] = None
|
| 166 |
+
topic: Optional[str] = None
|
| 167 |
+
|
| 168 |
+
# 原始 json 数据
|
| 169 |
+
origin_json: Optional[dict] = None
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def support_names(self) -> Dict[str, str]:
|
| 173 |
+
'''支持的数据对象名称.'''
|
| 174 |
+
return {
|
| 175 |
+
'antglm_raw': '原始 antglm format 格式, 单轮指令没有结构, 多轮 `第1轮\\n用户:xx\\n机器人xx\\n`',
|
| 176 |
+
'antglm_chat': 'Chat format 格式, 单轮多轮统一为 chat format 格式',
|
| 177 |
+
'chatglm1': 'chatglm1 format',
|
| 178 |
+
'chatglm2': 'chatglm2 format',
|
| 179 |
+
'llama2': 'llama2 format',
|
| 180 |
+
'qwen': '千问 format',
|
| 181 |
+
'baichuan2': '百川 2 format',
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
@classmethod
|
| 185 |
+
def from_json(
|
| 186 |
+
cls,
|
| 187 |
+
input: dict,
|
| 188 |
+
name: Optional[str] = None,
|
| 189 |
+
prompt_style: Optional[PromptStyle] = None,
|
| 190 |
+
):
|
| 191 |
+
'''从文件数据结构到数据对象的转换.
|
| 192 |
+
|
| 193 |
+
Params:
|
| 194 |
+
name: `Optional[str]`, 符号系统名称
|
| 195 |
+
- format 支持: antglm_raw, antglm_chat, chatglm1, chatglm2, llama2, qwen, baichuan2
|
| 196 |
+
- 如果指定了 format name, 使用该 name 符号系统, 否则使用 input 中 `name` 字段
|
| 197 |
+
|
| 198 |
+
prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格
|
| 199 |
+
|
| 200 |
+
input: `dict`, 文件中的 json dict 对象, 协议为:
|
| 201 |
+
- 既支持 `messages` 字段, 也支持 `turns` 字段
|
| 202 |
+
{
|
| 203 |
+
"id": "xxx",
|
| 204 |
+
"name": "antglm",
|
| 205 |
+
"source": "xxx",
|
| 206 |
+
"lang": "xx",
|
| 207 |
+
"topic": "xx",
|
| 208 |
+
"system_template": "",
|
| 209 |
+
"system_message": "xx",
|
| 210 |
+
"messages": [
|
| 211 |
+
{
|
| 212 |
+
"role": "HUMAN",
|
| 213 |
+
"content": "Hi"
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"role": "ASSISTANT",
|
| 217 |
+
"content": "Hello"
|
| 218 |
+
},
|
| 219 |
+
{
|
| 220 |
+
"role": "OBSERVATION",
|
| 221 |
+
"content": "xxx"
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"role": "ASSISTANT",
|
| 225 |
+
"content": "xxx"
|
| 226 |
+
}
|
| 227 |
+
],
|
| 228 |
+
"turns": [
|
| 229 |
+
{"HUMAN": "xxx", "OBSERVATION": "xx", "ASSISTANT": "xx"}
|
| 230 |
+
]
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
`Chat` 对象
|
| 235 |
+
'''
|
| 236 |
+
_id = input.get('id')
|
| 237 |
+
if name:
|
| 238 |
+
_name = name
|
| 239 |
+
else:
|
| 240 |
+
_name = input.get('name')
|
| 241 |
+
source = input.get('source')
|
| 242 |
+
lang = input.get('lang')
|
| 243 |
+
topic = input.get('topic')
|
| 244 |
+
kwargs = {}
|
| 245 |
+
if 'system_template' in input:
|
| 246 |
+
kwargs['system_template'] = input['system_template']
|
| 247 |
+
if 'system_message' in input:
|
| 248 |
+
kwargs['system_message'] = input['system_message']
|
| 249 |
+
|
| 250 |
+
# 转换成 Chat 对象
|
| 251 |
+
chat = cls(
|
| 252 |
+
id=_id,
|
| 253 |
+
name=_name,
|
| 254 |
+
prompt_style=prompt_style,
|
| 255 |
+
source=source,
|
| 256 |
+
lang=lang,
|
| 257 |
+
topic=topic,
|
| 258 |
+
origin_json=deepcopy(input),
|
| 259 |
+
**kwargs,
|
| 260 |
+
)
|
| 261 |
+
if 'messages' in input:
|
| 262 |
+
for msg in input['messages']:
|
| 263 |
+
if msg['role'] == 'HUMAN':
|
| 264 |
+
role = chat.role_human
|
| 265 |
+
elif msg['role'] == 'OBSERVATION':
|
| 266 |
+
role = chat.role_observation
|
| 267 |
+
elif msg['role'] == 'ASSISTANT':
|
| 268 |
+
role = chat.role_assistant
|
| 269 |
+
else:
|
| 270 |
+
raise ValueError(f'不支持数据集中的 role: {msg["role"]}')
|
| 271 |
+
|
| 272 |
+
chat.append_message(role, msg['content'])
|
| 273 |
+
|
| 274 |
+
elif 'turns' in input:
|
| 275 |
+
for turn in input['turns']:
|
| 276 |
+
if 'HUMAN' in turn:
|
| 277 |
+
content = turn['HUMAN']
|
| 278 |
+
chat.append_message(chat.role_human, content)
|
| 279 |
+
if 'OBSERVATION' in turn:
|
| 280 |
+
content = turn['OBSERVATION']
|
| 281 |
+
chat.append_message(chat.role_observation, content)
|
| 282 |
+
if 'ASSISTANT' in turn:
|
| 283 |
+
content = turn['ASSISTANT']
|
| 284 |
+
chat.append_message(chat.role_assistant, content)
|
| 285 |
+
|
| 286 |
+
return chat
|
| 287 |
+
|
| 288 |
+
@classmethod
|
| 289 |
+
def from_pack(
|
| 290 |
+
cls,
|
| 291 |
+
packs: Dict[str, List[str]],
|
| 292 |
+
name: str,
|
| 293 |
+
prompt_style: Optional[PromptStyle] = None,
|
| 294 |
+
) -> list:
|
| 295 |
+
'''根据 pack 数据创建 Chat 对象.
|
| 296 |
+
|
| 297 |
+
Params:
|
| 298 |
+
packs: `dict`, pack 样本数据
|
| 299 |
+
{
|
| 300 |
+
'inputs': ['xx', 'xx'],
|
| 301 |
+
'outputs': ['xx', 'xx'],
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
name: `str`, 符号系统名称
|
| 305 |
+
prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格
|
| 306 |
+
'''
|
| 307 |
+
chat = cls(name=name, prompt_style=prompt_style)
|
| 308 |
+
packs = cls._format_packs(packs)
|
| 309 |
+
|
| 310 |
+
sys_pattern = re.compile(chat.system_template.format(r'(.*?)'), re.DOTALL)
|
| 311 |
+
turn_pattern = re.compile(chat.turn_start.format(r'(\d+)'), re.DOTALL)
|
| 312 |
+
human_pattern = re.compile(chat.role_template.format(chat.role_human).strip(), re.DOTALL)
|
| 313 |
+
observe_pattern = re.compile(chat.role_template.format(chat.role_observation).strip(), re.DOTALL)
|
| 314 |
+
assistant_pattern = re.compile(chat.role_template.format(chat.role_assistant).strip(), re.DOTALL)
|
| 315 |
+
|
| 316 |
+
chats = []
|
| 317 |
+
for input, output in zip(packs['input'], packs['output']):
|
| 318 |
+
# system message
|
| 319 |
+
sys_match = sys_pattern.search(input)
|
| 320 |
+
if sys_match and sys_match.group(0):
|
| 321 |
+
# system 指令只在首轮, 新增 chat 对象
|
| 322 |
+
if len(chat.messages) > 0:
|
| 323 |
+
chats.append(chat)
|
| 324 |
+
chat = cls(name=name, prompt_style=prompt_style)
|
| 325 |
+
|
| 326 |
+
input = input[sys_match.end() :]
|
| 327 |
+
chat.system_message = sys_match.group(1)
|
| 328 |
+
|
| 329 |
+
# turn start
|
| 330 |
+
turn_match = turn_pattern.search(input)
|
| 331 |
+
if turn_match and turn_match.group(0):
|
| 332 |
+
# 当出现下一个轮次开始信息, 新增 chat 对象
|
| 333 |
+
if name in ['antglm', 'antglm_raw', 'chatglm2']:
|
| 334 |
+
round_start = 1
|
| 335 |
+
else:
|
| 336 |
+
round_start = 0
|
| 337 |
+
|
| 338 |
+
if all(
|
| 339 |
+
[
|
| 340 |
+
len(turn_match.groups()) > 0,
|
| 341 |
+
int(turn_match.group(1)) == round_start,
|
| 342 |
+
len(chat.messages) > 0,
|
| 343 |
+
]
|
| 344 |
+
):
|
| 345 |
+
chats.append(chat)
|
| 346 |
+
chat = cls(name=name, prompt_style=prompt_style)
|
| 347 |
+
|
| 348 |
+
input = input[turn_match.end() :]
|
| 349 |
+
|
| 350 |
+
human_iter = human_pattern.finditer(input)
|
| 351 |
+
observe_iter = observe_pattern.finditer(input)
|
| 352 |
+
assistant_iter = assistant_pattern.finditer(input)
|
| 353 |
+
human_match = next(human_iter, None)
|
| 354 |
+
observe_match = next(observe_iter, None)
|
| 355 |
+
assistant_match = next(assistant_iter, None)
|
| 356 |
+
|
| 357 |
+
if not human_match and not observe_match:
|
| 358 |
+
# 无 role format
|
| 359 |
+
chat.append_message(chat.role_human, input)
|
| 360 |
+
|
| 361 |
+
while human_match or observe_match:
|
| 362 |
+
next_human_match = next(human_iter, None)
|
| 363 |
+
next_observe_match = next(observe_iter, None)
|
| 364 |
+
input = cls._append_human_observation(
|
| 365 |
+
chat,
|
| 366 |
+
input,
|
| 367 |
+
human_match=human_match,
|
| 368 |
+
next_human_match=next_human_match,
|
| 369 |
+
observe_match=observe_match,
|
| 370 |
+
next_observe_match=next_observe_match,
|
| 371 |
+
assistant_match=assistant_match,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
human_match = next_human_match
|
| 375 |
+
observe_match = next_observe_match
|
| 376 |
+
next_human_match = next(human_iter, None)
|
| 377 |
+
next_observe_match = next(observe_iter, None)
|
| 378 |
+
|
| 379 |
+
if output:
|
| 380 |
+
chat.append_message(chat.role_assistant, output)
|
| 381 |
+
|
| 382 |
+
if chat.messages:
|
| 383 |
+
chats.append(chat)
|
| 384 |
+
|
| 385 |
+
return chats
|
| 386 |
+
|
| 387 |
+
@classmethod
|
| 388 |
+
def _append_human_observation(
|
| 389 |
+
cls,
|
| 390 |
+
chat,
|
| 391 |
+
input: str,
|
| 392 |
+
human_match: Optional[re.Match] = None,
|
| 393 |
+
next_human_match: Optional[re.Match] = None,
|
| 394 |
+
observe_match: Optional[re.Match] = None,
|
| 395 |
+
next_observe_match: Optional[re.Match] = None,
|
| 396 |
+
assistant_match: Optional[re.Match] = None,
|
| 397 |
+
) -> str:
|
| 398 |
+
'''给 chat 对象增加 human/observation message.'''
|
| 399 |
+
if observe_match:
|
| 400 |
+
# observation 在 human 之后
|
| 401 |
+
if observe_match.span()[0] > observe_match.span()[0]:
|
| 402 |
+
human_str = input[observe_match.span()[1] : observe_match.span()[0]]
|
| 403 |
+
observe_str = input[observe_match.span()[1] : assistant_match.span()[0]]
|
| 404 |
+
chat.append_message(chat.role_human, human_str.strip())
|
| 405 |
+
input_end = observe_match.span()[1]
|
| 406 |
+
if observe_match.span()[0] < next_human_match.span()[0]:
|
| 407 |
+
chat.append_message(chat.role_observation, observe_str.strip())
|
| 408 |
+
input_end = assistant_match.span()[1]
|
| 409 |
+
else:
|
| 410 |
+
# observation 在 human 之前
|
| 411 |
+
human_str = input[observe_match.span()[1] : assistant_match.span()[0]]
|
| 412 |
+
observe_str = input[observe_match.span()[1] : observe_match.span()[0]]
|
| 413 |
+
chat.append_message(chat.role_observation, observe_str.strip())
|
| 414 |
+
input_end = observe_match.span()[1]
|
| 415 |
+
if observe_match.span()[0] < next_observe_match.span()[0]:
|
| 416 |
+
chat.append_message(chat.role_human, human_str.strip())
|
| 417 |
+
input_end = assistant_match.span()[1]
|
| 418 |
+
else:
|
| 419 |
+
if assistant_match:
|
| 420 |
+
human_str = input[human_match.span()[1] : assistant_match.span()[0]]
|
| 421 |
+
input_end = assistant_match.span()[1]
|
| 422 |
+
else:
|
| 423 |
+
human_str = input[human_match.span()[1] :]
|
| 424 |
+
input_end = len(input)
|
| 425 |
+
chat.append_message(chat.role_human, human_str.strip())
|
| 426 |
+
|
| 427 |
+
return input[input_end:]
|
| 428 |
+
|
| 429 |
+
@classmethod
|
| 430 |
+
def from_inout(
|
| 431 |
+
cls,
|
| 432 |
+
sample: Dict[str, str],
|
| 433 |
+
name: str,
|
| 434 |
+
prompt_style: Optional[PromptStyle] = None,
|
| 435 |
+
):
|
| 436 |
+
'''根据单样本创建一个 Chat 对象.
|
| 437 |
+
|
| 438 |
+
Params:
|
| 439 |
+
sample: `Dict[str, str]`, input/output 数据样本
|
| 440 |
+
{
|
| 441 |
+
"input": "xxx",
|
| 442 |
+
"output": "xxx",
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
name: `str`, 符号系统名称
|
| 446 |
+
prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格
|
| 447 |
+
'''
|
| 448 |
+
chat = cls(name=name, prompt_style=prompt_style)
|
| 449 |
+
input = sample['input']
|
| 450 |
+
output = sample['output']
|
| 451 |
+
|
| 452 |
+
sys_pattern = re.compile(chat.system_template.format(r'(.*?)'), re.DOTALL)
|
| 453 |
+
turn_pattern = re.compile(chat.turn_start.format(r'(\d+)'), re.DOTALL)
|
| 454 |
+
human_pattern = re.compile(chat.role_template.format(chat.role_human).strip(), re.DOTALL)
|
| 455 |
+
observe_pattern = re.compile(chat.role_template.format(chat.role_observation).strip(), re.DOTALL)
|
| 456 |
+
assistant_pattern = re.compile(chat.role_template.format(chat.role_assistant).strip(), re.DOTALL)
|
| 457 |
+
|
| 458 |
+
# 去除轮次信息
|
| 459 |
+
input = turn_pattern.sub('', input)
|
| 460 |
+
|
| 461 |
+
# system message search
|
| 462 |
+
sys_match = sys_pattern.search(input)
|
| 463 |
+
if sys_match and sys_match.group(0):
|
| 464 |
+
input = input[sys_match.end() :]
|
| 465 |
+
chat.system_message = sys_match.group(1)
|
| 466 |
+
|
| 467 |
+
human_iter = human_pattern.finditer(input)
|
| 468 |
+
observe_iter = observe_pattern.finditer(input)
|
| 469 |
+
assistant_iter = assistant_pattern.finditer(input)
|
| 470 |
+
human_match = next(human_iter, None)
|
| 471 |
+
observe_match = next(observe_iter, None)
|
| 472 |
+
assistant_match = next(assistant_iter, None)
|
| 473 |
+
next_human_match = next(human_iter, None)
|
| 474 |
+
next_observe_match = next(observe_iter, None)
|
| 475 |
+
|
| 476 |
+
while any(
|
| 477 |
+
[
|
| 478 |
+
human_match,
|
| 479 |
+
observe_match,
|
| 480 |
+
assistant_match,
|
| 481 |
+
]
|
| 482 |
+
):
|
| 483 |
+
|
| 484 |
+
# human/observation 先后顺序可能不一样, 并且有可能有多个
|
| 485 |
+
# 判断 assitant 之前是否还有 human/observation
|
| 486 |
+
while any(
|
| 487 |
+
[
|
| 488 |
+
human_match and human_match.span()[0] < assistant_match.span()[0],
|
| 489 |
+
observe_match and observe_match.span()[0] < assistant_match.span()[0],
|
| 490 |
+
next_human_match and next_human_match.span()[0] < assistant_match.span()[0],
|
| 491 |
+
next_observe_match and next_observe_match.span()[0] < assistant_match.span()[0],
|
| 492 |
+
]
|
| 493 |
+
):
|
| 494 |
+
if not input:
|
| 495 |
+
break
|
| 496 |
+
|
| 497 |
+
cls._append_human_observation(
|
| 498 |
+
chat,
|
| 499 |
+
input,
|
| 500 |
+
human_match=human_match,
|
| 501 |
+
next_human_match=next_human_match,
|
| 502 |
+
observe_match=observe_match,
|
| 503 |
+
next_observe_match=next_observe_match,
|
| 504 |
+
assistant_match=assistant_match,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
human_match = next_human_match
|
| 508 |
+
observe_match = next_observe_match
|
| 509 |
+
next_human_match = next(human_iter, None)
|
| 510 |
+
next_observe_match = next(observe_iter, None)
|
| 511 |
+
|
| 512 |
+
# assistant message
|
| 513 |
+
if assistant_match and assistant_match.span():
|
| 514 |
+
if observe_match:
|
| 515 |
+
if observe_match.span() and observe_match.span()[0] < human_match.span()[0]:
|
| 516 |
+
assistant_str = input[assistant_match.span()[1] : observe_match.span()[0]]
|
| 517 |
+
elif human_match:
|
| 518 |
+
if human_match.span():
|
| 519 |
+
assistant_str = input[assistant_match.span()[1] : human_match.span()[0]]
|
| 520 |
+
else:
|
| 521 |
+
assistant_str = input[assistant_match.span()[1] :]
|
| 522 |
+
|
| 523 |
+
if assistant_str:
|
| 524 |
+
chat.append_message(chat.role_assistant, assistant_str)
|
| 525 |
+
|
| 526 |
+
assistant_match = next(assistant_iter, None)
|
| 527 |
+
|
| 528 |
+
if output:
|
| 529 |
+
chat.append_message(chat.role_assistant, output)
|
| 530 |
+
|
| 531 |
+
return chat
|
| 532 |
+
|
| 533 |
+
def __hash__(self):
|
| 534 |
+
'''数据对象的 hash 函数.'''
|
| 535 |
+
return hash(self.id)
|
| 536 |
+
|
| 537 |
+
def __post_init__(self):
|
| 538 |
+
'''对象初始化后的处理, 处理包括:
|
| 539 |
+
- 根据数据对象名称, 支持转成其他开源数据对象的基本信息
|
| 540 |
+
'''
|
| 541 |
+
self.id = str(uuid.uuid4())
|
| 542 |
+
if not self.messages:
|
| 543 |
+
self.messages = []
|
| 544 |
+
|
| 545 |
+
if not self.name and not self.prompt_style:
|
| 546 |
+
logger.error('构造 Chat 对象至少包含以下一个入参: `name/prompt_style`.\n\n' '`name` 支持以下 format 名称:')
|
| 547 |
+
logger.error('\n'.join([f'{k}: {v}' for k, v in self.support_names.items()]))
|
| 548 |
+
logger.error('\n`prompt_style` 参考 antllm.data.chat_format.PromptStyle')
|
| 549 |
+
raise ValueError
|
| 550 |
+
|
| 551 |
+
if self.name == 'antglm':
|
| 552 |
+
# 默认 antglm 使用原始 antglm_raw - 第1轮\n用户: xx\n机器人: xx\n
|
| 553 |
+
self.name = 'antglm_raw'
|
| 554 |
+
|
| 555 |
+
if not self.name and self.prompt_style == PromptStyle.ANTGLM_CHAT:
|
| 556 |
+
logger.info(
|
| 557 |
+
'Chat 对象入参没有 `name`, 默认使用 `ANTGLM_CHAT`, format:\n'
|
| 558 |
+
f'role_human: {self.role_human}\n'
|
| 559 |
+
f'role_assistant: {self.role_assistant}\n'
|
| 560 |
+
f'role_observation: {self.role_observation}\n'
|
| 561 |
+
f'role_template: {self.role_template}\n'
|
| 562 |
+
f'turn_start: {self.turn_start}\n'
|
| 563 |
+
f'human_end: {self.human_end}\n'
|
| 564 |
+
f'assistant_start: {self.assistant_start}\n'
|
| 565 |
+
f'assistant_end: {self.assistant_end}\n'
|
| 566 |
+
f'assistant_end_ids: {self.assistant_end_ids}\n'
|
| 567 |
+
f'general_role_end: {self.general_role_end}\n'
|
| 568 |
+
f'tool_template: {self.tool_template}\n'
|
| 569 |
+
f'code_template: {self.code_template}\n'
|
| 570 |
+
f'arithemetic_templte: {self.arithemetic_templte}\n'
|
| 571 |
+
f'image_template: {self.image_template}\n'
|
| 572 |
+
f'\n入参 `name` 支持: ``'
|
| 573 |
+
)
|
| 574 |
+
return
|
| 575 |
+
|
| 576 |
+
if self.name == 'antglm_raw' or self.prompt_style == PromptStyle.ANTGLM_RAW:
|
| 577 |
+
self.prompt_style = PromptStyle.ANTGLM_RAW
|
| 578 |
+
self.role_template = '{}'
|
| 579 |
+
self.role_human = '用户: '
|
| 580 |
+
self.role_assistant = '机器人: '
|
| 581 |
+
self.turn_start = '第{}轮\n'
|
| 582 |
+
self.general_role_end = '\n'
|
| 583 |
+
|
| 584 |
+
if self.name in ['chatglm1', 'chatglm2'] or self.prompt_style == PromptStyle.CHATGLM:
|
| 585 |
+
self.prompt_style = PromptStyle.CHATGLM
|
| 586 |
+
self.role_template = '{}'
|
| 587 |
+
self.role_human = '问:'
|
| 588 |
+
self.role_assistant = '答:'
|
| 589 |
+
self.turn_start = '[Round {}]\n'
|
| 590 |
+
if self.name == 'chatglm1':
|
| 591 |
+
self.general_role_end = '\n'
|
| 592 |
+
else:
|
| 593 |
+
self.general_role_end = '\n\n'
|
| 594 |
+
|
| 595 |
+
elif self.name == 'chatglm3' or self.prompt_style == PromptStyle.CHATGLM3:
|
| 596 |
+
self.prompt_style = PromptStyle.CHATGLM3
|
| 597 |
+
self.system_template = '<|system|>\n {}'
|
| 598 |
+
self.role_human = '<|user|>\n '
|
| 599 |
+
self.role_assistant = '<|assistant|>\n '
|
| 600 |
+
self.role_template = '{}'
|
| 601 |
+
|
| 602 |
+
elif self.name == 'llama2' or self.prompt_style == PromptStyle.LLAMA2:
|
| 603 |
+
self.prompt_style = PromptStyle.LLAMA2
|
| 604 |
+
self.role_template = '{}'
|
| 605 |
+
self.system_template = '[INST] <<SYS>>\n{}\n<</SYS>>\n\n'
|
| 606 |
+
self.role_human = '[INST] '
|
| 607 |
+
self.role_assistant = '[/INST] '
|
| 608 |
+
self.human_end = ' '
|
| 609 |
+
self.assistant_end = ' </s><s>'
|
| 610 |
+
|
| 611 |
+
elif self.name == 'qwen':
|
| 612 |
+
self.prompt_style = PromptStyle.CHATML
|
| 613 |
+
self.role_template = '{}'
|
| 614 |
+
self.system_template = '<|im_start|>system\n{}'
|
| 615 |
+
if not self.system_message:
|
| 616 |
+
self.system_message = 'You are a helpful assistant.'
|
| 617 |
+
self.role_human = '<|im_start|>user\n'
|
| 618 |
+
self.role_assistant = '<|im_start|>assistant\n'
|
| 619 |
+
self.general_role_end = '<|im_end|>\n'
|
| 620 |
+
|
| 621 |
+
elif self.name == 'baichuan':
|
| 622 |
+
self.prompt_style = PromptStyle.BAICHUAN2
|
| 623 |
+
self.role_template = '{}'
|
| 624 |
+
self.system_template = '{}'
|
| 625 |
+
self.role_human = '<token_id-195>'
|
| 626 |
+
self.role_assistant = '<token_id-196>'
|
| 627 |
+
|
| 628 |
+
if not self.system_template:
|
| 629 |
+
self.system_template = '{}'
|
| 630 |
+
|
| 631 |
+
def readable_messages(self) -> str:
|
| 632 |
+
'''将 messages 输出为人类可读的字符串, 方便分析数据.'''
|
| 633 |
+
pass
|
| 634 |
+
|
| 635 |
+
@property
|
| 636 |
+
def prompt_str(self) -> str:
|
| 637 |
+
'''将 Chat 对象转成 prompt str, 合并 human/assitant 输出为 format 字符串.'''
|
| 638 |
+
return f'{self.prompt_inout["input"]}{self.prompt_inout["output"]}'
|
| 639 |
+
|
| 640 |
+
@classmethod
|
| 641 |
+
def _format_packs(cls, packs: Dict[str, List[str]]) -> Dict[str, List[str]]:
|
| 642 |
+
'''格式化 pack 样本, 输出相同 pack inputs, outputs 个数.'''
|
| 643 |
+
_packs = copy.deepcopy(packs)
|
| 644 |
+
if len(_packs['input']) - 1 == len(_packs['output']):
|
| 645 |
+
_packs['output'].append('')
|
| 646 |
+
|
| 647 |
+
if len(_packs['input']) != len(_packs['output']):
|
| 648 |
+
print(packs)
|
| 649 |
+
raise ValueError(
|
| 650 |
+
'输入 input 和 output 数量不匹配, '
|
| 651 |
+
f'input num: {len(packs["input"])}, '
|
| 652 |
+
f'output num: {len(packs["output"])}'
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
return _packs
|
| 656 |
+
|
| 657 |
+
@property
|
| 658 |
+
def prompt_inout(self) -> Dict[str, str]:
|
| 659 |
+
'''将 Chat 对象转成 input prompt, output prompt 字符串.
|
| 660 |
+
|
| 661 |
+
Returns:
|
| 662 |
+
`Dict[str, str]`, 示例:
|
| 663 |
+
{
|
| 664 |
+
"input": "<role>SYSTEM</role>xxxx<role>HUMAN</role>你好<role>ASSISTANT</role>你好,有什么可以帮您?<role>ASSISTANT</role>", # noqa
|
| 665 |
+
"output": "你好,有什么可以帮您?"
|
| 666 |
+
}
|
| 667 |
+
'''
|
| 668 |
+
packs = self._format_packs(self.prompt_pack)
|
| 669 |
+
|
| 670 |
+
# 兼容逻辑
|
| 671 |
+
if self.prompt_style == PromptStyle.ANTGLM_RAW:
|
| 672 |
+
packs['input'] = [f'{item} ' for item in packs['input']]
|
| 673 |
+
|
| 674 |
+
prompt_input = ''.join([f'{x}{y}' for x, y in zip(packs['input'][:-1], packs['output'][:-1])])
|
| 675 |
+
prompt_input += packs['input'][-1]
|
| 676 |
+
prompt_output = packs['output'][-1]
|
| 677 |
+
|
| 678 |
+
# 兼容逻辑
|
| 679 |
+
if self.prompt_style == PromptStyle.ANTGLM_RAW:
|
| 680 |
+
prompt_input = prompt_input.strip()
|
| 681 |
+
|
| 682 |
+
return {
|
| 683 |
+
'input': prompt_input,
|
| 684 |
+
'output': prompt_output,
|
| 685 |
+
}
|
| 686 |
+
|
| 687 |
+
@property
|
| 688 |
+
def prompt_pack(self) -> Dict[str, List[str]]:
|
| 689 |
+
'''将数据对象转成 pack input prompt, output prompt 字符串列表.:
|
| 690 |
+
|
| 691 |
+
Returns:
|
| 692 |
+
`Dict[str, List[str]]`, 示例:
|
| 693 |
+
|
| 694 |
+
{
|
| 695 |
+
"input": [
|
| 696 |
+
"<role>SYSTEM</role>xxxx<role>HUMAN</role>你好<role>ASSISTANT</role>",
|
| 697 |
+
"<role>HUMAN</role>讲个笑话<role>ASSISTANT</role>",
|
| 698 |
+
"<role>OBSERVATION</role>{\"weather\": \"晴\"}<role>ASSISTANT</role>"
|
| 699 |
+
],
|
| 700 |
+
"output": [
|
| 701 |
+
"你好,有什么可以帮您?",
|
| 702 |
+
"笑话 1",
|
| 703 |
+
"今天天气 xxx"
|
| 704 |
+
]
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
'''
|
| 708 |
+
inputs = []
|
| 709 |
+
outputs = []
|
| 710 |
+
|
| 711 |
+
# 最开始 system 构造
|
| 712 |
+
system_prompt = ''
|
| 713 |
+
if self.system_message:
|
| 714 |
+
system_prompt = self.system_template.format(self.system_message)
|
| 715 |
+
|
| 716 |
+
if system_prompt:
|
| 717 |
+
ret = system_prompt + self.general_role_end
|
| 718 |
+
else:
|
| 719 |
+
ret = ''
|
| 720 |
+
|
| 721 |
+
# 有些 prompt style 单轮指令没有 format
|
| 722 |
+
if self.prompt_style in [
|
| 723 |
+
PromptStyle.ANTGLM_RAW,
|
| 724 |
+
PromptStyle.ANTGLM_ONLY_MULTITURN_CHAT,
|
| 725 |
+
]:
|
| 726 |
+
if len(self.messages) <= 2:
|
| 727 |
+
output = ''
|
| 728 |
+
for role, message in self.messages:
|
| 729 |
+
if role == self.role_assistant:
|
| 730 |
+
output = message
|
| 731 |
+
else:
|
| 732 |
+
input = ret + message
|
| 733 |
+
return {
|
| 734 |
+
'input': [input],
|
| 735 |
+
'output': [output],
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
# 多轮对话
|
| 739 |
+
if self.name in ['antglm_raw', 'chatglm2']:
|
| 740 |
+
round_start = 1
|
| 741 |
+
else:
|
| 742 |
+
round_start = 0
|
| 743 |
+
|
| 744 |
+
for i, (role, message) in enumerate(self.messages):
|
| 745 |
+
# 轮次信息
|
| 746 |
+
if self.name in ['antglm_raw', 'chatglm1', 'chatglm2']:
|
| 747 |
+
if i % 2 == 0:
|
| 748 |
+
ret += self.turn_start.format(i // 2 + round_start)
|
| 749 |
+
|
| 750 |
+
# 角色 + 内容
|
| 751 |
+
role_end = self.general_role_end
|
| 752 |
+
if role == self.role_assistant and self.assistant_end:
|
| 753 |
+
role_end = self.assistant_end
|
| 754 |
+
elif self.human_end:
|
| 755 |
+
role_end = self.human_end
|
| 756 |
+
|
| 757 |
+
ret += self.role_template.format(role) + message + role_end
|
| 758 |
+
|
| 759 |
+
if role == self.role_assistant:
|
| 760 |
+
# output 只保留实际 assistant 内容
|
| 761 |
+
if not message:
|
| 762 |
+
outputs.append('')
|
| 763 |
+
else:
|
| 764 |
+
outputs.append(message + role_end)
|
| 765 |
+
# input 需要连接 assistant role
|
| 766 |
+
inputs[-1] += ret[: -len(message + role_end)]
|
| 767 |
+
elif all(
|
| 768 |
+
[
|
| 769 |
+
role == self.role_observation,
|
| 770 |
+
len(self.messages) > 1,
|
| 771 |
+
self.messages[i - 1][0] != self.role_assistant,
|
| 772 |
+
]
|
| 773 |
+
):
|
| 774 |
+
# observation 之前不是 assistant, 需要将 observation 和上一个 input 连接一起
|
| 775 |
+
continue
|
| 776 |
+
else:
|
| 777 |
+
inputs.append(ret)
|
| 778 |
+
ret = ''
|
| 779 |
+
|
| 780 |
+
# 最后一轮不是机器人回复, 需要拼接机器人 role, 用于模型生成
|
| 781 |
+
if i == len(self.messages) - 1 and role != self.role_assistant:
|
| 782 |
+
inputs[-1] += self.role_template.format(self.role_assistant).strip()
|
| 783 |
+
|
| 784 |
+
# 兼容逻辑, 去除 inputs 最后空格符号
|
| 785 |
+
if self.prompt_style == PromptStyle.ANTGLM_RAW:
|
| 786 |
+
inputs = [item.strip() for item in inputs]
|
| 787 |
+
|
| 788 |
+
return {
|
| 789 |
+
'input': inputs,
|
| 790 |
+
'output': outputs,
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
@property
|
| 794 |
+
def turns_num(self) -> int:
|
| 795 |
+
'''和机器人的交互轮数, 以用户输出多少次为轮数个数.'''
|
| 796 |
+
return sum([1 if msg[0] == self.role_human else 0 for msg in self.messages])
|
| 797 |
+
|
| 798 |
+
def to_json(self) -> dict:
|
| 799 |
+
'''输出 chat json dict 格式, 包含不同角色和机器人交互的每轮信息.
|
| 800 |
+
|
| 801 |
+
Returns
|
| 802 |
+
`List[dict]`, {
|
| 803 |
+
"id": "xx",
|
| 804 |
+
"messages": [
|
| 805 |
+
{"role": "HUMAN", "content": "xxx"}
|
| 806 |
+
]
|
| 807 |
+
"turns": [
|
| 808 |
+
{"HUMAN": "xx", "OBSERVATION": "xx", "ASSISTANT": "xx"}
|
| 809 |
+
]
|
| 810 |
+
}
|
| 811 |
+
'''
|
| 812 |
+
turns = []
|
| 813 |
+
messages = []
|
| 814 |
+
turn = {}
|
| 815 |
+
for msg in self.messages:
|
| 816 |
+
if msg[0] == self.role_assistant:
|
| 817 |
+
messages.append({'role': 'ASSISTANT', 'content': msg[1]})
|
| 818 |
+
turn['ASSISTANT'] = msg[1]
|
| 819 |
+
turns.append(turn)
|
| 820 |
+
turn = {}
|
| 821 |
+
|
| 822 |
+
if msg[0] == self.role_human:
|
| 823 |
+
messages.append({'role': 'HUMAN', 'content': msg[1]})
|
| 824 |
+
turn['HUMAN'] = msg[1]
|
| 825 |
+
|
| 826 |
+
if msg[0] == self.role_observation:
|
| 827 |
+
messages.append({'role': 'OBSERVATION', 'content': msg[1]})
|
| 828 |
+
turn['OBSERVATION'] = msg[1]
|
| 829 |
+
|
| 830 |
+
if self.messages[-1][0] == self.role_human:
|
| 831 |
+
messages.append({'role': 'ASSISTANT', 'content': ''})
|
| 832 |
+
turn['ASSISTANT'] = ''
|
| 833 |
+
turns.append(turn)
|
| 834 |
+
|
| 835 |
+
result = self.origin_json or {}
|
| 836 |
+
result.update(
|
| 837 |
+
{
|
| 838 |
+
'id': self.id,
|
| 839 |
+
'name': self.name,
|
| 840 |
+
'source': self.source,
|
| 841 |
+
'lang': self.lang,
|
| 842 |
+
'topic': self.topic,
|
| 843 |
+
'system_template': self.system_template,
|
| 844 |
+
'system_message': self.system_message,
|
| 845 |
+
'turns': turns,
|
| 846 |
+
'messages': messages,
|
| 847 |
+
}
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
return result
|
| 851 |
+
|
| 852 |
+
def set_system_message(self, system_message: str):
|
| 853 |
+
'''Set the system message.'''
|
| 854 |
+
self.system_message = system_message
|
| 855 |
+
|
| 856 |
+
def append_message(self, role: str, message: str):
|
| 857 |
+
'''Append a new message.'''
|
| 858 |
+
if not message:
|
| 859 |
+
message = ''
|
| 860 |
+
self.messages.append([role, message])
|
| 861 |
+
|
| 862 |
+
def to_openai_api_messages(self) -> List[dict]:
|
| 863 |
+
'''Convert the conversation to OpenAI chat completion format.'''
|
| 864 |
+
ret = [{'role': 'system', 'content': self.system_message}]
|
| 865 |
+
|
| 866 |
+
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
| 867 |
+
if i % 2 == 0:
|
| 868 |
+
ret.append({'role': 'user', 'content': msg})
|
| 869 |
+
else:
|
| 870 |
+
if msg is not None:
|
| 871 |
+
ret.append({'role': 'assistant', 'content': msg})
|
| 872 |
+
return ret
|
| 873 |
+
|
| 874 |
+
def copy(self):
|
| 875 |
+
return copy.deepcopy(self)
|