File size: 5,329 Bytes
f316449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
from transformers import BertTokenizer
from functools import partial
import logging
import re
import itertools
import Nested

logger = logging.getLogger(__name__)


class BertSeqTransform:
    def __init__(self, bert_model, vocab, max_seq_len=512):
        self.tokenizer = BertTokenizer.from_pretrained(bert_model)
        self.encoder = partial(
            self.tokenizer.encode,
            max_length=max_seq_len,
            truncation=True,
        )
        self.max_seq_len = max_seq_len
        self.vocab = vocab

    def __call__(self, segment):
        subwords, tags, tokens = list(), list(), list()
        unk_token = Nested.data.datasets.Token(text="UNK")

        for token in segment:
            # Sometimes the tokenizer fails to encode the word and return no input_ids, in that case, we use
            # the input_id for [UNK]
            token_subwords = self.encoder(token.text)[1:-1] or self.encoder("[UNK]")[1:-1]
            subwords += token_subwords
            tags += [self.vocab.tags[0].get_stoi()[token.gold_tag[0]]] + [self.vocab.tags[0].get_stoi()["O"]] * (len(token_subwords) - 1)
            tokens += [token] + [unk_token] * (len(token_subwords) - 1)

        # Truncate to max_seq_len
        if len(subwords) > self.max_seq_len - 2:
            text = " ".join([t.text for t in tokens if t.text != "UNK"])
            logger.info("Truncating the sequence %s to %d", text, self.max_seq_len - 2)
            subwords = subwords[:self.max_seq_len - 2]
            tags = tags[:self.max_seq_len - 2]
            tokens = tokens[:self.max_seq_len - 2]

        subwords.insert(0, self.tokenizer.cls_token_id)
        subwords.append(self.tokenizer.sep_token_id)

        tags.insert(0, self.vocab.tags[0].get_stoi()["O"])
        tags.append(self.vocab.tags[0].get_stoi()["O"])

        tokens.insert(0, unk_token)
        tokens.append(unk_token)

        return torch.LongTensor(subwords), torch.LongTensor(tags), tokens, len(tokens)


class NestedTagsTransform:
    def __init__(self, bert_model, vocab, max_seq_len=512):
        self.tokenizer = BertTokenizer.from_pretrained(bert_model)
        self.encoder = partial(
            self.tokenizer.encode,
            max_length=max_seq_len,
            truncation=True,
        )
        self.max_seq_len = max_seq_len
        self.vocab = vocab

    def __call__(self, segment):
        tags, tokens, subwords = list(), list(), list()
        unk_token = Nested.data.datasets.Token(text="UNK")

        # Encode each token and get its subwords and IDs
        for token in segment:
            # Sometimes the tokenizer fails to encode the word and return no input_ids, in that case, we use
            # the input_id for [UNK]
            token.subwords = self.encoder(token.text)[1:-1] or self.encoder("[UNK]")[1:-1]
            subwords += token.subwords
            tokens += [token] + [unk_token] * (len(token.subwords) - 1)

        # Construct the labels for each tag type
        # The sequence will have a list of tags for each type
        # The final tags for a sequence is a matrix NUM_TAG_TYPES x SEQ_LEN
        # Example:
        #   [
        #       [O,     O,     B-PERS, I-PERS, O, O, O]
        #       [B-ORG, I-ORG, O,      O,      O, O, O]
        #       [O,     O,     O,      O,      O, O, B-GPE]
        #   ]
        for vocab in self.vocab.tags[1:]:
            vocab_tags = "|".join(["^" + t + "$" for t in vocab.get_itos() if "-" in t])
            r = re.compile(vocab_tags)

            # This is really messy
            # For a given token we find a matching tag_name, BUT we might find
            # multiple matches (i.e. a token can be labeled B-ORG and I-ORG) in this
            # case we get only the first tag as we do not have overlapping of same type
            single_type_tags = [[(list(filter(r.match, token.gold_tag))
                                or ["O"])[0]] + ["O"] * (len(token.subwords) - 1)
                                for token in segment]
            single_type_tags = list(itertools.chain(*single_type_tags))
            tags.append([vocab.get_stoi()[tag] for tag in single_type_tags])

        # Truncate to max_seq_len
        if len(subwords) > self.max_seq_len - 2:
            text = " ".join([t.text for t in tokens if t.text != "UNK"])
            logger.info("Truncating the sequence %s to %d", text, self.max_seq_len - 2)
            subwords = subwords[:self.max_seq_len - 2]
            tags = [t[:self.max_seq_len - 2] for t in tags]
            tokens = tokens[:self.max_seq_len - 2]

        # Add dummy token at the start end of sequence
        tokens.insert(0, unk_token)
        tokens.append(unk_token)

        # Add CLS and SEP at start end of subwords
        subwords.insert(0, self.tokenizer.cls_token_id)
        subwords.append(self.tokenizer.sep_token_id)
        subwords = torch.LongTensor(subwords)

        # Add "O" tags for the first and last subwords
        tags = torch.Tensor(tags)
        tags = torch.column_stack((
            torch.Tensor([vocab.get_stoi()["O"] for vocab in self.vocab.tags[1:]]),
            tags,
            torch.Tensor([vocab.get_stoi()["O"] for vocab in self.vocab.tags[1:]]),
        )).unsqueeze(0)

        mask = torch.ones_like(tags)
        return subwords, tags, tokens, mask, len(tokens)