File size: 2,501 Bytes
b1c4dc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import List
from typing import Union

import torch
from nltk.tokenize import TreebankWordTokenizer
from torchrua import C
from transformers.models.roberta import RobertaTokenizer
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, EncodedInput

nltk_tokenizer = TreebankWordTokenizer()

PTB_UNESCAPE_MAPPING = {
    "«": '"',
    "»": '"',
    "‘": "'",
    "’": "'",
    "“": '"',
    "”": '"',
    "„": '"',
    "‹": "'",
    "›": "'",
    "\u2013": "--",  # en dash
    "\u2014": "--",  # em dash
}


def ptb_unescape(words: List[str]) -> List[str]:
    cleaned_words = []

    for word in words:
        word = PTB_UNESCAPE_MAPPING.get(word, word)
        # This un-escaping for / and * was not yet added for the
        # parser version in https://arxiv.org/abs/1812.11760v1
        # and related model releases (e.g. benepar2_en2)
        word = word.replace("\\/", "/").replace("\\*", "*")
        # Mid-token punctuation occurs in biomedical text
        word = word.replace("-LSB-", "[").replace("-RSB-", "]")
        word = word.replace("-LRB-", "(").replace("-RRB-", ")")
        word = word.replace("-LCB-", "{").replace("-RCB-", "}")
        word = word.replace("``", '"').replace("`", "'").replace("''", '"')
        cleaned_words.append(word)

    return cleaned_words


class ParserkerTokenizer(RobertaTokenizer):
    def __call__(self, text: Union[TextInput, PreTokenizedInput, EncodedInput], **kwargs):
        input_ids_list = []
        duration_list = []

        if isinstance(text, str):
            tokens_list = [ptb_unescape(nltk_tokenizer.tokenize(text))]
        else:
            tokens_list = [ptb_unescape(nltk_tokenizer.tokenize(t)) for t in text]

        for tokens in tokens_list:
            out = super().__call__(
                tokens,
                return_attention_mask=False,
                add_special_tokens=False,
                is_split_into_words=False,
                return_tensors=None,
            )

            input_ids = [t for ts in out['input_ids'] for t in ts]
            duration = [len(ts) for ts in out['input_ids']]

            input_ids_list.append([self.bos_token_id, *input_ids, self.eos_token_id])
            duration_list.append([1, *duration, 1])

        input_ids = C.new([torch.tensor(t, dtype=torch.long) for t in input_ids_list])
        duration = C.new([torch.tensor(t, dtype=torch.long) for t in duration_list])

        return tokens_list, input_ids, duration