parserker / tokenization_parserker.py
yehzw's picture
Upload folder using huggingface_hub
b1c4dc3 verified
from typing import List
from typing import Union
import torch
from nltk.tokenize import TreebankWordTokenizer
from torchrua import C
from transformers.models.roberta import RobertaTokenizer
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, EncodedInput
nltk_tokenizer = TreebankWordTokenizer()
PTB_UNESCAPE_MAPPING = {
"«": '"',
"»": '"',
"‘": "'",
"’": "'",
"“": '"',
"”": '"',
"„": '"',
"‹": "'",
"›": "'",
"\u2013": "--", # en dash
"\u2014": "--", # em dash
}
def ptb_unescape(words: List[str]) -> List[str]:
cleaned_words = []
for word in words:
word = PTB_UNESCAPE_MAPPING.get(word, word)
# This un-escaping for / and * was not yet added for the
# parser version in https://arxiv.org/abs/1812.11760v1
# and related model releases (e.g. benepar2_en2)
word = word.replace("\\/", "/").replace("\\*", "*")
# Mid-token punctuation occurs in biomedical text
word = word.replace("-LSB-", "[").replace("-RSB-", "]")
word = word.replace("-LRB-", "(").replace("-RRB-", ")")
word = word.replace("-LCB-", "{").replace("-RCB-", "}")
word = word.replace("``", '"').replace("`", "'").replace("''", '"')
cleaned_words.append(word)
return cleaned_words
class ParserkerTokenizer(RobertaTokenizer):
def __call__(self, text: Union[TextInput, PreTokenizedInput, EncodedInput], **kwargs):
input_ids_list = []
duration_list = []
if isinstance(text, str):
tokens_list = [ptb_unescape(nltk_tokenizer.tokenize(text))]
else:
tokens_list = [ptb_unescape(nltk_tokenizer.tokenize(t)) for t in text]
for tokens in tokens_list:
out = super().__call__(
tokens,
return_attention_mask=False,
add_special_tokens=False,
is_split_into_words=False,
return_tensors=None,
)
input_ids = [t for ts in out['input_ids'] for t in ts]
duration = [len(ts) for ts in out['input_ids']]
input_ids_list.append([self.bos_token_id, *input_ids, self.eos_token_id])
duration_list.append([1, *duration, 1])
input_ids = C.new([torch.tensor(t, dtype=torch.long) for t in input_ids_list])
duration = C.new([torch.tensor(t, dtype=torch.long) for t in duration_list])
return tokens_list, input_ids, duration