| import os |
| import torch |
| import base64 |
| import tiktoken |
| from typing import Collection, Optional, Dict, List, Set, Tuple, Union |
| from transformers import PreTrainedTokenizer |
| from transformers.utils import PaddingStrategy |
| from transformers.tokenization_utils import PreTrainedTokenizer |
|
|
|
|
| PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" |
|
|
|
|
| class SPTokenizer: |
| def __init__(self, model_path): |
| self.vocab_file = model_path |
| self.pad_token = '<pad>' |
| self.unk_token = '<unk>' |
| self.mask_token = '<mask>' |
| self.eod_token = '<eod>' |
| self.eop_token = '<eop>' |
| self.im_start_token = '<|im_start|>' |
| self.im_end_token = '<|im_end|>' |
|
|
| |
| self.SPECIAL_TOKENS = ( |
| self.pad_token, |
| self.unk_token, |
| self.mask_token, |
| self.eod_token, |
| self.eop_token, |
| '[space2]', '[space3]', '[space4]', '[space8]', |
| self.im_start_token, self.im_end_token |
| ) |
| self.bulid_tokenizer() |
| self.out = self.output_core_token() |
| |
| self.token2strs = { |
| "[space2]": " ", |
| "[space3]": " ", |
| "[space4]": " ", |
| "[space8]": " ", |
| } |
| self.str2tokens = {v: k for k, v in self.token2strs.items()} |
| self.sorted_strs = sorted(list(self.str2tokens.keys()), |
| key=lambda x: len(x), reverse=True) |
| |
| |
| self.decode_skip_special_tokens = [ |
| self.pad_token, |
| self.unk_token, |
| self.mask_token, |
| self.eod_token, |
| self.eop_token, |
| self.im_start_token, |
| self.im_end_token] |
| self.decode_skip_special_tokens_ids = [self.convert_token_to_id(token) for token in self.decode_skip_special_tokens] |
|
|
| def _load_tiktoken_bpe(self, tiktoken_bpe_file: str): |
| with open(tiktoken_bpe_file, "rb") as f: |
| contents = f.read() |
| return { |
| base64.b64decode(token): int(rank) |
| for token, rank in (line.split() for line in contents.splitlines() if line) |
| } |
| |
| def bulid_tokenizer(self): |
| mergeable_ranks = self._load_tiktoken_bpe(self.vocab_file) |
| special_tokens = { |
| token: index |
| for index, token in enumerate( |
| self.SPECIAL_TOKENS, start=len(mergeable_ranks) |
| ) |
| } |
| encode = tiktoken.Encoding( |
| "zhinao", |
| pat_str=PAT_STR, |
| mergeable_ranks=mergeable_ranks, |
| special_tokens=special_tokens |
| ) |
| decoder = {v: k for k, v in mergeable_ranks.items()} |
| decoder.update({v: k for k, v in special_tokens.items()}) |
| decoder_token2id = {v: k for k, v in decoder.items()} |
| |
| self.tokenizer = encode |
| self.decoder = decoder |
| self.decoder_token2id = decoder_token2id |
| self.num_tokens = len(mergeable_ranks) + len(self.SPECIAL_TOKENS) |
|
|
| def output_core_token(self): |
| """output special tokens""" |
| out = {} |
| for t in self.SPECIAL_TOKENS: |
| out[t] = self.convert_token_to_id(t) |
| return out |
|
|
| def tokenize( |
| self, |
| text, |
| allowed_special: Union[Set, str] = "all", |
| disallowed_special: Union[Collection, str] = ()): |
| tokens = [] |
| text = self.convert(text) |
| for idx in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special): |
| tokens.append(self.decoder[idx]) |
| return tokens |
|
|
| def encode(self, text, allowed_special="all", disallowed_special=()): |
| """text to id""" |
| text = self.convert(text) |
| return self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special) |
| |
| def decode(self, ids, errors="replace"): |
| """id to text""" |
| text = self.tokenizer.decode(ids, errors=errors) |
| return self.deconvert(text) |
|
|
| def decode_tokens(self, tokens: List[str]) -> str: |
| """ |
| Converts a sequence of tokens in a single string. |
| """ |
| text = "" |
| temp = b"" |
| for t in tokens: |
| if isinstance(t, str): |
| if temp: |
| text += temp.decode("utf-8", errors="replace") |
| temp = b"" |
| text += t |
| elif isinstance(t, bytes): |
| temp += t |
| else: |
| raise TypeError("token should only be of type bytes or str") |
| if temp: |
| text += temp.decode("utf-8", errors="replace") |
| return self.deconvert(text) |
|
|
| def convert_id_to_token(self, idx): |
| return self.decoder[idx] |
| |
| def convert_token_to_id(self, token): |
| return self.decoder_token2id[token] |
|
|
| def convert(self, text): |
| """将文本的特殊字符转换成特殊token""" |
| for k in ["[br]", "<br>"]: |
| text = text.replace(k, "\n") |
| for k in self.sorted_strs: |
| if k in text: |
| text = text.replace(k, self.str2tokens[k]) |
| return text |
|
|
| def deconvert(self, text): |
| """将解码文本恢复原始字符""" |
| for t in self.token2strs: |
| if t in text: |
| text = text.replace(t, self.token2strs[t]) |
| return text |
|
|
|
|
| class ZhinaoTokenizer(PreTrainedTokenizer): |
| vocab_files_names = {"vocab_file": "vocab/360.tiktoken"} |
| model_input_names = ["input_ids", "attention_mask"] |
|
|
| def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs): |
| self.name = "ZhinaoTokenizer" |
| self.errors = "replace" |
| self.vocab_file = vocab_file |
| self.tokenizer = SPTokenizer(model_path=vocab_file) |
| try: |
| kwargs.pop('eos_token') |
| kwargs.pop('pad_token') |
| kwargs.pop('unk_token') |
| except: |
| pass |
| super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs) |
| self.pad_token_id = self.tokenizer.convert_token_to_id(self.tokenizer.pad_token) |
| self.eod_id = self.tokenizer.convert_token_to_id(self.tokenizer.eod_token) |
| self.im_start_id = self.tokenizer.convert_token_to_id(self.tokenizer.im_start_token) |
| self.im_end_id = self.tokenizer.convert_token_to_id(self.tokenizer.im_end_token) |
| from icecream import ic |
| ic( |
| self.eos_token_id, |
| self.pad_token_id, |
| self.im_start_id, |
| self.im_end_id) |
|
|
| @property |
| def unk_token(self) -> str: |
| return self.tokenizer.unk_token |
|
|
| @property |
| def pad_token(self) -> str: |
| return self.tokenizer.pad_token |
|
|
| @property |
| def eos_token(self) -> str: |
| return self.tokenizer.eod_token |
|
|
| @property |
| def eos_token_id(self): |
| return self.tokenizer.convert_token_to_id(self.tokenizer.eod_token) |
|
|
| @property |
| def eop_token(self) -> str: |
| return self.tokenizer.eop_token |
|
|
| @property |
| def eop_token_id(self): |
| return self.tokenizer.convert_token_to_id(self.tokenizer.eop_token) |
|
|
| @property |
| def vocab_size(self): |
| return self.tokenizer.num_tokens |
|
|
| def get_vocab(self): |
| """ Returns vocab as a dict """ |
| vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} |
| vocab.update(self.added_tokens_encoder) |
| return vocab |
| |
| def tokenize( |
| self, |
| text: str, |
| allowed_special: Union[Set, str] = "all", |
| disallowed_special: Union[Collection, str] = (), |
| ) -> List[Union[bytes, str]]: |
| tokens = [] |
| for t in self.tokenizer.encode( |
| text, allowed_special=allowed_special, disallowed_special=disallowed_special |
| ): |
| tokens.append(self.tokenizer.decoder[t]) |
| return tokens |
| |
| def _decode( |
| self, |
| token_ids: Union[int, List[int]], |
| skip_special_tokens: bool = False, |
| errors: str = None, |
| **kwargs, |
| ) -> str: |
| if isinstance(token_ids, int): |
| token_ids = [token_ids] |
| if skip_special_tokens: |
| token_ids = [i for i in token_ids if i not in self.tokenizer.decode_skip_special_tokens_ids] |
| return self.tokenizer.decode(token_ids, errors=errors or self.errors) |
|
|
| def _tokenize(self, text, **kwargs): |
| raise NotImplementedError |
|
|
| def _convert_token_to_id(self, token): |
| """ Converts a token (str) in an id using the vocab. """ |
| return self.tokenizer.convert_token_to_id(token) |
|
|
| def _convert_id_to_token(self, index): |
| """Converts an index (integer) in a token (str) using the vocab. """ |
| return self.tokenizer.convert_id_to_token(index) |
|
|
| def convert_tokens_to_string(self, tokens: List[str]) -> str: |
| """ |
| Converts a sequence of tokens in a single string. |
| """ |
| return self.tokenizer.decode_tokens(tokens) |
|
|
| def save_vocabulary(self, save_directory, filename_prefix=None): |
| """Save only the vocabulary of the tokenizer (vocabulary). """ |
| if os.path.isdir(save_directory): |
| vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"]) |
| else: |
| vocab_file = save_directory |
|
|
| with open(self.vocab_file, 'rb') as fin: |
| proto_str = fin.read() |
|
|
| os.makedirs(save_directory + "/vocab", exist_ok=True) |
| with open(vocab_file, "wb") as writer: |
| writer.write(proto_str) |
|
|
| return (vocab_file,) |