| |
|
|
| """Megatron tokenizers.""" |
| from transformers.tokenization_utils import PreTrainedTokenizer |
| from typing import Union |
| from typing import ( |
| AbstractSet, |
| cast, |
| Collection, |
| Dict, |
| Iterator, |
| List, |
| Literal, |
| Sequence, |
| Union, |
| Optional, |
| ) |
| from tiktoken.load import load_tiktoken_bpe |
| import tiktoken |
| from pathlib import Path |
| import os |
| import logging |
| from tokenizers import AddedToken |
|
|
| logger = logging.getLogger(__name__) |
| VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"} |
|
|
|
|
| class TikTokenTokenizer(PreTrainedTokenizer): |
| """ |
| Tokenizing and encoding/decoding text using the Tiktoken tokenizer. |
| """ |
|
|
| special_tokens: Dict[str, int] |
|
|
| num_reserved_special_tokens = 293 + 128 |
|
|
| pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" |
|
|
| vocab_files_names = VOCAB_FILES_NAMES |
|
|
| def __init__( |
| self, |
| vocab_file, |
| bos_token: Union[str, AddedToken] = "[BOS]", |
| eos_token: Union[str, AddedToken] = "[EOS]", |
| unk_token: Union[str, AddedToken] = "[UNK]", |
| pad_token: Union[str, AddedToken] = "[PAD]", |
| additional_special_tokens: Optional[List[str]] = None, |
| added_tokens_decoder: Optional[dict] = None, |
| **kwargs, |
| ): |
| """ |
| Initializes the Tokenizer with a Tiktoken model. |
| |
| Args: |
| model_path (str): The path to the Tiktoken model file. |
| """ |
| assert os.path.isfile(vocab_file), vocab_file |
|
|
| mergeable_ranks = load_tiktoken_bpe(vocab_file) |
| num_base_tokens = len(mergeable_ranks) |
|
|
| used_special_tokens = [ |
| "[BOS]", |
| "[EOS]", |
| "<|im_msg_end|>", |
| "<|im_user_msg_start|>", |
| "<|im_assistant_msg_start|>", |
| "<|reserved_token_0|>", |
| "<|reserved_token_1|>", |
| "<|reserved_token_2|>", |
| "<|reserved_token_3|>", |
| "[EOT]", |
| "<|reserved_token_4|>", |
| "<|reserved_token_5|>", |
| "<|reserved_token_6|>", |
| "<|reserved_token_7|>", |
| "<|reserved_token_8|>", |
| "<|reserved_token_9|>", |
| "<|reserved_token_10|>", |
| "<|reserved_token_11|>", |
| "<|im_media_begin|>", |
| "<|reserved_token_12|>", |
| "<|im_media_end|>", |
| "<|reserved_token_13|>", |
| "<|reserved_token_14|>", |
| "<|im_kimia_text_blank|>", |
| "<|im_kimia_text_eos|>", |
| "<|reserved_token_15|>", |
| "<|reserved_token_16|>", |
| "<|im_kimia_user_msg_start|>", |
| "<|im_kimia_assistant_msg_start|>", |
| "<|reserved_token_17|>", |
| "<|reserved_token_18|>", |
| "<|reserved_token_19|>", |
| "<|im_kimia_speech_ct_id|>", |
| "<|im_kimia_speech_ctd_id|>", |
| ] |
| autoset_special_tokens = [ |
| f"<|reserved_token_{i}|>" |
| for i in range( |
| 20, self.num_reserved_special_tokens - len(used_special_tokens) + 20 |
| ) |
| ] |
| special_tokens = used_special_tokens + autoset_special_tokens |
| self.special_tokens = { |
| token: num_base_tokens + i for i, token in enumerate(special_tokens) |
| } |
| self.model = tiktoken.Encoding( |
| name=Path(vocab_file).name, |
| pat_str=self.pat_str, |
| mergeable_ranks=mergeable_ranks, |
| special_tokens=self.special_tokens, |
| ) |
| logger.info(f"Reloaded tiktoken model from {vocab_file}") |
|
|
| self.n_words: int = self.model.n_vocab |
| |
| self.bos_token = "[BOS]" |
| self.bos_id: int = self.special_tokens["[BOS]"] |
| self.eos_token = "[EOS]" |
| self.eos_id: int = self.special_tokens["[EOS]"] |
|
|
| |
| self.pad_token: str = special_tokens[-1] |
| self.pad_id: int = self.special_tokens[self.pad_token] |
|
|
| self.unk_token: str = special_tokens[-2] |
| self.unk_id: int = self.special_tokens[self.pad_token] |
|
|
| self.stop_tokens = { |
| self.special_tokens["[EOS]"], |
| self.special_tokens["[EOT]"], |
| } |
|
|
| logger.info( |
| f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" |
| ) |
|
|
| def encode( |
| self, |
| s: str, |
| *, |
| bos: bool, |
| eos: bool, |
| allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), |
| disallowed_special: Union[Literal["all"], Collection[str]] = (), |
| ) -> List[int]: |
| """ |
| Encodes a string into a list of token IDs. |
| |
| Args: |
| s (str): The input string to be encoded. |
| bos (bool): Whether to prepend the beginning-of-sequence token. |
| eos (bool): Whether to append the end-of-sequence token. |
| allowed_tokens ("all"|set[str]): allowed special tokens in string |
| disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string |
| |
| Returns: |
| list[int]: A list of token IDs. |
| |
| By default, setting disallowed_special=() encodes a string by ignoring |
| special tokens. Specifically: |
| - Setting `disallowed_special` to () will cause all text corresponding |
| to special tokens to be encoded as natural text (insteading of raising |
| an error). |
| - Setting `allowed_special` to "all" will treat all text corresponding |
| to special tokens to be encoded as special tokens. |
| """ |
| assert type(s) is str |
|
|
| |
| |
| TIKTOKEN_MAX_ENCODE_CHARS = 400_000 |
|
|
| |
| |
| |
| MAX_NO_WHITESPACES_CHARS = 25_000 |
|
|
| substrs = ( |
| substr |
| for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) |
| for substr in self._split_whitespaces_or_nonwhitespaces( |
| s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS |
| ) |
| ) |
| t: List[int] = [] |
| for substr in substrs: |
| t.extend( |
| self.model.encode( |
| substr, |
| allowed_special=allowed_special, |
| disallowed_special=disallowed_special, |
| ) |
| ) |
| if bos: |
| t.insert(0, self.bos_id) |
| if eos: |
| t.append(self.eos_id) |
| return t |
|
|
| def decode(self, t: Sequence[int]) -> str: |
| """ |
| Decodes a list of token IDs into a string. |
| |
| Args: |
| t (List[int]): The list of token IDs to be decoded. |
| |
| Returns: |
| str: The decoded string. |
| """ |
| |
| return self.model.decode(cast(List[int], t)) |
|
|
| @staticmethod |
| def _split_whitespaces_or_nonwhitespaces( |
| s: str, max_consecutive_slice_len: int |
| ) -> Iterator[str]: |
| """ |
| Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` |
| consecutive whitespaces or consecutive non-whitespaces. |
| """ |
| current_slice_len = 0 |
| current_slice_is_space = s[0].isspace() if len(s) > 0 else False |
| slice_start = 0 |
|
|
| for i in range(len(s)): |
| is_now_space = s[i].isspace() |
|
|
| if current_slice_is_space ^ is_now_space: |
| current_slice_len = 1 |
| current_slice_is_space = is_now_space |
| else: |
| current_slice_len += 1 |
| if current_slice_len > max_consecutive_slice_len: |
| yield s[slice_start:i] |
| slice_start = i |
| current_slice_len = 1 |
| yield s[slice_start:] |
|
|
| """ ----- Below are the abstract methods required by megatron ----- """ |
|
|
| @property |
| def vocab_size(self): |
| return self.n_words |
|
|
| @property |
| def vocab(self): |
| if hasattr(self, "str_vocab"): |
| return self.str_vocab |
| self.str_vocab = {} |
|
|
| |
| utf8_num, unicode_num = 0, 0 |
| for byte_key, index in self.model._mergeable_ranks.items(): |
| try: |
| str_key = byte_key.decode("utf-8") |
| utf8_num += 1 |
| except UnicodeDecodeError: |
| |
| |
| |
| str_key = byte_key.decode("utf-8", "backslashreplace") + "_unicode_" |
| unicode_num += 1 |
|
|
| self.str_vocab[str_key] = index |
| logger.info(f"num utf8: {utf8_num}, num unicode: {unicode_num}") |
|
|
| |
| self.str_vocab.update(self.model._special_tokens) |
|
|
| assert len(self.str_vocab) == self.vocab_size |
| return self.str_vocab |
|
|
| @property |
| def inv_vocab(self): |
| return {v: k for k, v in self.vocab.items()} |
|
|
| def tokenize(self, text, eos=True): |
| |
| |
| |
| |
| return self.encode(text, bos=True, eos=eos) |
|
|
| def detokenize(self, tokens): |
| |
| if not isinstance(tokens, list): |
| tokens = tokens.tolist() |
| return self.decode(tokens) |
|
|
| @property |
| def eod(self): |
| return self.eos_id |
|
|
| def bod(self): |
| return self.bos_id |
|
|
| @property |
| def msk_start_id(self): |
| return self.msk_start |
|
|
| @property |
| def msk_end_id(self): |
| return self.msk_end |
|
|
| def _get_index_2_bytes(self): |
| if hasattr(self, "index_2_bytes"): |
| return self.index_2_bytes |
|
|
| |
| self.index_2_bytes = [0] * self.model.n_vocab |
| for byte_key, index in self.model._mergeable_ranks.items(): |
| self.index_2_bytes[index] = len(byte_key) |
|
|
| for _, index in self.model._special_tokens.items(): |
| |
| |
| self.index_2_bytes[index] = 1 |
|
|
| return self.index_2_bytes |
|
|
| def get_array_bytes(self, array): |
| index_2_bytes = self._get_index_2_bytes() |
| return sum(index_2_bytes[i] for i in array) |
|
|
| @property |
| def eos_token_id(self): |
| return self.eos_id |
|
|
| @property |
| def pad_token_id(self): |
| return self.pad_id |
|
|