from typing import List from typing import Union import torch from nltk.tokenize import TreebankWordTokenizer from torchrua import C from transformers.models.roberta import RobertaTokenizer from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, EncodedInput nltk_tokenizer = TreebankWordTokenizer() PTB_UNESCAPE_MAPPING = { "«": '"', "»": '"', "‘": "'", "’": "'", "“": '"', "”": '"', "„": '"', "‹": "'", "›": "'", "\u2013": "--", # en dash "\u2014": "--", # em dash } def ptb_unescape(words: List[str]) -> List[str]: cleaned_words = [] for word in words: word = PTB_UNESCAPE_MAPPING.get(word, word) # This un-escaping for / and * was not yet added for the # parser version in https://arxiv.org/abs/1812.11760v1 # and related model releases (e.g. benepar2_en2) word = word.replace("\\/", "/").replace("\\*", "*") # Mid-token punctuation occurs in biomedical text word = word.replace("-LSB-", "[").replace("-RSB-", "]") word = word.replace("-LRB-", "(").replace("-RRB-", ")") word = word.replace("-LCB-", "{").replace("-RCB-", "}") word = word.replace("``", '"').replace("`", "'").replace("''", '"') cleaned_words.append(word) return cleaned_words class ParserkerTokenizer(RobertaTokenizer): def __call__(self, text: Union[TextInput, PreTokenizedInput, EncodedInput], **kwargs): input_ids_list = [] duration_list = [] if isinstance(text, str): tokens_list = [ptb_unescape(nltk_tokenizer.tokenize(text))] else: tokens_list = [ptb_unescape(nltk_tokenizer.tokenize(t)) for t in text] for tokens in tokens_list: out = super().__call__( tokens, return_attention_mask=False, add_special_tokens=False, is_split_into_words=False, return_tensors=None, ) input_ids = [t for ts in out['input_ids'] for t in ts] duration = [len(ts) for ts in out['input_ids']] input_ids_list.append([self.bos_token_id, *input_ids, self.eos_token_id]) duration_list.append([1, *duration, 1]) input_ids = C.new([torch.tensor(t, dtype=torch.long) for t in input_ids_list]) duration = C.new([torch.tensor(t, dtype=torch.long) for t in duration_list]) return tokens_list, input_ids, duration