|
|
| import torch |
| import sentencepiece |
| import jieba |
| import numpy as np |
|
|
| from transformers.tokenization_utils import PreTrainedTokenizer |
|
|
| jieba.add_word('<s>') |
| jieba.add_word('</s>') |
| jieba.add_word('<eot>') |
| jieba.add_word('<unk>') |
| jieba.add_word('<sep>') |
| jieba.add_word('<pad>') |
|
|
|
|
| class GPTPanguTokenizer(PreTrainedTokenizer): |
| |
| vocab_files_names = { |
| "model_file": "vocab.model" |
| } |
|
|
| def __init__( |
| self, |
| model_file, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
|
|
| self.sp = sentencepiece.SentencePieceProcessor() |
| self.sp.Load(model_file=model_file) |
| self.translator = str.maketrans(" \n", "\u2582\u2583") |
|
|
| |
| |
|
|
| def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
| """ |
| Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and |
| adding special tokens. A BERT sequence has the following format: |
| |
| - single sequence: `[CLS] X [SEP]` |
| - pair of sequences: `[CLS] A [SEP] B [SEP]` |
| |
| Args: |
| token_ids_0 (`List[int]`): |
| List of IDs to which the special tokens will be added. |
| token_ids_1 (`List[int]`, *optional*): |
| Optional second list of IDs for sequence pairs. |
| |
| Returns: |
| `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. |
| """ |
| if self.bos_token_id is not None: |
| if token_ids_1 is None: |
| return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] |
| bos = [self.bos_token_id] |
| sep = [self.sep_token_id] |
| eos = [self.eos_token_id] |
| return bos + token_ids_0 + sep + token_ids_1 + eos |
| else: |
| if token_ids_1 is None: |
| return token_ids_0 + [self.eos_token_id] |
| sep = [self.sep_token_id] |
| eos = [self.eos_token_id] |
| return token_ids_0 + sep + token_ids_1 + eos |
|
|
| def tokenize(self, text, **kwargs): |
| """ Tokenize a string. """ |
| seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)] |
| return seg_list |
|
|
| def convert_tokens_to_ids(self, tokens): |
| if tokens is None: |
| return None |
|
|
| if isinstance(tokens, str): |
| return self._convert_token_to_id_with_added_voc(tokens) |
|
|
| special_tokens_index = [i for i, token in enumerate(tokens) if token in self.all_special_tokens] |
|
|
| ids = [] |
| i = 0 |
| for j in special_tokens_index: |
| new_seg = " ".join(tokens[i:j]) |
| ids.extend(self.sp.encode(new_seg)) |
| ids.append(self._convert_token_to_id(tokens[j])) |
| i = j + 1 |
|
|
| new_seg = " ".join(tokens[i:]) |
| ids.extend(self.sp.encode(new_seg)) |
|
|
| return ids |
|
|
| |
| |
| |
|
|
| def _convert_token_to_id(self, token): |
| return self.sp.piece_to_id(token) |
|
|
| def _convert_id_to_token(self, index): |
| return self.sp.id_to_piece(index) |
|
|
| def convert_ids_to_tokens(self, ids): |
| return self.decode(ids) |
|
|
| def decode(self, ids, **kwargs): |
| if isinstance(ids, torch.Tensor) or isinstance(ids, np.ndarray): |
| ids = ids.tolist() |
|
|
| if kwargs.get('skip_special_tokens', None) is True: |
| ids = [token_id for token_id in ids if token_id not in self.all_special_ids] |
| text = self.sp.decode(ids) |
| if isinstance(text, list): |
| text = text[0] |
| text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n') |
| return text |
|
|
| @property |
| def vocab_size(self) -> int: |
| """ |
| `int`: Size of the base vocabulary (without the added tokens). |
| """ |
| return len(self.sp) |
|
|