|
|
import os
|
|
|
from pathlib import Path
|
|
|
|
|
|
import tiktoken
|
|
|
from tiktoken.load import load_tiktoken_bpe
|
|
|
|
|
|
class Llama3Tokenizer:
|
|
|
def __init__(self, model_path):
|
|
|
assert os.path.isfile(model_path), f'Model file {model_path} not found'
|
|
|
mergeable_ranks = load_tiktoken_bpe(model_path)
|
|
|
|
|
|
self.special_tokens = {
|
|
|
"<|begin_of_text|>": 128000,
|
|
|
"<|end_of_text|>": 128001,
|
|
|
"<|start_header_id|>": 128006,
|
|
|
"<|end_header_id|>": 128007,
|
|
|
"<|eot_id|>": 128009,
|
|
|
}
|
|
|
self.special_tokens.update({
|
|
|
f"<|reserved_{i}|>": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()
|
|
|
})
|
|
|
self.model = tiktoken.Encoding(
|
|
|
name=Path(model_path).name,
|
|
|
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
|
|
|
mergeable_ranks=mergeable_ranks,
|
|
|
special_tokens=self.special_tokens
|
|
|
)
|
|
|
def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):
|
|
|
if bos:
|
|
|
tokens = [self.special_tokens["<|begin_of_text|>"]]
|
|
|
else:
|
|
|
tokens = []
|
|
|
|
|
|
tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)
|
|
|
|
|
|
if eos:
|
|
|
tokens.append(self.special_tokens["<|end_of_text|>"])
|
|
|
return tokens
|
|
|
|
|
|
def decode(self, tokens):
|
|
|
return self.model.decode(tokens)
|
|
|
class ChatFormat:
|
|
|
def __init__(self, tokenizer):
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
def encode_header(self, message):
|
|
|
tokens = []
|
|
|
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
|
|
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
|
|
|
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
|
|
|
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
|
|
return tokens
|
|
|
|
|
|
def encode(self, text):
|
|
|
message = {
|
|
|
"role": "user",
|
|
|
"content": text
|
|
|
}
|
|
|
|
|
|
tokens = self.encode_header(message)
|
|
|
tokens.extend(
|
|
|
self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
|
|
|
)
|
|
|
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
|
|
|
return tokens
|
|
|
|
|
|
def decode(self, token_ids):
|
|
|
return self.tokenizer.decode(token_ids)
|
|
|
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
|
|
|
|
|
|
index = text.find(header_end)
|
|
|
|
|
|
if index != -1:
|
|
|
|
|
|
return text[index + len(header_end):].strip()
|
|
|
else:
|
|
|
|
|
|
return text |