Spaces:
Runtime error
Runtime error
Upload chat_template.py
Browse files- chat_template.py +41 -0
chat_template.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ChatTemplate:
|
| 5 |
+
cache = {}
|
| 6 |
+
|
| 7 |
+
def __init__(self, model, im_start=r'<|im_start|>', im_end=r'<|im_end|>', nl='\n'):
|
| 8 |
+
self.model = model
|
| 9 |
+
self.nl = nl
|
| 10 |
+
self.im_start = im_start
|
| 11 |
+
self.im_start_token = model.tokenize(self.im_start.encode('utf-8'), add_bos=False, special=True)
|
| 12 |
+
self.im_end = im_end
|
| 13 |
+
self.im_end_nl = model.tokenize((self.im_end + self.nl).encode('utf-8'), add_bos=False, special=True)
|
| 14 |
+
self.eos = [model._token_eos, self.im_end_nl[0]]
|
| 15 |
+
self.onenl = [self.im_end_nl[-1]]
|
| 16 |
+
tmp = model.tokenize(('\r' + self.nl).encode('utf-8'), add_bos=False, special=True)
|
| 17 |
+
if len(tmp) == 1:
|
| 18 |
+
self.onenl.append(tmp[0])
|
| 19 |
+
self.onerl = model.tokenize(b'\r', add_bos=False, special=True)
|
| 20 |
+
self.nlnl = None
|
| 21 |
+
tmp = model.tokenize((self.nl + self.nl).encode('utf-8'), add_bos=False, special=True)
|
| 22 |
+
if len(tmp) == 1:
|
| 23 |
+
self.nlnl = tmp[0]
|
| 24 |
+
print('ChatTemplate', self.eos, self.im_end_nl, self.onerl, self.onenl, self.nlnl)
|
| 25 |
+
|
| 26 |
+
def _get(self, key: str):
|
| 27 |
+
if key in self.cache:
|
| 28 |
+
return copy.deepcopy(self.cache[key]) # 深拷贝一下
|
| 29 |
+
else:
|
| 30 |
+
value = self.model.tokenize((self.im_start + key + self.nl).encode('utf-8'), add_bos=False, special=True)
|
| 31 |
+
self.cache[key] = copy.deepcopy(value) # 深拷贝一下
|
| 32 |
+
return value
|
| 33 |
+
|
| 34 |
+
def __call__(self, _role, prompt=None):
|
| 35 |
+
if prompt is None:
|
| 36 |
+
return self._get(_role)
|
| 37 |
+
# print(_role, prompt, self.cache)
|
| 38 |
+
prompt = self.im_start + _role + self.nl + prompt
|
| 39 |
+
prompt = self.model.tokenize(prompt.encode('utf-8'), add_bos=False, special=True) + self.im_end_nl
|
| 40 |
+
# print(self.model.str_detokenize(prompt), prompt)
|
| 41 |
+
return prompt
|