| from transformers.tokenization_utils import PreTrainedTokenizer |
|
|
| import torch |
| import sentencepiece |
| import jieba |
|
|
|
|
| class GPTPanguTokenizer(PreTrainedTokenizer): |
| |
| vocab_files_names = { |
| "model_file": "vocab.model" |
| } |
|
|
| def __init__( |
| self, |
| model_file, |
| **kwargs |
| ): |
| super().__init__() |
|
|
| self.sp = sentencepiece.SentencePieceProcessor() |
| self.sp.Load(model_file=model_file) |
| self.translator = str.maketrans(" \n", "\u2582\u2583") |
|
|
| |
| self.eos_token_id = self.sp.piece_to_id("<eot>") |
|
|
| def tokenize(self, text, **kwargs): |
| """ Tokenize a string. """ |
| seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)] |
| new_seg = " ".join(seg_list) |
| return self.sp.encode(new_seg) |
|
|
| def convert_tokens_to_ids(self, tokens): |
| return tokens |
|
|
| def convert_ids_to_tokens(self, ids): |
| return self.decode(ids) |
|
|
| def decode(self, tokens, **kwargs): |
| if isinstance(tokens, torch.Tensor): |
| tokens = tokens.tolist() |
|
|
| text = self.sp.decode(tokens) |
| text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n') |
| return text |
|
|