| import json |
| import os |
| import shutil |
| import time |
| from pathlib import Path |
| from typing import List, Union, Optional |
|
|
| import tokenizers |
| import torch |
| from torch import NoneType |
| from huggingface_hub import hf_hub_download |
| from pypinyin import pinyin, Style |
| from transformers.tokenization_utils_base import TruncationStrategy |
| from transformers.utils import PaddingStrategy |
| from transformers.utils.generic import TensorType |
|
|
| try: |
| from tokenizers import BertWordPieceTokenizer |
| except: |
| from tokenizers.implementations import BertWordPieceTokenizer |
|
|
| from transformers import BertTokenizerFast, BatchEncoding |
|
|
| cache_path = Path(os.path.abspath(__file__)).parent |
|
|
|
|
| def download_file(filename: str, path: Path): |
| if os.path.exists(cache_path / filename): |
| return |
|
|
| if os.path.exists(path / filename): |
| shutil.copyfile(path / filename, cache_path / filename) |
| return |
|
|
| hf_hub_download( |
| "iioSnail/ChineseBERT-for-csc", |
| filename, |
| local_dir=cache_path |
| ) |
| time.sleep(0.2) |
|
|
|
|
| class ChineseBertTokenizer(BertTokenizerFast): |
|
|
| def __init__(self, **kwargs): |
| super(ChineseBertTokenizer, self).__init__(**kwargs) |
|
|
| self.path = Path(kwargs['name_or_path']) |
| vocab_file = cache_path / 'vocab.txt' |
| config_path = cache_path / 'config' |
| if not os.path.exists(config_path): |
| os.makedirs(config_path) |
|
|
| self.max_length = 512 |
|
|
| download_file('vocab.txt', self.path) |
| self.tokenizer = BertWordPieceTokenizer(str(vocab_file)) |
|
|
| |
| download_file('config/pinyin_map.json', self.path) |
| with open(config_path / 'pinyin_map.json', encoding='utf8') as fin: |
| self.pinyin_dict = json.load(fin) |
|
|
| |
| download_file('config/id2pinyin.json', self.path) |
| with open(config_path / 'id2pinyin.json', encoding='utf8') as fin: |
| self.id2pinyin = json.load(fin) |
|
|
| |
| download_file('config/pinyin2tensor.json', self.path) |
| with open(config_path / 'pinyin2tensor.json', encoding='utf8') as fin: |
| self.pinyin2tensor = json.load(fin) |
|
|
| def __call__(self, |
| text: Union[str, List[str], List[List[str]]] = None, |
| text_pair: Union[str, List[str], List[List[str]], NoneType] = None, |
| text_target: Union[str, List[str], List[List[str]]] = None, |
| text_pair_target: Union[str, List[str], List[List[str]], NoneType] = None, |
| add_special_tokens: bool = True, |
| padding: Union[bool, str, PaddingStrategy] = False, |
| truncation: Union[bool, str, TruncationStrategy] = None, |
| max_length: Optional[int] = None, |
| stride: int = 0, |
| is_split_into_words: bool = False, |
| pad_to_multiple_of: Optional[int] = None, |
| return_tensors: Union[str, TensorType, NoneType] = None, |
| return_token_type_ids: Optional[bool] = None, |
| return_attention_mask: Optional[bool] = None, |
| return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, |
| return_offsets_mapping: bool = False, |
| return_length: bool = False, |
| verbose: bool = True, **kwargs) -> BatchEncoding: |
| encoding = super(ChineseBertTokenizer, self).__call__( |
| text=text, |
| text_pair=text_pair, |
| text_target=text_target, |
| text_pair_target=text_pair_target, |
| add_special_tokens=add_special_tokens, |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| stride=stride, |
| is_split_into_words=is_split_into_words, |
| pad_to_multiple_of=pad_to_multiple_of, |
| return_tensors=return_tensors, |
| return_token_type_ids=return_token_type_ids, |
| return_attention_mask=return_attention_mask, |
| return_overflowing_tokens=return_overflowing_tokens, |
| return_offsets_mapping=True, |
| return_length=return_length, |
| verbose=verbose, |
| ) |
|
|
| input_ids = encoding.input_ids |
|
|
| pinyin_ids = None |
| if type(text) == str: |
| offsets = encoding.offset_mapping[0].tolist() |
| tokens = self.sentence_to_tokens(text, offsets) |
| pinyin_ids = [self.convert_sentence_to_pinyin_ids(text, tokens, offsets)] |
|
|
| if type(text) == list or type(text) == tuple: |
| pinyin_ids = [] |
| for i, sentence in enumerate(text): |
| offsets = encoding.offset_mapping[i].tolist() |
| tokens = self.sentence_to_tokens(sentence, offsets) |
| pinyin_ids.append(self.convert_sentence_to_pinyin_ids(sentence, tokens, offsets)) |
|
|
| if torch.is_tensor(encoding.input_ids): |
| pinyin_ids = torch.LongTensor(pinyin_ids) |
|
|
| encoding['pinyin_ids'] = pinyin_ids |
|
|
| if not return_offsets_mapping: |
| del encoding['offset_mapping'] |
|
|
| return encoding |
|
|
| def sentence_to_tokens(self, sentence, offsets): |
| tokens = [] |
| for start, end in offsets: |
| tokens.append(sentence[start:end]) |
| return tokens |
|
|
| def convert_sentence_to_pinyin_ids(self, sentence: str, tokens, offsets): |
| |
| pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x]) |
| pinyin_locs = {} |
| |
| for index, item in enumerate(pinyin_list): |
| pinyin_string = item[0] |
| |
| if pinyin_string == "not chinese": |
| continue |
| if pinyin_string in self.pinyin2tensor: |
| pinyin_locs[index] = self.pinyin2tensor[pinyin_string] |
| else: |
| ids = [0] * 8 |
| for i, p in enumerate(pinyin_string): |
| if p not in self.pinyin_dict["char2idx"]: |
| ids = [0] * 8 |
| break |
| ids[i] = self.pinyin_dict["char2idx"][p] |
| pinyin_locs[index] = ids |
|
|
| |
| pinyin_ids = [] |
| for idx, (token, offset) in enumerate(zip(tokens, offsets)): |
| if offset[1] - offset[0] != 1: |
| pinyin_ids.append([0] * 8) |
| continue |
| if offset[0] in pinyin_locs: |
| pinyin_ids.append(pinyin_locs[offset[0]]) |
| else: |
| pinyin_ids.append([0] * 8) |
|
|
| return pinyin_ids |
|
|