SamMikaelson commited on
Commit
c56154b
·
verified ·
1 Parent(s): 2068f0a

Add conversation.py from official DeepSeek OCR repo

Browse files
Files changed (1) hide show
  1. conversation.py +280 -0
conversation.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3
+ """
4
+
5
+ import dataclasses
6
+ from enum import IntEnum, auto
7
+ from typing import Any, Dict, List
8
+
9
+
10
+ class SeparatorStyle(IntEnum):
11
+ """Separator styles."""
12
+
13
+ DeepSeek = auto()
14
+ DeepSeekV2 = auto()
15
+ PLAIN = auto()
16
+ ALIGNMENT = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that manages prompt templates and keeps all conversation history."""
22
+
23
+ # The name of this template
24
+ name: str
25
+ # The template of the system prompt
26
+ system_template: str = "{system_message}"
27
+ # The system message
28
+ system_message: str = ""
29
+ # The names of two roles
30
+ roles: List[str] = (("USER", "ASSISTANT"),)
31
+ # All messages. Each item is (role, message).
32
+ messages: List[List[str]] = ()
33
+ # The number of few shot examples
34
+ offset: int = 0
35
+ # The separator style and configurations
36
+ sep_style: SeparatorStyle = SeparatorStyle.DeepSeek
37
+ sep: str = "\n"
38
+ sep2: str = None
39
+ # Stop criteria (the default one is EOS token)
40
+ stop_str: str = None
41
+ # Stops generation if meeting any token in this list
42
+ stop_token_ids: List[int] = None
43
+
44
+ def get_prompt(self) -> str:
45
+ """Get the prompt for generation."""
46
+ system_prompt = self.system_template.format(system_message=self.system_message)
47
+ if self.sep_style == SeparatorStyle.DeepSeek:
48
+ seps = [self.sep, self.sep2]
49
+ if system_prompt == "" or system_prompt is None:
50
+ ret = ""
51
+ else:
52
+ ret = system_prompt + seps[0]
53
+ for i, (role, message) in enumerate(self.messages):
54
+ if message:
55
+ ret += role + ": " + message + seps[i % 2]
56
+ else:
57
+ ret += role + ":"
58
+ return ret
59
+ elif self.sep_style == SeparatorStyle.DeepSeekV2:
60
+ seps = [self.sep, self.sep2]
61
+ if system_prompt == "" or system_prompt is None:
62
+ ret = ""
63
+ else:
64
+ ret = system_prompt + seps[0]
65
+ for i, (role, message) in enumerate(self.messages):
66
+ if message:
67
+ if role == "User":
68
+ ret += "<|sft▁begin|>\n" + message + self.sep #<|sft▁begin|>User Input<|sft▁end|>\nResponse<|end▁of▁sentence|>
69
+ else:
70
+ ret += message + self.sep2
71
+ else:
72
+ ret = ret
73
+ return ret
74
+
75
+ elif self.sep_style == SeparatorStyle.PLAIN:
76
+ seps = [self.sep, self.sep2]
77
+ ret = ""
78
+ for i, (role, message) in enumerate(self.messages):
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ if i % 2 == 0:
83
+ ret += message + seps[i % 2]
84
+ else:
85
+ ret += message + seps[i % 2]
86
+ else:
87
+ ret += ""
88
+ return ret
89
+ elif self.sep_style == SeparatorStyle.ALIGNMENT:
90
+ seps = [self.sep, self.sep2]
91
+ ret = ""
92
+ for i, (role, message) in enumerate(self.messages):
93
+ if message:
94
+ if type(message) is tuple:
95
+ message, _, _ = message
96
+ if i % 2 == 0:
97
+ ret += '<image>\n' + seps[i % 2]
98
+ else:
99
+ ret += message + seps[i % 2]
100
+ else:
101
+ ret += ""
102
+ return ret
103
+ else:
104
+ raise ValueError(f"Invalid style: {self.sep_style}")
105
+
106
+ def set_system_message(self, system_message: str):
107
+ """Set the system message."""
108
+ self.system_message = system_message
109
+
110
+ def append_message(self, role: str, message: str):
111
+ """Append a new message."""
112
+ self.messages.append([role, message])
113
+
114
+ def update_last_message(self, message: str):
115
+ """Update the last output.
116
+
117
+ The last message is typically set to be None when constructing the prompt,
118
+ so we need to update it in-place after getting the response from a model.
119
+ """
120
+ self.messages[-1][1] = message
121
+
122
+ def reset_message(self):
123
+ """Reset a new message."""
124
+ self.messages = []
125
+
126
+ def to_gradio_chatbot(self):
127
+ """Convert the conversation to gradio chatbot format."""
128
+ ret = []
129
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
130
+ if i % 2 == 0:
131
+ ret.append([msg, None])
132
+ else:
133
+ ret[-1][-1] = msg
134
+ return ret
135
+
136
+ def to_openai_api_messages(self):
137
+ """Convert the conversation to OpenAI chat completion format."""
138
+ system_prompt = self.system_template.format(system_message=self.system_message)
139
+ ret = [{"role": "system", "content": system_prompt}]
140
+
141
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
142
+ if i % 2 == 0:
143
+ ret.append({"role": "user", "content": msg})
144
+ else:
145
+ if msg is not None:
146
+ ret.append({"role": "assistant", "content": msg})
147
+ return ret
148
+
149
+ def copy(self):
150
+ return Conversation(
151
+ name=self.name,
152
+ system_template=self.system_template,
153
+ system_message=self.system_message,
154
+ roles=self.roles,
155
+ messages=[[x, y] for x, y in self.messages],
156
+ offset=self.offset,
157
+ sep_style=self.sep_style,
158
+ sep=self.sep,
159
+ sep2=self.sep2,
160
+ stop_str=self.stop_str,
161
+ stop_token_ids=self.stop_token_ids,
162
+ )
163
+
164
+ def dict(self):
165
+ return {
166
+ "template_name": self.name,
167
+ "system_message": self.system_message,
168
+ "roles": self.roles,
169
+ "messages": self.messages,
170
+ "offset": self.offset,
171
+ }
172
+
173
+
174
+ # A global registry for all conversation templates
175
+ conv_templates: Dict[str, Conversation] = {}
176
+
177
+
178
+ def register_conv_template(template: Conversation, override: bool = False):
179
+ """Register a new conversation template."""
180
+ if not override:
181
+ assert template.name not in conv_templates, f"{template.name} has been registered."
182
+
183
+ conv_templates[template.name] = template
184
+
185
+
186
+ def get_conv_template(name: str) -> Conversation:
187
+ """Get a conversation template."""
188
+ return conv_templates[name].copy()
189
+
190
+
191
+ register_conv_template(
192
+ Conversation(
193
+ name="deepseek",
194
+ system_template="{system_message}",
195
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
196
+ # "thinking step by step to be sure you get the right answer.",
197
+ system_message="",
198
+ roles=("<|User|>", "<|Assistant|>"),
199
+ messages=(),
200
+ offset=0,
201
+ sep_style=SeparatorStyle.DeepSeek,
202
+ sep="\n\n",
203
+ sep2="<|end▁of▁sentence|>",
204
+ stop_token_ids=[100001],
205
+ stop_str=["User:", "<|end▁of▁sentence|>"]
206
+ )
207
+ )
208
+ register_conv_template(
209
+ Conversation(
210
+ name="deepseekv2",
211
+ system_template="{system_message}",
212
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
213
+ # "thinking step by step to be sure you get the right answer.",
214
+ system_message="",
215
+ roles=("<|User|>", "<|Assistant|>"),
216
+ messages=(),
217
+ offset=0,
218
+ sep_style=SeparatorStyle.DeepSeek,
219
+ sep="",
220
+ sep2="<|end▁of▁sentence|>",
221
+ stop_token_ids=[100001],
222
+ stop_str=["User:", "<|end▁of▁sentence|>"]
223
+ )
224
+ )
225
+
226
+
227
+ register_conv_template(
228
+ Conversation(
229
+ name="plain",
230
+ system_template="",
231
+ system_message="",
232
+ roles=("", ""),
233
+ messages=(),
234
+ offset=0,
235
+ sep_style=SeparatorStyle.PLAIN,
236
+ sep="",
237
+ sep2="",
238
+ stop_token_ids=[100001],
239
+ stop_str=['</s>'],
240
+ )
241
+ )
242
+
243
+
244
+ register_conv_template(
245
+ Conversation(
246
+ name="alignment",
247
+ system_template="",
248
+ system_message="",
249
+ roles=("", ""),
250
+ messages=(),
251
+ offset=0,
252
+ sep_style=SeparatorStyle.ALIGNMENT,
253
+ sep="",
254
+ sep2="",
255
+ stop_token_ids=[100001],
256
+ stop_str=['</s>'],
257
+ )
258
+ )
259
+
260
+
261
+ if __name__ == "__main__":
262
+ print("deepseek template:")
263
+ conv = get_conv_template("deepseek")
264
+ conv.append_message(conv.roles[0], "Hello!")
265
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
266
+ conv.append_message(conv.roles[0], "Who are you?")
267
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
268
+ conv.append_message(conv.roles[0], "How are you?")
269
+ conv.append_message(conv.roles[1], None)
270
+ print(conv.get_prompt())
271
+
272
+ print("deepseekv2 template:")
273
+ conv = get_conv_template("deepseekv2")
274
+ conv.append_message(conv.roles[0], "Hello!")
275
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
276
+ conv.append_message(conv.roles[0], "Who are you?")
277
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
278
+ conv.append_message(conv.roles[0], "How are you?")
279
+ conv.append_message(conv.roles[1], None)
280
+ print(conv.get_prompt())