shivavardhineedi commited on
Commit
91c804f
·
verified ·
1 Parent(s): 25f9e25

Delete conversation.py

Browse files
Files changed (1) hide show
  1. conversation.py +0 -383
conversation.py DELETED
@@ -1,383 +0,0 @@
1
- """
2
- Conversation prompt templates.
3
-
4
- We kindly request that you import fastchat instead of copying this file if you wish to use it.
5
- If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
6
- """
7
-
8
- import dataclasses
9
- from enum import IntEnum, auto
10
- from typing import Any, Dict, List, Tuple, Union
11
-
12
-
13
- class SeparatorStyle(IntEnum):
14
- """Separator styles."""
15
-
16
- ADD_COLON_SINGLE = auto()
17
- ADD_COLON_TWO = auto()
18
- ADD_COLON_SPACE_SINGLE = auto()
19
- NO_COLON_SINGLE = auto()
20
- NO_COLON_TWO = auto()
21
- ADD_NEW_LINE_SINGLE = auto()
22
- LLAMA2 = auto()
23
- CHATGLM = auto()
24
- CHATML = auto()
25
- CHATINTERN = auto()
26
- DOLLY = auto()
27
- RWKV = auto()
28
- PHOENIX = auto()
29
- ROBIN = auto()
30
- FALCON_CHAT = auto()
31
- CHATGLM3 = auto()
32
- INTERNVL_ZH = auto()
33
- MPT = auto()
34
-
35
-
36
- @dataclasses.dataclass
37
- class Conversation:
38
- """A class that manages prompt templates and keeps all conversation history."""
39
-
40
- # The name of this template
41
- name: str
42
- # The template of the system prompt
43
- system_template: str = '{system_message}'
44
- # The system message
45
- system_message: str = ''
46
- # The names of two roles
47
- roles: Tuple[str] = ('USER', 'ASSISTANT')
48
- # All messages. Each item is (role, message).
49
- messages: List[List[str]] = ()
50
- # The number of few shot examples
51
- offset: int = 0
52
- # The separator style and configurations
53
- sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
54
- sep: str = '\n'
55
- sep2: str = None
56
- # Stop criteria (the default one is EOS token)
57
- stop_str: Union[str, List[str]] = None
58
- # Stops generation if meeting any token in this list
59
- stop_token_ids: List[int] = None
60
-
61
- def get_prompt(self) -> str:
62
- """Get the prompt for generation."""
63
- system_prompt = self.system_template.format(system_message=self.system_message)
64
- if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
65
- ret = system_prompt + self.sep
66
- for role, message in self.messages:
67
- if message:
68
- ret += role + ': ' + message + self.sep
69
- else:
70
- ret += role + ':'
71
- return ret
72
- elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
73
- seps = [self.sep, self.sep2]
74
- ret = system_prompt + seps[0]
75
- for i, (role, message) in enumerate(self.messages):
76
- if message:
77
- ret += role + ': ' + message + seps[i % 2]
78
- else:
79
- ret += role + ':'
80
- return ret
81
- elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
82
- ret = system_prompt + self.sep
83
- for role, message in self.messages:
84
- if message:
85
- ret += role + ': ' + message + self.sep
86
- else:
87
- ret += role + ': ' # must be end with a space
88
- return ret
89
- elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
90
- ret = '' if system_prompt == '' else system_prompt + self.sep
91
- for role, message in self.messages:
92
- if message:
93
- ret += role + '\n' + message + self.sep
94
- else:
95
- ret += role + '\n'
96
- return ret
97
- elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
98
- ret = system_prompt
99
- for role, message in self.messages:
100
- if message:
101
- ret += role + message + self.sep
102
- else:
103
- ret += role
104
- return ret
105
- elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
106
- seps = [self.sep, self.sep2]
107
- ret = system_prompt
108
- for i, (role, message) in enumerate(self.messages):
109
- if message:
110
- ret += role + message + seps[i % 2]
111
- else:
112
- ret += role
113
- return ret
114
- elif self.sep_style == SeparatorStyle.RWKV:
115
- ret = system_prompt
116
- for i, (role, message) in enumerate(self.messages):
117
- if message:
118
- ret += (
119
- role
120
- + ': '
121
- + message.replace('\r\n', '\n').replace('\n\n', '\n')
122
- )
123
- ret += '\n\n'
124
- else:
125
- ret += role + ':'
126
- return ret
127
- elif self.sep_style == SeparatorStyle.LLAMA2:
128
- seps = [self.sep, self.sep2]
129
- if self.system_message:
130
- ret = system_prompt
131
- else:
132
- ret = '[INST] '
133
- for i, (role, message) in enumerate(self.messages):
134
- tag = self.roles[i % 2]
135
- if message:
136
- if i == 0:
137
- ret += message + ' '
138
- else:
139
- ret += tag + ' ' + message + seps[i % 2]
140
- else:
141
- ret += tag
142
- return ret
143
- elif self.sep_style == SeparatorStyle.CHATGLM:
144
- # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
145
- # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
146
- round_add_n = 1 if self.name == 'chatglm2' else 0
147
- if system_prompt:
148
- ret = system_prompt + self.sep
149
- else:
150
- ret = ''
151
-
152
- for i, (role, message) in enumerate(self.messages):
153
- if i % 2 == 0:
154
- ret += f'[Round {i//2 + round_add_n}]{self.sep}'
155
-
156
- if message:
157
- ret += f'{role}:{message}{self.sep}'
158
- else:
159
- ret += f'{role}:'
160
- return ret
161
- elif self.sep_style == SeparatorStyle.CHATML:
162
- ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
163
- for role, message in self.messages:
164
- if message:
165
- ret += role + '\n' + message + self.sep + '\n'
166
- else:
167
- ret += role + '\n'
168
- return ret
169
- elif self.sep_style == SeparatorStyle.CHATGLM3:
170
- ret = ''
171
- if self.system_message:
172
- ret += system_prompt
173
- for role, message in self.messages:
174
- if message:
175
- ret += role + '\n' + ' ' + message
176
- else:
177
- ret += role
178
- return ret
179
- elif self.sep_style == SeparatorStyle.CHATINTERN:
180
- # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
181
- seps = [self.sep, self.sep2]
182
- ret = system_prompt
183
- for i, (role, message) in enumerate(self.messages):
184
- # if i % 2 == 0:
185
- # ret += "<s>"
186
- if message:
187
- ret += role + ':' + message + seps[i % 2] + '\n'
188
- else:
189
- ret += role + ':'
190
- return ret
191
- elif self.sep_style == SeparatorStyle.DOLLY:
192
- seps = [self.sep, self.sep2]
193
- ret = system_prompt
194
- for i, (role, message) in enumerate(self.messages):
195
- if message:
196
- ret += role + ':\n' + message + seps[i % 2]
197
- if i % 2 == 1:
198
- ret += '\n\n'
199
- else:
200
- ret += role + ':\n'
201
- return ret
202
- elif self.sep_style == SeparatorStyle.PHOENIX:
203
- ret = system_prompt
204
- for role, message in self.messages:
205
- if message:
206
- ret += role + ': ' + '<s>' + message + '</s>'
207
- else:
208
- ret += role + ': ' + '<s>'
209
- return ret
210
- elif self.sep_style == SeparatorStyle.ROBIN:
211
- ret = system_prompt + self.sep
212
- for role, message in self.messages:
213
- if message:
214
- ret += role + ':\n' + message + self.sep
215
- else:
216
- ret += role + ':\n'
217
- return ret
218
- elif self.sep_style == SeparatorStyle.FALCON_CHAT:
219
- ret = ''
220
- if self.system_message:
221
- ret += system_prompt + self.sep
222
- for role, message in self.messages:
223
- if message:
224
- ret += role + ': ' + message + self.sep
225
- else:
226
- ret += role + ':'
227
-
228
- return ret
229
- elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
230
- seps = [self.sep, self.sep2]
231
- ret = self.system_message + seps[0]
232
- for i, (role, message) in enumerate(self.messages):
233
- if message:
234
- ret += role + ': ' + message + seps[i % 2]
235
- else:
236
- ret += role + ':'
237
- return ret
238
- elif self.sep_style == SeparatorStyle.MPT:
239
- ret = system_prompt + self.sep
240
- for role, message in self.messages:
241
- if message:
242
- if type(message) is tuple:
243
- message, _, _ = message
244
- ret += role + message + self.sep
245
- else:
246
- ret += role
247
- return ret
248
- else:
249
- raise ValueError(f'Invalid style: {self.sep_style}')
250
-
251
- def set_system_message(self, system_message: str):
252
- """Set the system message."""
253
- self.system_message = system_message
254
-
255
- def append_message(self, role: str, message: str):
256
- """Append a new message."""
257
- self.messages.append([role, message])
258
-
259
- def update_last_message(self, message: str):
260
- """Update the last output.
261
-
262
- The last message is typically set to be None when constructing the prompt,
263
- so we need to update it in-place after getting the response from a model.
264
- """
265
- self.messages[-1][1] = message
266
-
267
- def to_gradio_chatbot(self):
268
- """Convert the conversation to gradio chatbot format."""
269
- ret = []
270
- for i, (role, msg) in enumerate(self.messages[self.offset :]):
271
- if i % 2 == 0:
272
- ret.append([msg, None])
273
- else:
274
- ret[-1][-1] = msg
275
- return ret
276
-
277
- def to_openai_api_messages(self):
278
- """Convert the conversation to OpenAI chat completion format."""
279
- ret = [{'role': 'system', 'content': self.system_message}]
280
-
281
- for i, (_, msg) in enumerate(self.messages[self.offset :]):
282
- if i % 2 == 0:
283
- ret.append({'role': 'user', 'content': msg})
284
- else:
285
- if msg is not None:
286
- ret.append({'role': 'assistant', 'content': msg})
287
- return ret
288
-
289
- def copy(self):
290
- return Conversation(
291
- name=self.name,
292
- system_template=self.system_template,
293
- system_message=self.system_message,
294
- roles=self.roles,
295
- messages=[[x, y] for x, y in self.messages],
296
- offset=self.offset,
297
- sep_style=self.sep_style,
298
- sep=self.sep,
299
- sep2=self.sep2,
300
- stop_str=self.stop_str,
301
- stop_token_ids=self.stop_token_ids,
302
- )
303
-
304
- def dict(self):
305
- return {
306
- 'template_name': self.name,
307
- 'system_message': self.system_message,
308
- 'roles': self.roles,
309
- 'messages': self.messages,
310
- 'offset': self.offset,
311
- }
312
-
313
-
314
- # A global registry for all conversation templates
315
- conv_templates: Dict[str, Conversation] = {}
316
-
317
-
318
- def register_conv_template(template: Conversation, override: bool = False):
319
- """Register a new conversation template."""
320
- if not override:
321
- assert (
322
- template.name not in conv_templates
323
- ), f'{template.name} has been registered.'
324
-
325
- conv_templates[template.name] = template
326
-
327
-
328
- def get_conv_template(name: str) -> Conversation:
329
- """Get a conversation template."""
330
- return conv_templates[name].copy()
331
-
332
-
333
- register_conv_template(
334
- Conversation(
335
- name='Hermes-2',
336
- system_template='<|im_start|>system\n{system_message}',
337
- system_message='Answer the questions.',
338
- roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
339
- sep_style=SeparatorStyle.MPT,
340
- sep='<|im_end|>',
341
- stop_token_ids=[
342
- 2,
343
- 6,
344
- 7,
345
- 8,
346
- ],
347
- stop_str='<|endoftext|>',
348
- )
349
- )
350
-
351
-
352
- register_conv_template(
353
- Conversation(
354
- name='internlm2-chat',
355
- system_template='<|im_start|>system\n{system_message}',
356
- system_message='You are an AI assistant whose name is InternLM (书生·浦语).',
357
- roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
358
- sep_style=SeparatorStyle.MPT,
359
- sep='<|im_end|>',
360
- stop_token_ids=[
361
- 2,
362
- 92543,
363
- 92542
364
- ]
365
- )
366
- )
367
-
368
-
369
- register_conv_template(
370
- Conversation(
371
- name='phi3-chat',
372
- system_template='<|system|>\n{system_message}',
373
- system_message='You are an AI assistant whose name is Phi-3.',
374
- roles=('<|user|>\n', '<|assistant|>\n'),
375
- sep_style=SeparatorStyle.MPT,
376
- sep='<|end|>',
377
- stop_token_ids=[
378
- 2,
379
- 32000,
380
- 32007
381
- ]
382
- )
383
- )