yarkcy commited on
Commit
b4f8742
·
verified ·
1 Parent(s): 051e58f

Upload chat_format.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. chat_format.py +875 -0
chat_format.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''AntGLM Chat-model data format.
2
+
3
+ 格式化 AntGLM 以及各种开源模型的符号系统:
4
+ - 确定 Chat 模型依赖的文件数据结构协议
5
+ - 确定单轮/多轮的统一结构
6
+ - 确定 Chat 符号系统的协议, 包括角色定义、分隔符等
7
+ - 方便做开源模型依赖的 prompt 转换
8
+ - 支持工具、代码、推理等支持
9
+
10
+ 参考 FastChat Conversation 对象的设计思路.
11
+ Reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
12
+ '''
13
+
14
+ import copy
15
+ import dataclasses
16
+ import logging
17
+ import re
18
+ import uuid
19
+ from copy import deepcopy
20
+ from enum import IntEnum, auto
21
+ from typing import Dict, List, Optional, Tuple
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class PromptStyle(IntEnum):
27
+ '''Prompt styles.'''
28
+
29
+ # 原始 antglm format 格式, 单轮指令没有结构, 多轮 `第1轮\n用户: xx\n机器人: xx\n`
30
+ ANTGLM_RAW = auto()
31
+ # Chat format 格式, 单轮多轮统一为 chat format 格式
32
+ ANTGLM_CHAT = auto()
33
+ # 单轮指令没有结构, 只有多轮为 chat format 格式
34
+ ANTGLM_ONLY_MULTITURN_CHAT = auto()
35
+ # OpenAI ChatML 格式, 包括千问
36
+ CHATML = auto()
37
+ # LLAMA2 格式
38
+ LLAMA2 = auto()
39
+ # ChatGLM 1/2 格式
40
+ CHATGLM = auto()
41
+ # ChatGLM3 格式
42
+ CHATGLM3 = auto()
43
+ # 百川格式
44
+ BAICHUAN2 = auto()
45
+
46
+
47
+ @dataclasses.dataclass
48
+ class Chat:
49
+ '''Chat 数据符号结构, 格式化 AntGLM 以及各种开源模型的符号系统.
50
+
51
+ Examples:
52
+
53
+ ```python
54
+ >>> from antllm.data.chat_format import Chat
55
+
56
+ >>> ### 从 json 数据结构创建 chat 对象, 并且 format 结构使用 AntGLM 原始结构
57
+ >>> input_json = {
58
+ ... "messages": [
59
+ ... {"role": "HUMAN", "content": "讲一个笑话"},
60
+ ... {"role": "ASSISTANT", "content": "为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!"},
61
+ ... {"role": "HUMAN", "content": "不好笑,换个程序员的笑话"}
62
+ ... ],
63
+ ... }
64
+ >>> chat = Chat.from_json(input_json, name='antglm_raw')
65
+
66
+ >>> ### 根据 chat 对象创建大模型训练所需 pack 数据
67
+ >>> pack_data = chat.prompt_pack
68
+ >>> print(pack_data)
69
+
70
+ >>> ### 根据 chat 对象创建大模型训练所需 input, output 数据
71
+ >>> data = chat.prompt_inout
72
+ >>> print(data)
73
+
74
+ >>> ### 根据 chat 对象创建大模型预测用的 prompt
75
+ >>> prompt = chat.prompt_str
76
+ >>> print(prompt)
77
+
78
+ >>> ### 从大模型训练数据 {"input": "xx", "output": "xx"} 中创建 chat 对象
79
+ >>> data = {
80
+ ... 'input': (
81
+ ... '第1轮\n用户: 讲一个笑话\n机器人: 为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!\n'
82
+ ... '第2轮\n用户: 不好笑,换个程序员的笑话\n机器人:'
83
+ ... ),
84
+ ... 'output': ''
85
+ ... }
86
+ >>> chat = Chat.from_inout(data, name='antglm_raw')
87
+
88
+ >>> ### 从大模型 pack 训练数据创建 chat 对象列表
89
+ >>> pack_data = {
90
+ ... 'inputs': ['第1轮\n用户: 讲一个笑话\n机器人:', '第2轮\n用户: 不好笑,换个程序员的笑话\n机器人:', '第1轮\n用户: 写首诗\n机器人:'],
91
+ ... 'outputs': [
92
+ ... '为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!\n',
93
+ ... '为什么程序员总是喜欢使用黑色主题?因为他们喜欢“黑暗模式”(Dark Mode),这样他们就可以在晚上加班时更好地隐藏自己的错误!',
94
+ ... '']
95
+ ... }
96
+ >>> chats = Chat.from_pack(pack_data, name='antglm_raw')
97
+ >>> assert len(chats) == 2
98
+ >>> print(chats[0])
99
+ >>> print(chats[1])
100
+
101
+ >>> ### 显示总交互轮数 (以用户输出多少次为轮数个数)
102
+ >>> print(chat.turns_num)
103
+
104
+ >>> ### 根据 chat 对象创建 json 格式化输出
105
+ >>> data_json = chat.to_json()
106
+ >>> print(data_json)
107
+
108
+ >>> ### 增加轮次信息
109
+ >>> content = (
110
+ ... '为什么程序员总是喜欢使用黑色主题?'
111
+ ... '因为他们喜欢“黑暗模式”(Dark Mode),这样他们就可以在晚上加班时更好地隐藏自己的错误!'
112
+ ... )
113
+ >>> chat.append_message(chat.role_assistant, content)
114
+
115
+ >>> ### 将 chat 对象转成 OpenAI ChatCompletion 接口的入参
116
+ >>> openai_messages = chat.to_openai_api_messages()
117
+ >>> print(openai_messages)
118
+
119
+ >>> ### 复制一个 chat 对象
120
+ >>> chat_new = chat.copy()
121
+ ```
122
+ '''
123
+
124
+ # 数据结构名称
125
+ id: str = None
126
+
127
+ # format 支持: antglm_raw, antglm_chat, chatglm1, chatglm2, llama2, qwen, baichuan2
128
+ name: Optional[str] = None
129
+
130
+ # Prompt 风格
131
+ prompt_style: Optional[PromptStyle] = None
132
+
133
+ # System Template 和 message
134
+ system_template: str = '<role>SYSTEM</role>{}'
135
+ system_message: str = ''
136
+
137
+ # 角色定义
138
+ role_human: str = 'HUMAN'
139
+ role_assistant: str = 'ASSISTANT'
140
+ role_observation: str = 'OBSERVATION'
141
+ role_template: str = '<role>{}</role>'
142
+
143
+ # 每轮符号定义
144
+ turn_start: str = ''
145
+ human_end: str = ''
146
+ assistant_start: str = ''
147
+ assistant_end: str = ''
148
+ assistant_end_ids: Optional[List[int]] = None
149
+ general_role_end: str = ''
150
+
151
+ # agent 符号定义
152
+ tool_template = '<tool>{}</tool>'
153
+ code_template = '<code>{}</code>'
154
+ arithemetic_templte = '<arithemetic>{}</arithemetic>'
155
+ image_template = '<image>{}</image>'
156
+
157
+ # All messages. Each item is (role, message).
158
+ messages: List[Tuple[str, str]] = ()
159
+
160
+ # messages 中用于 few-shot messages
161
+ offset: int = 0
162
+
163
+ # 其他 meta data
164
+ source: Optional[str] = None
165
+ lang: Optional[str] = None
166
+ topic: Optional[str] = None
167
+
168
+ # 原始 json 数据
169
+ origin_json: Optional[dict] = None
170
+
171
+ @property
172
+ def support_names(self) -> Dict[str, str]:
173
+ '''支持的数据对象名称.'''
174
+ return {
175
+ 'antglm_raw': '原始 antglm format 格式, 单轮指令没有结构, 多轮 `第1轮\\n用户:xx\\n机器人xx\\n`',
176
+ 'antglm_chat': 'Chat format 格式, 单轮多轮统一为 chat format 格式',
177
+ 'chatglm1': 'chatglm1 format',
178
+ 'chatglm2': 'chatglm2 format',
179
+ 'llama2': 'llama2 format',
180
+ 'qwen': '千问 format',
181
+ 'baichuan2': '百川 2 format',
182
+ }
183
+
184
+ @classmethod
185
+ def from_json(
186
+ cls,
187
+ input: dict,
188
+ name: Optional[str] = None,
189
+ prompt_style: Optional[PromptStyle] = None,
190
+ ):
191
+ '''从文件数据结构到数据对象的转换.
192
+
193
+ Params:
194
+ name: `Optional[str]`, 符号系统名称
195
+ - format 支持: antglm_raw, antglm_chat, chatglm1, chatglm2, llama2, qwen, baichuan2
196
+ - 如果指定了 format name, 使用该 name 符号系统, 否则使用 input 中 `name` 字段
197
+
198
+ prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格
199
+
200
+ input: `dict`, 文件中的 json dict 对象, 协议为:
201
+ - 既支持 `messages` 字段, 也支持 `turns` 字段
202
+ {
203
+ "id": "xxx",
204
+ "name": "antglm",
205
+ "source": "xxx",
206
+ "lang": "xx",
207
+ "topic": "xx",
208
+ "system_template": "",
209
+ "system_message": "xx",
210
+ "messages": [
211
+ {
212
+ "role": "HUMAN",
213
+ "content": "Hi"
214
+ },
215
+ {
216
+ "role": "ASSISTANT",
217
+ "content": "Hello"
218
+ },
219
+ {
220
+ "role": "OBSERVATION",
221
+ "content": "xxx"
222
+ },
223
+ {
224
+ "role": "ASSISTANT",
225
+ "content": "xxx"
226
+ }
227
+ ],
228
+ "turns": [
229
+ {"HUMAN": "xxx", "OBSERVATION": "xx", "ASSISTANT": "xx"}
230
+ ]
231
+ }
232
+
233
+ Returns:
234
+ `Chat` 对象
235
+ '''
236
+ _id = input.get('id')
237
+ if name:
238
+ _name = name
239
+ else:
240
+ _name = input.get('name')
241
+ source = input.get('source')
242
+ lang = input.get('lang')
243
+ topic = input.get('topic')
244
+ kwargs = {}
245
+ if 'system_template' in input:
246
+ kwargs['system_template'] = input['system_template']
247
+ if 'system_message' in input:
248
+ kwargs['system_message'] = input['system_message']
249
+
250
+ # 转换成 Chat 对象
251
+ chat = cls(
252
+ id=_id,
253
+ name=_name,
254
+ prompt_style=prompt_style,
255
+ source=source,
256
+ lang=lang,
257
+ topic=topic,
258
+ origin_json=deepcopy(input),
259
+ **kwargs,
260
+ )
261
+ if 'messages' in input:
262
+ for msg in input['messages']:
263
+ if msg['role'] == 'HUMAN':
264
+ role = chat.role_human
265
+ elif msg['role'] == 'OBSERVATION':
266
+ role = chat.role_observation
267
+ elif msg['role'] == 'ASSISTANT':
268
+ role = chat.role_assistant
269
+ else:
270
+ raise ValueError(f'不支持数据集中的 role: {msg["role"]}')
271
+
272
+ chat.append_message(role, msg['content'])
273
+
274
+ elif 'turns' in input:
275
+ for turn in input['turns']:
276
+ if 'HUMAN' in turn:
277
+ content = turn['HUMAN']
278
+ chat.append_message(chat.role_human, content)
279
+ if 'OBSERVATION' in turn:
280
+ content = turn['OBSERVATION']
281
+ chat.append_message(chat.role_observation, content)
282
+ if 'ASSISTANT' in turn:
283
+ content = turn['ASSISTANT']
284
+ chat.append_message(chat.role_assistant, content)
285
+
286
+ return chat
287
+
288
+ @classmethod
289
+ def from_pack(
290
+ cls,
291
+ packs: Dict[str, List[str]],
292
+ name: str,
293
+ prompt_style: Optional[PromptStyle] = None,
294
+ ) -> list:
295
+ '''根据 pack 数据创建 Chat 对象.
296
+
297
+ Params:
298
+ packs: `dict`, pack 样本数据
299
+ {
300
+ 'inputs': ['xx', 'xx'],
301
+ 'outputs': ['xx', 'xx'],
302
+ }
303
+
304
+ name: `str`, 符号系统名称
305
+ prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格
306
+ '''
307
+ chat = cls(name=name, prompt_style=prompt_style)
308
+ packs = cls._format_packs(packs)
309
+
310
+ sys_pattern = re.compile(chat.system_template.format(r'(.*?)'), re.DOTALL)
311
+ turn_pattern = re.compile(chat.turn_start.format(r'(\d+)'), re.DOTALL)
312
+ human_pattern = re.compile(chat.role_template.format(chat.role_human).strip(), re.DOTALL)
313
+ observe_pattern = re.compile(chat.role_template.format(chat.role_observation).strip(), re.DOTALL)
314
+ assistant_pattern = re.compile(chat.role_template.format(chat.role_assistant).strip(), re.DOTALL)
315
+
316
+ chats = []
317
+ for input, output in zip(packs['input'], packs['output']):
318
+ # system message
319
+ sys_match = sys_pattern.search(input)
320
+ if sys_match and sys_match.group(0):
321
+ # system 指令只在首轮, 新增 chat 对象
322
+ if len(chat.messages) > 0:
323
+ chats.append(chat)
324
+ chat = cls(name=name, prompt_style=prompt_style)
325
+
326
+ input = input[sys_match.end() :]
327
+ chat.system_message = sys_match.group(1)
328
+
329
+ # turn start
330
+ turn_match = turn_pattern.search(input)
331
+ if turn_match and turn_match.group(0):
332
+ # 当出现下一个轮次开始信息, 新增 chat 对象
333
+ if name in ['antglm', 'antglm_raw', 'chatglm2']:
334
+ round_start = 1
335
+ else:
336
+ round_start = 0
337
+
338
+ if all(
339
+ [
340
+ len(turn_match.groups()) > 0,
341
+ int(turn_match.group(1)) == round_start,
342
+ len(chat.messages) > 0,
343
+ ]
344
+ ):
345
+ chats.append(chat)
346
+ chat = cls(name=name, prompt_style=prompt_style)
347
+
348
+ input = input[turn_match.end() :]
349
+
350
+ human_iter = human_pattern.finditer(input)
351
+ observe_iter = observe_pattern.finditer(input)
352
+ assistant_iter = assistant_pattern.finditer(input)
353
+ human_match = next(human_iter, None)
354
+ observe_match = next(observe_iter, None)
355
+ assistant_match = next(assistant_iter, None)
356
+
357
+ if not human_match and not observe_match:
358
+ # 无 role format
359
+ chat.append_message(chat.role_human, input)
360
+
361
+ while human_match or observe_match:
362
+ next_human_match = next(human_iter, None)
363
+ next_observe_match = next(observe_iter, None)
364
+ input = cls._append_human_observation(
365
+ chat,
366
+ input,
367
+ human_match=human_match,
368
+ next_human_match=next_human_match,
369
+ observe_match=observe_match,
370
+ next_observe_match=next_observe_match,
371
+ assistant_match=assistant_match,
372
+ )
373
+
374
+ human_match = next_human_match
375
+ observe_match = next_observe_match
376
+ next_human_match = next(human_iter, None)
377
+ next_observe_match = next(observe_iter, None)
378
+
379
+ if output:
380
+ chat.append_message(chat.role_assistant, output)
381
+
382
+ if chat.messages:
383
+ chats.append(chat)
384
+
385
+ return chats
386
+
387
+ @classmethod
388
+ def _append_human_observation(
389
+ cls,
390
+ chat,
391
+ input: str,
392
+ human_match: Optional[re.Match] = None,
393
+ next_human_match: Optional[re.Match] = None,
394
+ observe_match: Optional[re.Match] = None,
395
+ next_observe_match: Optional[re.Match] = None,
396
+ assistant_match: Optional[re.Match] = None,
397
+ ) -> str:
398
+ '''给 chat 对象增加 human/observation message.'''
399
+ if observe_match:
400
+ # observation 在 human 之后
401
+ if observe_match.span()[0] > observe_match.span()[0]:
402
+ human_str = input[observe_match.span()[1] : observe_match.span()[0]]
403
+ observe_str = input[observe_match.span()[1] : assistant_match.span()[0]]
404
+ chat.append_message(chat.role_human, human_str.strip())
405
+ input_end = observe_match.span()[1]
406
+ if observe_match.span()[0] < next_human_match.span()[0]:
407
+ chat.append_message(chat.role_observation, observe_str.strip())
408
+ input_end = assistant_match.span()[1]
409
+ else:
410
+ # observation 在 human 之前
411
+ human_str = input[observe_match.span()[1] : assistant_match.span()[0]]
412
+ observe_str = input[observe_match.span()[1] : observe_match.span()[0]]
413
+ chat.append_message(chat.role_observation, observe_str.strip())
414
+ input_end = observe_match.span()[1]
415
+ if observe_match.span()[0] < next_observe_match.span()[0]:
416
+ chat.append_message(chat.role_human, human_str.strip())
417
+ input_end = assistant_match.span()[1]
418
+ else:
419
+ if assistant_match:
420
+ human_str = input[human_match.span()[1] : assistant_match.span()[0]]
421
+ input_end = assistant_match.span()[1]
422
+ else:
423
+ human_str = input[human_match.span()[1] :]
424
+ input_end = len(input)
425
+ chat.append_message(chat.role_human, human_str.strip())
426
+
427
+ return input[input_end:]
428
+
429
+ @classmethod
430
+ def from_inout(
431
+ cls,
432
+ sample: Dict[str, str],
433
+ name: str,
434
+ prompt_style: Optional[PromptStyle] = None,
435
+ ):
436
+ '''根据单样本创建一个 Chat 对象.
437
+
438
+ Params:
439
+ sample: `Dict[str, str]`, input/output 数据样本
440
+ {
441
+ "input": "xxx",
442
+ "output": "xxx",
443
+ }
444
+
445
+ name: `str`, 符号系统名称
446
+ prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格
447
+ '''
448
+ chat = cls(name=name, prompt_style=prompt_style)
449
+ input = sample['input']
450
+ output = sample['output']
451
+
452
+ sys_pattern = re.compile(chat.system_template.format(r'(.*?)'), re.DOTALL)
453
+ turn_pattern = re.compile(chat.turn_start.format(r'(\d+)'), re.DOTALL)
454
+ human_pattern = re.compile(chat.role_template.format(chat.role_human).strip(), re.DOTALL)
455
+ observe_pattern = re.compile(chat.role_template.format(chat.role_observation).strip(), re.DOTALL)
456
+ assistant_pattern = re.compile(chat.role_template.format(chat.role_assistant).strip(), re.DOTALL)
457
+
458
+ # 去除轮次信息
459
+ input = turn_pattern.sub('', input)
460
+
461
+ # system message search
462
+ sys_match = sys_pattern.search(input)
463
+ if sys_match and sys_match.group(0):
464
+ input = input[sys_match.end() :]
465
+ chat.system_message = sys_match.group(1)
466
+
467
+ human_iter = human_pattern.finditer(input)
468
+ observe_iter = observe_pattern.finditer(input)
469
+ assistant_iter = assistant_pattern.finditer(input)
470
+ human_match = next(human_iter, None)
471
+ observe_match = next(observe_iter, None)
472
+ assistant_match = next(assistant_iter, None)
473
+ next_human_match = next(human_iter, None)
474
+ next_observe_match = next(observe_iter, None)
475
+
476
+ while any(
477
+ [
478
+ human_match,
479
+ observe_match,
480
+ assistant_match,
481
+ ]
482
+ ):
483
+
484
+ # human/observation 先后顺序可能不一样, 并且有可能有多个
485
+ # 判断 assitant 之前是否还有 human/observation
486
+ while any(
487
+ [
488
+ human_match and human_match.span()[0] < assistant_match.span()[0],
489
+ observe_match and observe_match.span()[0] < assistant_match.span()[0],
490
+ next_human_match and next_human_match.span()[0] < assistant_match.span()[0],
491
+ next_observe_match and next_observe_match.span()[0] < assistant_match.span()[0],
492
+ ]
493
+ ):
494
+ if not input:
495
+ break
496
+
497
+ cls._append_human_observation(
498
+ chat,
499
+ input,
500
+ human_match=human_match,
501
+ next_human_match=next_human_match,
502
+ observe_match=observe_match,
503
+ next_observe_match=next_observe_match,
504
+ assistant_match=assistant_match,
505
+ )
506
+
507
+ human_match = next_human_match
508
+ observe_match = next_observe_match
509
+ next_human_match = next(human_iter, None)
510
+ next_observe_match = next(observe_iter, None)
511
+
512
+ # assistant message
513
+ if assistant_match and assistant_match.span():
514
+ if observe_match:
515
+ if observe_match.span() and observe_match.span()[0] < human_match.span()[0]:
516
+ assistant_str = input[assistant_match.span()[1] : observe_match.span()[0]]
517
+ elif human_match:
518
+ if human_match.span():
519
+ assistant_str = input[assistant_match.span()[1] : human_match.span()[0]]
520
+ else:
521
+ assistant_str = input[assistant_match.span()[1] :]
522
+
523
+ if assistant_str:
524
+ chat.append_message(chat.role_assistant, assistant_str)
525
+
526
+ assistant_match = next(assistant_iter, None)
527
+
528
+ if output:
529
+ chat.append_message(chat.role_assistant, output)
530
+
531
+ return chat
532
+
533
+ def __hash__(self):
534
+ '''数据对象的 hash 函数.'''
535
+ return hash(self.id)
536
+
537
+ def __post_init__(self):
538
+ '''对象初始化后的处理, 处理包括:
539
+ - 根据数据对象名称, 支持转成其他开源数据对象的基本信息
540
+ '''
541
+ self.id = str(uuid.uuid4())
542
+ if not self.messages:
543
+ self.messages = []
544
+
545
+ if not self.name and not self.prompt_style:
546
+ logger.error('构造 Chat 对象至少包含以下一个入参: `name/prompt_style`.\n\n' '`name` 支持以下 format 名称:')
547
+ logger.error('\n'.join([f'{k}: {v}' for k, v in self.support_names.items()]))
548
+ logger.error('\n`prompt_style` 参考 antllm.data.chat_format.PromptStyle')
549
+ raise ValueError
550
+
551
+ if self.name == 'antglm':
552
+ # 默认 antglm 使用原始 antglm_raw - 第1轮\n用户: xx\n机器人: xx\n
553
+ self.name = 'antglm_raw'
554
+
555
+ if not self.name and self.prompt_style == PromptStyle.ANTGLM_CHAT:
556
+ logger.info(
557
+ 'Chat 对象入参没有 `name`, 默认使用 `ANTGLM_CHAT`, format:\n'
558
+ f'role_human: {self.role_human}\n'
559
+ f'role_assistant: {self.role_assistant}\n'
560
+ f'role_observation: {self.role_observation}\n'
561
+ f'role_template: {self.role_template}\n'
562
+ f'turn_start: {self.turn_start}\n'
563
+ f'human_end: {self.human_end}\n'
564
+ f'assistant_start: {self.assistant_start}\n'
565
+ f'assistant_end: {self.assistant_end}\n'
566
+ f'assistant_end_ids: {self.assistant_end_ids}\n'
567
+ f'general_role_end: {self.general_role_end}\n'
568
+ f'tool_template: {self.tool_template}\n'
569
+ f'code_template: {self.code_template}\n'
570
+ f'arithemetic_templte: {self.arithemetic_templte}\n'
571
+ f'image_template: {self.image_template}\n'
572
+ f'\n入参 `name` 支持: ``'
573
+ )
574
+ return
575
+
576
+ if self.name == 'antglm_raw' or self.prompt_style == PromptStyle.ANTGLM_RAW:
577
+ self.prompt_style = PromptStyle.ANTGLM_RAW
578
+ self.role_template = '{}'
579
+ self.role_human = '用户: '
580
+ self.role_assistant = '机器人: '
581
+ self.turn_start = '第{}轮\n'
582
+ self.general_role_end = '\n'
583
+
584
+ if self.name in ['chatglm1', 'chatglm2'] or self.prompt_style == PromptStyle.CHATGLM:
585
+ self.prompt_style = PromptStyle.CHATGLM
586
+ self.role_template = '{}'
587
+ self.role_human = '问:'
588
+ self.role_assistant = '答:'
589
+ self.turn_start = '[Round {}]\n'
590
+ if self.name == 'chatglm1':
591
+ self.general_role_end = '\n'
592
+ else:
593
+ self.general_role_end = '\n\n'
594
+
595
+ elif self.name == 'chatglm3' or self.prompt_style == PromptStyle.CHATGLM3:
596
+ self.prompt_style = PromptStyle.CHATGLM3
597
+ self.system_template = '<|system|>\n {}'
598
+ self.role_human = '<|user|>\n '
599
+ self.role_assistant = '<|assistant|>\n '
600
+ self.role_template = '{}'
601
+
602
+ elif self.name == 'llama2' or self.prompt_style == PromptStyle.LLAMA2:
603
+ self.prompt_style = PromptStyle.LLAMA2
604
+ self.role_template = '{}'
605
+ self.system_template = '[INST] <<SYS>>\n{}\n<</SYS>>\n\n'
606
+ self.role_human = '[INST] '
607
+ self.role_assistant = '[/INST] '
608
+ self.human_end = ' '
609
+ self.assistant_end = ' </s><s>'
610
+
611
+ elif self.name == 'qwen':
612
+ self.prompt_style = PromptStyle.CHATML
613
+ self.role_template = '{}'
614
+ self.system_template = '<|im_start|>system\n{}'
615
+ if not self.system_message:
616
+ self.system_message = 'You are a helpful assistant.'
617
+ self.role_human = '<|im_start|>user\n'
618
+ self.role_assistant = '<|im_start|>assistant\n'
619
+ self.general_role_end = '<|im_end|>\n'
620
+
621
+ elif self.name == 'baichuan':
622
+ self.prompt_style = PromptStyle.BAICHUAN2
623
+ self.role_template = '{}'
624
+ self.system_template = '{}'
625
+ self.role_human = '<token_id-195>'
626
+ self.role_assistant = '<token_id-196>'
627
+
628
+ if not self.system_template:
629
+ self.system_template = '{}'
630
+
631
+ def readable_messages(self) -> str:
632
+ '''将 messages 输出为人类可读的字符串, 方便分析数据.'''
633
+ pass
634
+
635
+ @property
636
+ def prompt_str(self) -> str:
637
+ '''将 Chat 对象转成 prompt str, 合并 human/assitant 输出为 format 字符串.'''
638
+ return f'{self.prompt_inout["input"]}{self.prompt_inout["output"]}'
639
+
640
+ @classmethod
641
+ def _format_packs(cls, packs: Dict[str, List[str]]) -> Dict[str, List[str]]:
642
+ '''格式化 pack 样本, 输出相同 pack inputs, outputs 个数.'''
643
+ _packs = copy.deepcopy(packs)
644
+ if len(_packs['input']) - 1 == len(_packs['output']):
645
+ _packs['output'].append('')
646
+
647
+ if len(_packs['input']) != len(_packs['output']):
648
+ print(packs)
649
+ raise ValueError(
650
+ '输入 input 和 output 数量不匹配, '
651
+ f'input num: {len(packs["input"])}, '
652
+ f'output num: {len(packs["output"])}'
653
+ )
654
+
655
+ return _packs
656
+
657
+ @property
658
+ def prompt_inout(self) -> Dict[str, str]:
659
+ '''将 Chat 对象转成 input prompt, output prompt 字符串.
660
+
661
+ Returns:
662
+ `Dict[str, str]`, 示例:
663
+ {
664
+ "input": "<role>SYSTEM</role>xxxx<role>HUMAN</role>你好<role>ASSISTANT</role>你好,有什么可以帮您?<role>ASSISTANT</role>", # noqa
665
+ "output": "你好,有什么可以帮您?"
666
+ }
667
+ '''
668
+ packs = self._format_packs(self.prompt_pack)
669
+
670
+ # 兼容逻辑
671
+ if self.prompt_style == PromptStyle.ANTGLM_RAW:
672
+ packs['input'] = [f'{item} ' for item in packs['input']]
673
+
674
+ prompt_input = ''.join([f'{x}{y}' for x, y in zip(packs['input'][:-1], packs['output'][:-1])])
675
+ prompt_input += packs['input'][-1]
676
+ prompt_output = packs['output'][-1]
677
+
678
+ # 兼容逻辑
679
+ if self.prompt_style == PromptStyle.ANTGLM_RAW:
680
+ prompt_input = prompt_input.strip()
681
+
682
+ return {
683
+ 'input': prompt_input,
684
+ 'output': prompt_output,
685
+ }
686
+
687
+ @property
688
+ def prompt_pack(self) -> Dict[str, List[str]]:
689
+ '''将数据对象转成 pack input prompt, output prompt 字符串列表.:
690
+
691
+ Returns:
692
+ `Dict[str, List[str]]`, 示例:
693
+
694
+ {
695
+ "input": [
696
+ "<role>SYSTEM</role>xxxx<role>HUMAN</role>你好<role>ASSISTANT</role>",
697
+ "<role>HUMAN</role>讲个笑话<role>ASSISTANT</role>",
698
+ "<role>OBSERVATION</role>{\"weather\": \"晴\"}<role>ASSISTANT</role>"
699
+ ],
700
+ "output": [
701
+ "你好,有什么可以帮您?",
702
+ "笑话 1",
703
+ "今天天气 xxx"
704
+ ]
705
+ }
706
+
707
+ '''
708
+ inputs = []
709
+ outputs = []
710
+
711
+ # 最开始 system 构造
712
+ system_prompt = ''
713
+ if self.system_message:
714
+ system_prompt = self.system_template.format(self.system_message)
715
+
716
+ if system_prompt:
717
+ ret = system_prompt + self.general_role_end
718
+ else:
719
+ ret = ''
720
+
721
+ # 有些 prompt style 单轮指令没有 format
722
+ if self.prompt_style in [
723
+ PromptStyle.ANTGLM_RAW,
724
+ PromptStyle.ANTGLM_ONLY_MULTITURN_CHAT,
725
+ ]:
726
+ if len(self.messages) <= 2:
727
+ output = ''
728
+ for role, message in self.messages:
729
+ if role == self.role_assistant:
730
+ output = message
731
+ else:
732
+ input = ret + message
733
+ return {
734
+ 'input': [input],
735
+ 'output': [output],
736
+ }
737
+
738
+ # 多轮对话
739
+ if self.name in ['antglm_raw', 'chatglm2']:
740
+ round_start = 1
741
+ else:
742
+ round_start = 0
743
+
744
+ for i, (role, message) in enumerate(self.messages):
745
+ # 轮次信息
746
+ if self.name in ['antglm_raw', 'chatglm1', 'chatglm2']:
747
+ if i % 2 == 0:
748
+ ret += self.turn_start.format(i // 2 + round_start)
749
+
750
+ # 角色 + 内容
751
+ role_end = self.general_role_end
752
+ if role == self.role_assistant and self.assistant_end:
753
+ role_end = self.assistant_end
754
+ elif self.human_end:
755
+ role_end = self.human_end
756
+
757
+ ret += self.role_template.format(role) + message + role_end
758
+
759
+ if role == self.role_assistant:
760
+ # output 只保留实际 assistant 内容
761
+ if not message:
762
+ outputs.append('')
763
+ else:
764
+ outputs.append(message + role_end)
765
+ # input 需要连接 assistant role
766
+ inputs[-1] += ret[: -len(message + role_end)]
767
+ elif all(
768
+ [
769
+ role == self.role_observation,
770
+ len(self.messages) > 1,
771
+ self.messages[i - 1][0] != self.role_assistant,
772
+ ]
773
+ ):
774
+ # observation 之前不是 assistant, 需要将 observation 和上一个 input 连接一起
775
+ continue
776
+ else:
777
+ inputs.append(ret)
778
+ ret = ''
779
+
780
+ # 最后一轮不是机器人回复, 需要拼接机器人 role, 用于模型生成
781
+ if i == len(self.messages) - 1 and role != self.role_assistant:
782
+ inputs[-1] += self.role_template.format(self.role_assistant).strip()
783
+
784
+ # 兼容逻辑, 去除 inputs 最后空格符号
785
+ if self.prompt_style == PromptStyle.ANTGLM_RAW:
786
+ inputs = [item.strip() for item in inputs]
787
+
788
+ return {
789
+ 'input': inputs,
790
+ 'output': outputs,
791
+ }
792
+
793
+ @property
794
+ def turns_num(self) -> int:
795
+ '''和机器人的交互轮数, 以用户输出多少次为轮数个数.'''
796
+ return sum([1 if msg[0] == self.role_human else 0 for msg in self.messages])
797
+
798
+ def to_json(self) -> dict:
799
+ '''输出 chat json dict 格式, 包含不同角色和机器人交互的每轮信息.
800
+
801
+ Returns
802
+ `List[dict]`, {
803
+ "id": "xx",
804
+ "messages": [
805
+ {"role": "HUMAN", "content": "xxx"}
806
+ ]
807
+ "turns": [
808
+ {"HUMAN": "xx", "OBSERVATION": "xx", "ASSISTANT": "xx"}
809
+ ]
810
+ }
811
+ '''
812
+ turns = []
813
+ messages = []
814
+ turn = {}
815
+ for msg in self.messages:
816
+ if msg[0] == self.role_assistant:
817
+ messages.append({'role': 'ASSISTANT', 'content': msg[1]})
818
+ turn['ASSISTANT'] = msg[1]
819
+ turns.append(turn)
820
+ turn = {}
821
+
822
+ if msg[0] == self.role_human:
823
+ messages.append({'role': 'HUMAN', 'content': msg[1]})
824
+ turn['HUMAN'] = msg[1]
825
+
826
+ if msg[0] == self.role_observation:
827
+ messages.append({'role': 'OBSERVATION', 'content': msg[1]})
828
+ turn['OBSERVATION'] = msg[1]
829
+
830
+ if self.messages[-1][0] == self.role_human:
831
+ messages.append({'role': 'ASSISTANT', 'content': ''})
832
+ turn['ASSISTANT'] = ''
833
+ turns.append(turn)
834
+
835
+ result = self.origin_json or {}
836
+ result.update(
837
+ {
838
+ 'id': self.id,
839
+ 'name': self.name,
840
+ 'source': self.source,
841
+ 'lang': self.lang,
842
+ 'topic': self.topic,
843
+ 'system_template': self.system_template,
844
+ 'system_message': self.system_message,
845
+ 'turns': turns,
846
+ 'messages': messages,
847
+ }
848
+ )
849
+
850
+ return result
851
+
852
+ def set_system_message(self, system_message: str):
853
+ '''Set the system message.'''
854
+ self.system_message = system_message
855
+
856
+ def append_message(self, role: str, message: str):
857
+ '''Append a new message.'''
858
+ if not message:
859
+ message = ''
860
+ self.messages.append([role, message])
861
+
862
+ def to_openai_api_messages(self) -> List[dict]:
863
+ '''Convert the conversation to OpenAI chat completion format.'''
864
+ ret = [{'role': 'system', 'content': self.system_message}]
865
+
866
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
867
+ if i % 2 == 0:
868
+ ret.append({'role': 'user', 'content': msg})
869
+ else:
870
+ if msg is not None:
871
+ ret.append({'role': 'assistant', 'content': msg})
872
+ return ret
873
+
874
+ def copy(self):
875
+ return copy.deepcopy(self)