|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import base64 |
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional |
|
|
|
|
|
import numpy as np |
|
|
import tiktoken |
|
|
import torch |
|
|
|
|
|
PATTERN_TIKTOKEN = "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" |
|
|
DEFAULT_TIKTOKEN_MAX_VOCAB = 2**17 |
|
|
SPECIAL_TOKENS = ["<unk>", "<s>", "</s>"] |
|
|
SPECIAL_TOKEN_TEMPLATE = "<SPECIAL_{id}>" |
|
|
|
|
|
|
|
|
def reload_mergeable_ranks( |
|
|
path: str, |
|
|
max_vocab: Optional[int] = None, |
|
|
) -> Dict[bytes, int]: |
|
|
""" |
|
|
Reload the tokenizer JSON file and convert it to Tiktoken format. |
|
|
""" |
|
|
assert path.endswith(".json") |
|
|
|
|
|
|
|
|
with open(path, "r", encoding='utf-8') as f: |
|
|
vocab = json.load(f) |
|
|
assert isinstance(vocab, list) |
|
|
print(f"Vocab size: {len(vocab)}") |
|
|
if max_vocab is not None: |
|
|
vocab = vocab[:max_vocab] |
|
|
print(f"Cutting vocab to first {len(vocab)} tokens.") |
|
|
|
|
|
|
|
|
ranks: Dict[bytes, int] = {} |
|
|
for i, x in enumerate(vocab): |
|
|
assert x.keys() == {"rank", "token_bytes", "token_str"} |
|
|
assert x["rank"] == i |
|
|
merge = base64.b64decode(x["token_bytes"]) |
|
|
assert i >= 256 or merge == bytes([i]) |
|
|
ranks[merge] = x["rank"] |
|
|
|
|
|
|
|
|
assert len(ranks) == len(vocab) |
|
|
assert set(ranks.values()) == set(range(len(ranks))) |
|
|
|
|
|
return ranks |
|
|
|
|
|
|
|
|
class TiktokenTokenizer: |
|
|
def __init__(self, vocab_file: str): |
|
|
|
|
|
self.num_special_tokens = 1000 |
|
|
vocab_size = DEFAULT_TIKTOKEN_MAX_VOCAB |
|
|
pattern = PATTERN_TIKTOKEN |
|
|
special_tokens = SPECIAL_TOKENS.copy() |
|
|
inner_vocab_size = vocab_size - self.num_special_tokens |
|
|
|
|
|
token2id = reload_mergeable_ranks(vocab_file, max_vocab=inner_vocab_size) |
|
|
self.tokenizer = tiktoken.Encoding( |
|
|
name=Path(vocab_file).parent.name, |
|
|
pat_str=pattern, |
|
|
mergeable_ranks=token2id, |
|
|
special_tokens={}, |
|
|
) |
|
|
|
|
|
|
|
|
self._bos_id = special_tokens.index("<s>") |
|
|
self._eos_id = special_tokens.index("</s>") |
|
|
|
|
|
def encode(self, text): |
|
|
tokens = self.tokenizer.encode(text) |
|
|
tokens = [t + self.num_special_tokens for t in tokens] |
|
|
return tokens |
|
|
|
|
|
def decode(self, tokens): |
|
|
|
|
|
adjusted_tokens = [ |
|
|
t - self.num_special_tokens |
|
|
for t in tokens |
|
|
if t not in {self._bos_id, self._eos_id} and t >= self.num_special_tokens |
|
|
] |
|
|
|
|
|
|
|
|
if adjusted_tokens: |
|
|
return self.tokenizer.decode(adjusted_tokens) |
|
|
else: |
|
|
return "" |
|
|
|
|
|
def batch_decode(self, ids): |
|
|
if isinstance(ids, np.ndarray) or torch.is_tensor(ids): |
|
|
ids = ids.tolist() |
|
|
|
|
|
if isinstance(ids[0], list): |
|
|
ids = ids[0] |
|
|
|
|
|
return self.decode(ids) |
|
|
|
|
|
@property |
|
|
def pad_id(self): |
|
|
return self._eos_id |
|
|
|
|
|
@property |
|
|
def bos_token_id(self): |
|
|
return self._bos_id |
|
|
|
|
|
@property |
|
|
def eos_token_id(self): |
|
|
return self._eos_id |
|
|
|