| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import base64 |
| import json |
| import os |
| from pathlib import Path |
| from typing import Dict, List, Optional |
|
|
| try: |
| import tiktoken |
| except ImportError: |
| pass |
|
|
| from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
|
|
| __all__ = ['TiktokenTokenizer'] |
|
|
|
|
| def reload_mergeable_ranks( |
| path: str, |
| max_vocab: Optional[int] = None, |
| ) -> Dict[bytes, int]: |
| """ |
| Reload the tokenizer JSON file and convert it to Tiktoken format. |
| """ |
| assert path.endswith(".json") |
|
|
| |
| with open(path, "r") as f: |
| vocab = json.load(f) |
| assert isinstance(vocab, list) |
| print(f"Vocab size: {len(vocab)}") |
| if max_vocab is not None: |
| vocab = vocab[:max_vocab] |
| print(f"Cutting vocab to first {len(vocab)} tokens.") |
|
|
| |
| ranks: Dict[bytes, int] = {} |
| for i, x in enumerate(vocab): |
| assert x.keys() == {"rank", "token_bytes", "token_str"} |
| assert x["rank"] == i |
| merge = base64.b64decode(x["token_bytes"]) |
| assert i >= 256 or merge == bytes([i]) |
| ranks[merge] = x["rank"] |
|
|
| |
| assert len(ranks) == len(vocab) |
| assert set(ranks.values()) == set(range(len(ranks))) |
|
|
| return ranks |
|
|
|
|
| PATTERN_TIKTOKEN = "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" |
| DEFAULT_TIKTOKEN_MAX_VOCAB = 2**17 |
| SPECIAL_TOKENS = ["<unk>", "<s>", "</s>", "<mask>", "<pad>", "<cls>", "<sep>"] |
| SPECIAL_TOKEN_TEMPLATE = "<SPECIAL_{id}>" |
|
|
|
|
| class TiktokenTokenizer(TokenizerSpec): |
| """ |
| TiktokenTokenizer https://github.com/openai/tiktoken. |
| |
| Args: |
| model_path: path to tokenizer vocabulary |
| num_special_tokens: number of special tokens to generate |
| special_tokens: template for user-defined special tokens |
| pattern: Regex pattern to split the text |
| """ |
|
|
| def __init__( |
| self, |
| vocab_file: str, |
| pattern: str = PATTERN_TIKTOKEN, |
| vocab_size: int = DEFAULT_TIKTOKEN_MAX_VOCAB, |
| num_special_tokens: int = 1000, |
| special_tokens: Optional[List[str]] = None, |
| ): |
| if not vocab_file or not os.path.exists(vocab_file): |
| raise ValueError(f"vocab_file: {vocab_file} is invalid") |
|
|
| if special_tokens is None: |
| special_tokens = SPECIAL_TOKENS.copy() |
|
|
| assert len(special_tokens) == len(set(special_tokens)), f"Special tokens should be unique: {special_tokens}" |
| assert len(special_tokens) <= num_special_tokens < vocab_size |
| assert set(SPECIAL_TOKENS) <= set(special_tokens), f"Custom special tokens should include {SPECIAL_TOKENS}" |
|
|
| self._unk_id = special_tokens.index("<unk>") |
| self._bos_id = special_tokens.index("<s>") |
| self._eos_id = special_tokens.index("</s>") |
| self._mask_id = special_tokens.index("<mask>") |
| self._pad_id = special_tokens.index("<pad>") |
| self._cls_id = special_tokens.index("<cls>") |
| self._sep_id = special_tokens.index("<sep>") |
|
|
| self._vocab_size = vocab_size |
| print(f'{self._vocab_size = }') |
| self.num_special_tokens = num_special_tokens |
| special_filler = [SPECIAL_TOKEN_TEMPLATE.format(id=i) for i in range(len(special_tokens), num_special_tokens)] |
| self.special_filler = special_filler |
| if special_filler: |
| print(f"Adding special tokens {special_filler[0]}, ..., {special_filler[-1]}") |
| self.special_tokens = special_tokens + special_filler |
| assert len(set(self.special_tokens)) == len(self.special_tokens) == num_special_tokens, self.special_tokens |
| self.inner_vocab_size = vocab_size - num_special_tokens |
|
|
| |
| self.token2id = reload_mergeable_ranks(vocab_file, max_vocab=self.inner_vocab_size) |
| self.id2token = {v: k for k, v in self.token2id.items()} |
| assert set(range(self.inner_vocab_size)) == set(self.id2token.keys()) |
|
|
| self.shifted_id2token = {i: tok for i, tok in enumerate(self.special_tokens)} |
| for key, value in self.id2token.items(): |
| self.shifted_id2token[key + self.num_special_tokens] = value.decode('utf-8', errors='replace') |
|
|
| self.tokenizer = tiktoken.Encoding( |
| name=Path(vocab_file).parent.name, |
| pat_str=pattern, |
| mergeable_ranks=self.token2id, |
| special_tokens={}, |
| ) |
|
|
| def text_to_tokens(self, text: str): |
| token_ids = self.tokenizer.encode(text) |
| return [self.tokenizer.decode_single_token_bytes(token) for token in token_ids] |
|
|
| def tokens_to_text(self, tokens: List[int]): |
| token_ids = [self.tokenizer.encode_single_token(tokens) for tokens in tokens] |
| return self.tokenizer.decode(token_ids) |
|
|
| def token_to_id(self, token): |
| if token in self.special_tokens: |
| return self.special_tokens.index(token) |
| else: |
| return self.tokenizer.encode_single_token(token) + self.num_special_tokens |
|
|
| def tokens_to_ids(self, tokens): |
| return [self.token_to_id(token) for token in tokens] |
|
|
| def id_to_token(self, token_id): |
| if token_id < self.num_special_tokens: |
| return self.special_tokens[token_id] |
| else: |
| token_id -= self.num_special_tokens |
| token_bytes = self.tokenizer.decode_single_token_bytes(token_id) |
| return token_bytes.decode('utf-8', errors='replace') |
|
|
| def ids_to_tokens(self, token_ids): |
| tokens = [] |
| for token_id in token_ids: |
| tokens.append(self.id_to_token(token_id)) |
|
|
| return tokens |
|
|
| def text_to_ids(self, text: str): |
| tokens = self.tokenizer.encode(text) |
| tokens = [t + self.num_special_tokens for t in tokens] |
| return tokens |
|
|
| def ids_to_text( |
| self, tokens: List[int], remove_special_tokens: bool = True |
| ): |
| if remove_special_tokens: |
| adjusted_tokens = [t for t in tokens if t not in {self.bos, self.eos} and t >= self.num_special_tokens] |
| else: |
| adjusted_tokens = tokens |
|
|
| |
| if adjusted_tokens: |
| return "".join(self.ids_to_tokens(adjusted_tokens)) |
| else: |
| return "" |
|
|
| @property |
| def bos_id(self): |
| return self._bos_id |
|
|
| @property |
| def eos_id(self): |
| return self._eos_id |
|
|
| @property |
| def unk_id(self): |
| return self._unk_id |
|
|
| @property |
| def mask_id(self): |
| return self._mask_id |
|
|
| @property |
| def pad_id(self): |
| return self._pad_id |
|
|
| @property |
| def cls_id(self): |
| return self._cls_id |
|
|
| @property |
| def sep_id(self): |
| return self._sep_id |
|
|
| @property |
| def vocab(self): |
| return self.token2id |
|
|
| @property |
| def additional_special_tokens_ids(self): |
| """ |
| Returns a list of the additional special tokens, excluding [bos, eos, pad, unk] and special_filler. |
| Used to return sentinel tokens for e.g. T5. |
| """ |
| excluding_tokens = self.ids_to_tokens([self._unk_id, self._bos_id, self._eos_id]) + self.special_filler |
| result = [self.token_to_id(token) for token in self.special_tokens if token not in excluding_tokens] |
| return result |
|
|
| @property |
| def decoder(self): |
| return self.shifted_id2token |
|
|
| @property |
| def encoder(self): |
| return self.vocab |
|
|
| @property |
| def vocab_size(self) -> int: |
| return self._vocab_size |
|
|
| @property |
| def inv_vocab(self): |
| return self.shifted_id2token |
|
|