prithivMLmods commited on
Commit
9e3a243
·
verified ·
1 Parent(s): 6831f55

update conversation

Browse files
Files changed (1) hide show
  1. conversation.py +280 -280
conversation.py CHANGED
@@ -1,280 +1,280 @@
1
- """
2
- From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3
- """
4
-
5
- import dataclasses
6
- from enum import IntEnum, auto
7
- from typing import Any, Dict, List
8
-
9
-
10
- class SeparatorStyle(IntEnum):
11
- """Separator styles."""
12
-
13
- DeepSeek = auto()
14
- DeepSeekV2 = auto()
15
- PLAIN = auto()
16
- ALIGNMENT = auto()
17
-
18
-
19
- @dataclasses.dataclass
20
- class Conversation:
21
- """A class that manages prompt templates and keeps all conversation history."""
22
-
23
- # The name of this template
24
- name: str
25
- # The template of the system prompt
26
- system_template: str = "{system_message}"
27
- # The system message
28
- system_message: str = ""
29
- # The names of two roles
30
- roles: List[str] = (("USER", "ASSISTANT"),)
31
- # All messages. Each item is (role, message).
32
- messages: List[List[str]] = ()
33
- # The number of few shot examples
34
- offset: int = 0
35
- # The separator style and configurations
36
- sep_style: SeparatorStyle = SeparatorStyle.DeepSeek
37
- sep: str = "\n"
38
- sep2: str = None
39
- # Stop criteria (the default one is EOS token)
40
- stop_str: str = None
41
- # Stops generation if meeting any token in this list
42
- stop_token_ids: List[int] = None
43
-
44
- def get_prompt(self) -> str:
45
- """Get the prompt for generation."""
46
- system_prompt = self.system_template.format(system_message=self.system_message)
47
- if self.sep_style == SeparatorStyle.DeepSeek:
48
- seps = [self.sep, self.sep2]
49
- if system_prompt == "" or system_prompt is None:
50
- ret = ""
51
- else:
52
- ret = system_prompt + seps[0]
53
- for i, (role, message) in enumerate(self.messages):
54
- if message:
55
- ret += role + ": " + message + seps[i % 2]
56
- else:
57
- ret += role + ":"
58
- return ret
59
- elif self.sep_style == SeparatorStyle.DeepSeekV2:
60
- seps = [self.sep, self.sep2]
61
- if system_prompt == "" or system_prompt is None:
62
- ret = ""
63
- else:
64
- ret = system_prompt + seps[0]
65
- for i, (role, message) in enumerate(self.messages):
66
- if message:
67
- if role == "User":
68
- ret += "<|sft▁begin|>\n" + message + self.sep #<|sft▁begin|>User Input<|sft▁end|>\nResponse<|end▁of▁sentence|>
69
- else:
70
- ret += message + self.sep2
71
- else:
72
- ret = ret
73
- return ret
74
-
75
- elif self.sep_style == SeparatorStyle.PLAIN:
76
- seps = [self.sep, self.sep2]
77
- ret = ""
78
- for i, (role, message) in enumerate(self.messages):
79
- if message:
80
- if type(message) is tuple:
81
- message, _, _ = message
82
- if i % 2 == 0:
83
- ret += message + seps[i % 2]
84
- else:
85
- ret += message + seps[i % 2]
86
- else:
87
- ret += ""
88
- return ret
89
- elif self.sep_style == SeparatorStyle.ALIGNMENT:
90
- seps = [self.sep, self.sep2]
91
- ret = ""
92
- for i, (role, message) in enumerate(self.messages):
93
- if message:
94
- if type(message) is tuple:
95
- message, _, _ = message
96
- if i % 2 == 0:
97
- ret += '<image>\n' + seps[i % 2]
98
- else:
99
- ret += message + seps[i % 2]
100
- else:
101
- ret += ""
102
- return ret
103
- else:
104
- raise ValueError(f"Invalid style: {self.sep_style}")
105
-
106
- def set_system_message(self, system_message: str):
107
- """Set the system message."""
108
- self.system_message = system_message
109
-
110
- def append_message(self, role: str, message: str):
111
- """Append a new message."""
112
- self.messages.append([role, message])
113
-
114
- def update_last_message(self, message: str):
115
- """Update the last output.
116
-
117
- The last message is typically set to be None when constructing the prompt,
118
- so we need to update it in-place after getting the response from a model.
119
- """
120
- self.messages[-1][1] = message
121
-
122
- def reset_message(self):
123
- """Reset a new message."""
124
- self.messages = []
125
-
126
- def to_gradio_chatbot(self):
127
- """Convert the conversation to gradio chatbot format."""
128
- ret = []
129
- for i, (role, msg) in enumerate(self.messages[self.offset :]):
130
- if i % 2 == 0:
131
- ret.append([msg, None])
132
- else:
133
- ret[-1][-1] = msg
134
- return ret
135
-
136
- def to_openai_api_messages(self):
137
- """Convert the conversation to OpenAI chat completion format."""
138
- system_prompt = self.system_template.format(system_message=self.system_message)
139
- ret = [{"role": "system", "content": system_prompt}]
140
-
141
- for i, (_, msg) in enumerate(self.messages[self.offset :]):
142
- if i % 2 == 0:
143
- ret.append({"role": "user", "content": msg})
144
- else:
145
- if msg is not None:
146
- ret.append({"role": "assistant", "content": msg})
147
- return ret
148
-
149
- def copy(self):
150
- return Conversation(
151
- name=self.name,
152
- system_template=self.system_template,
153
- system_message=self.system_message,
154
- roles=self.roles,
155
- messages=[[x, y] for x, y in self.messages],
156
- offset=self.offset,
157
- sep_style=self.sep_style,
158
- sep=self.sep,
159
- sep2=self.sep2,
160
- stop_str=self.stop_str,
161
- stop_token_ids=self.stop_token_ids,
162
- )
163
-
164
- def dict(self):
165
- return {
166
- "template_name": self.name,
167
- "system_message": self.system_message,
168
- "roles": self.roles,
169
- "messages": self.messages,
170
- "offset": self.offset,
171
- }
172
-
173
-
174
- # A global registry for all conversation templates
175
- conv_templates: Dict[str, Conversation] = {}
176
-
177
-
178
- def register_conv_template(template: Conversation, override: bool = False):
179
- """Register a new conversation template."""
180
- if not override:
181
- assert template.name not in conv_templates, f"{template.name} has been registered."
182
-
183
- conv_templates[template.name] = template
184
-
185
-
186
- def get_conv_template(name: str) -> Conversation:
187
- """Get a conversation template."""
188
- return conv_templates[name].copy()
189
-
190
-
191
- register_conv_template(
192
- Conversation(
193
- name="deepseek",
194
- system_template="{system_message}",
195
- # system_message="You are a helpful assistant. Please answer truthfully and write out your "
196
- # "thinking step by step to be sure you get the right answer.",
197
- system_message="",
198
- roles=("<|User|>", "<|Assistant|>"),
199
- messages=(),
200
- offset=0,
201
- sep_style=SeparatorStyle.DeepSeek,
202
- sep="\n\n",
203
- sep2="<|end▁of▁sentence|>",
204
- stop_token_ids=[100001],
205
- stop_str=["User:", "<|end▁of▁sentence|>"]
206
- )
207
- )
208
- register_conv_template(
209
- Conversation(
210
- name="deepseekv2",
211
- system_template="{system_message}",
212
- # system_message="You are a helpful assistant. Please answer truthfully and write out your "
213
- # "thinking step by step to be sure you get the right answer.",
214
- system_message="",
215
- roles=("<|User|>", "<|Assistant|>"),
216
- messages=(),
217
- offset=0,
218
- sep_style=SeparatorStyle.DeepSeek,
219
- sep="",
220
- sep2="<|end▁of▁sentence|>",
221
- stop_token_ids=[100001],
222
- stop_str=["User:", "<|end▁of▁sentence|>"]
223
- )
224
- )
225
-
226
-
227
- register_conv_template(
228
- Conversation(
229
- name="plain",
230
- system_template="",
231
- system_message="",
232
- roles=("", ""),
233
- messages=(),
234
- offset=0,
235
- sep_style=SeparatorStyle.PLAIN,
236
- sep="",
237
- sep2="",
238
- stop_token_ids=[100001],
239
- stop_str=['</s>'],
240
- )
241
- )
242
-
243
-
244
- register_conv_template(
245
- Conversation(
246
- name="alignment",
247
- system_template="",
248
- system_message="",
249
- roles=("", ""),
250
- messages=(),
251
- offset=0,
252
- sep_style=SeparatorStyle.ALIGNMENT,
253
- sep="",
254
- sep2="",
255
- stop_token_ids=[100001],
256
- stop_str=['</s>'],
257
- )
258
- )
259
-
260
-
261
- if __name__ == "__main__":
262
- print("deepseek template:")
263
- conv = get_conv_template("deepseek")
264
- conv.append_message(conv.roles[0], "Hello!")
265
- conv.append_message(conv.roles[1], "Hi! This is Tony.")
266
- conv.append_message(conv.roles[0], "Who are you?")
267
- conv.append_message(conv.roles[1], "I am a helpful assistant.")
268
- conv.append_message(conv.roles[0], "How are you?")
269
- conv.append_message(conv.roles[1], None)
270
- print(conv.get_prompt())
271
-
272
- print("deepseekv2 template:")
273
- conv = get_conv_template("deepseekv2")
274
- conv.append_message(conv.roles[0], "Hello!")
275
- conv.append_message(conv.roles[1], "Hi! This is Tony.")
276
- conv.append_message(conv.roles[0], "Who are you?")
277
- conv.append_message(conv.roles[1], "I am a helpful assistant.")
278
- conv.append_message(conv.roles[0], "How are you?")
279
- conv.append_message(conv.roles[1], None)
280
- print(conv.get_prompt())
 
1
+ """
2
+ From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3
+ """
4
+
5
+ import dataclasses
6
+ from enum import IntEnum, auto
7
+ from typing import Any, Dict, List
8
+
9
+
10
+ class SeparatorStyle(IntEnum):
11
+ """Separator styles."""
12
+
13
+ DeepSeek = auto()
14
+ DeepSeekV2 = auto()
15
+ PLAIN = auto()
16
+ ALIGNMENT = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that manages prompt templates and keeps all conversation history."""
22
+
23
+ # The name of this template
24
+ name: str
25
+ # The template of the system prompt
26
+ system_template: str = "{system_message}"
27
+ # The system message
28
+ system_message: str = ""
29
+ # The names of two roles
30
+ roles: List[str] = (("USER", "ASSISTANT"),)
31
+ # All messages. Each item is (role, message).
32
+ messages: List[List[str]] = ()
33
+ # The number of few shot examples
34
+ offset: int = 0
35
+ # The separator style and configurations
36
+ sep_style: SeparatorStyle = SeparatorStyle.DeepSeek
37
+ sep: str = "\n"
38
+ sep2: str = None
39
+ # Stop criteria (the default one is EOS token)
40
+ stop_str: str = None
41
+ # Stops generation if meeting any token in this list
42
+ stop_token_ids: List[int] = None
43
+
44
+ def get_prompt(self) -> str:
45
+ """Get the prompt for generation."""
46
+ system_prompt = self.system_template.format(system_message=self.system_message)
47
+ if self.sep_style == SeparatorStyle.DeepSeek:
48
+ seps = [self.sep, self.sep2]
49
+ if system_prompt == "" or system_prompt is None:
50
+ ret = ""
51
+ else:
52
+ ret = system_prompt + seps[0]
53
+ for i, (role, message) in enumerate(self.messages):
54
+ if message:
55
+ ret += role + ": " + message + seps[i % 2]
56
+ else:
57
+ ret += role + ":"
58
+ return ret
59
+ elif self.sep_style == SeparatorStyle.DeepSeekV2:
60
+ seps = [self.sep, self.sep2]
61
+ if system_prompt == "" or system_prompt is None:
62
+ ret = ""
63
+ else:
64
+ ret = system_prompt + seps[0]
65
+ for i, (role, message) in enumerate(self.messages):
66
+ if message:
67
+ if role == "User":
68
+ ret += "<|sft▁begin|>\n" + message + self.sep #<|sft▁begin|>User Input<|sft▁end|>\nResponse<|end▁of▁sentence|>
69
+ else:
70
+ ret += message + self.sep2
71
+ else:
72
+ ret = ret
73
+ return ret
74
+
75
+ elif self.sep_style == SeparatorStyle.PLAIN:
76
+ seps = [self.sep, self.sep2]
77
+ ret = ""
78
+ for i, (role, message) in enumerate(self.messages):
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ if i % 2 == 0:
83
+ ret += message + seps[i % 2]
84
+ else:
85
+ ret += message + seps[i % 2]
86
+ else:
87
+ ret += ""
88
+ return ret
89
+ elif self.sep_style == SeparatorStyle.ALIGNMENT:
90
+ seps = [self.sep, self.sep2]
91
+ ret = ""
92
+ for i, (role, message) in enumerate(self.messages):
93
+ if message:
94
+ if type(message) is tuple:
95
+ message, _, _ = message
96
+ if i % 2 == 0:
97
+ ret += '<image>\n' + seps[i % 2]
98
+ else:
99
+ ret += message + seps[i % 2]
100
+ else:
101
+ ret += ""
102
+ return ret
103
+ else:
104
+ raise ValueError(f"Invalid style: {self.sep_style}")
105
+
106
+ def set_system_message(self, system_message: str):
107
+ """Set the system message."""
108
+ self.system_message = system_message
109
+
110
+ def append_message(self, role: str, message: str):
111
+ """Append a new message."""
112
+ self.messages.append([role, message])
113
+
114
+ def update_last_message(self, message: str):
115
+ """Update the last output.
116
+
117
+ The last message is typically set to be None when constructing the prompt,
118
+ so we need to update it in-place after getting the response from a model.
119
+ """
120
+ self.messages[-1][1] = message
121
+
122
+ def reset_message(self):
123
+ """Reset a new message."""
124
+ self.messages = []
125
+
126
+ def to_gradio_chatbot(self):
127
+ """Convert the conversation to gradio chatbot format."""
128
+ ret = []
129
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
130
+ if i % 2 == 0:
131
+ ret.append([msg, None])
132
+ else:
133
+ ret[-1][-1] = msg
134
+ return ret
135
+
136
+ def to_openai_api_messages(self):
137
+ """Convert the conversation to OpenAI chat completion format."""
138
+ system_prompt = self.system_template.format(system_message=self.system_message)
139
+ ret = [{"role": "system", "content": system_prompt}]
140
+
141
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
142
+ if i % 2 == 0:
143
+ ret.append({"role": "user", "content": msg})
144
+ else:
145
+ if msg is not None:
146
+ ret.append({"role": "assistant", "content": msg})
147
+ return ret
148
+
149
+ def copy(self):
150
+ return Conversation(
151
+ name=self.name,
152
+ system_template=self.system_template,
153
+ system_message=self.system_message,
154
+ roles=self.roles,
155
+ messages=[[x, y] for x, y in self.messages],
156
+ offset=self.offset,
157
+ sep_style=self.sep_style,
158
+ sep=self.sep,
159
+ sep2=self.sep2,
160
+ stop_str=self.stop_str,
161
+ stop_token_ids=self.stop_token_ids,
162
+ )
163
+
164
+ def dict(self):
165
+ return {
166
+ "template_name": self.name,
167
+ "system_message": self.system_message,
168
+ "roles": self.roles,
169
+ "messages": self.messages,
170
+ "offset": self.offset,
171
+ }
172
+
173
+
174
+ # A global registry for all conversation templates
175
+ conv_templates: Dict[str, Conversation] = {}
176
+
177
+
178
+ def register_conv_template(template: Conversation, override: bool = False):
179
+ """Register a new conversation template."""
180
+ if not override:
181
+ assert template.name not in conv_templates, f"{template.name} has been registered."
182
+
183
+ conv_templates[template.name] = template
184
+
185
+
186
+ def get_conv_template(name: str) -> Conversation:
187
+ """Get a conversation template."""
188
+ return conv_templates[name].copy()
189
+
190
+
191
+ register_conv_template(
192
+ Conversation(
193
+ name="deepseek",
194
+ system_template="{system_message}",
195
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
196
+ # "thinking step by step to be sure you get the right answer.",
197
+ system_message="",
198
+ roles=("<|User|>", "<|Assistant|>"),
199
+ messages=(),
200
+ offset=0,
201
+ sep_style=SeparatorStyle.DeepSeek,
202
+ sep="\n\n",
203
+ sep2="<|end▁of▁sentence|>",
204
+ stop_token_ids=[100001],
205
+ stop_str=["User:", "<|end▁of▁sentence|>"]
206
+ )
207
+ )
208
+ register_conv_template(
209
+ Conversation(
210
+ name="deepseekv2",
211
+ system_template="{system_message}",
212
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
213
+ # "thinking step by step to be sure you get the right answer.",
214
+ system_message="",
215
+ roles=("<|User|>", "<|Assistant|>"),
216
+ messages=(),
217
+ offset=0,
218
+ sep_style=SeparatorStyle.DeepSeek,
219
+ sep="",
220
+ sep2="<|end▁of▁sentence|>",
221
+ stop_token_ids=[100001],
222
+ stop_str=["User:", "<|end▁of▁sentence|>"]
223
+ )
224
+ )
225
+
226
+
227
+ register_conv_template(
228
+ Conversation(
229
+ name="plain",
230
+ system_template="",
231
+ system_message="",
232
+ roles=("", ""),
233
+ messages=(),
234
+ offset=0,
235
+ sep_style=SeparatorStyle.PLAIN,
236
+ sep="",
237
+ sep2="",
238
+ stop_token_ids=[100001],
239
+ stop_str=['</s>'],
240
+ )
241
+ )
242
+
243
+
244
+ register_conv_template(
245
+ Conversation(
246
+ name="alignment",
247
+ system_template="",
248
+ system_message="",
249
+ roles=("", ""),
250
+ messages=(),
251
+ offset=0,
252
+ sep_style=SeparatorStyle.ALIGNMENT,
253
+ sep="",
254
+ sep2="",
255
+ stop_token_ids=[100001],
256
+ stop_str=['</s>'],
257
+ )
258
+ )
259
+
260
+
261
+ if __name__ == "__main__":
262
+ print("deepseek template:")
263
+ conv = get_conv_template("deepseek")
264
+ conv.append_message(conv.roles[0], "Hello!")
265
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
266
+ conv.append_message(conv.roles[0], "Who are you?")
267
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
268
+ conv.append_message(conv.roles[0], "How are you?")
269
+ conv.append_message(conv.roles[1], None)
270
+ print(conv.get_prompt())
271
+
272
+ print("deepseekv2 template:")
273
+ conv = get_conv_template("deepseekv2")
274
+ conv.append_message(conv.roles[0], "Hello!")
275
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
276
+ conv.append_message(conv.roles[0], "Who are you?")
277
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
278
+ conv.append_message(conv.roles[0], "How are you?")
279
+ conv.append_message(conv.roles[1], None)
280
+ print(conv.get_prompt())