| from typing import List |
| from typing import Union |
|
|
| import torch |
| from nltk.tokenize import TreebankWordTokenizer |
| from torchrua import C |
| from transformers.models.roberta import RobertaTokenizer |
| from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, EncodedInput |
|
|
| nltk_tokenizer = TreebankWordTokenizer() |
|
|
| PTB_UNESCAPE_MAPPING = { |
| "«": '"', |
| "»": '"', |
| "‘": "'", |
| "’": "'", |
| "“": '"', |
| "”": '"', |
| "„": '"', |
| "‹": "'", |
| "›": "'", |
| "\u2013": "--", |
| "\u2014": "--", |
| } |
|
|
|
|
| def ptb_unescape(words: List[str]) -> List[str]: |
| cleaned_words = [] |
|
|
| for word in words: |
| word = PTB_UNESCAPE_MAPPING.get(word, word) |
| |
| |
| |
| word = word.replace("\\/", "/").replace("\\*", "*") |
| |
| word = word.replace("-LSB-", "[").replace("-RSB-", "]") |
| word = word.replace("-LRB-", "(").replace("-RRB-", ")") |
| word = word.replace("-LCB-", "{").replace("-RCB-", "}") |
| word = word.replace("``", '"').replace("`", "'").replace("''", '"') |
| cleaned_words.append(word) |
|
|
| return cleaned_words |
|
|
|
|
| class ParserkerTokenizer(RobertaTokenizer): |
| def __call__(self, text: Union[TextInput, PreTokenizedInput, EncodedInput], **kwargs): |
| input_ids_list = [] |
| duration_list = [] |
|
|
| if isinstance(text, str): |
| tokens_list = [ptb_unescape(nltk_tokenizer.tokenize(text))] |
| else: |
| tokens_list = [ptb_unescape(nltk_tokenizer.tokenize(t)) for t in text] |
|
|
| for tokens in tokens_list: |
| out = super().__call__( |
| tokens, |
| return_attention_mask=False, |
| add_special_tokens=False, |
| is_split_into_words=False, |
| return_tensors=None, |
| ) |
|
|
| input_ids = [t for ts in out['input_ids'] for t in ts] |
| duration = [len(ts) for ts in out['input_ids']] |
|
|
| input_ids_list.append([self.bos_token_id, *input_ids, self.eos_token_id]) |
| duration_list.append([1, *duration, 1]) |
|
|
| input_ids = C.new([torch.tensor(t, dtype=torch.long) for t in input_ids_list]) |
| duration = C.new([torch.tensor(t, dtype=torch.long) for t in duration_list]) |
|
|
| return tokens_list, input_ids, duration |
|
|