File size: 16,478 Bytes
edbfc07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c886682
edbfc07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c886682
 
edbfc07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59a876a
 
 
 
 
 
 
 
edbfc07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59a876a
edbfc07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d0d6e1
 
 
 
 
 
 
 
 
 
edbfc07
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
"""Tokenization utilities for RNNLM - entity extraction, replacement, and decoding."""

import re
import numpy as np

# RNG for adapt_tok_seq_ents when sampling from sub_ent_probs
_rng = np.random.RandomState(0)


def segment(encoder, seq):
    doc = encoder(seq)
    return [getattr(sent, 'text', getattr(sent, 'string', str(sent))).strip() for sent in doc.sents]


def tokenize(encoder, seq, lowercase=True, recognize_ents=False,
             lemmatize=False, include_tags=[], include_pos=[], prepend_start=False):
    seq = encoder(seq)
    if recognize_ents:  # merge named entities into single tokens
        ent_start_idxs = {ent.start: ent for ent in seq.ents
                          if getattr(ent, 'text', getattr(ent, 'string', '')).strip()}
        # combine each ent into a single token; this is pretty hard to read, but it works
        seq = [ent_start_idxs[word_idx] if word_idx in ent_start_idxs else word
               for word_idx, word in enumerate(seq)
               if (not word.ent_type_ or word_idx in ent_start_idxs)]

    def _wtext(w):
        return getattr(w, 'text', getattr(w, 'string', str(w))).strip()

    # Don't apply POS filtering to phrases (words with underscores)
    if include_tags:  # fine-grained POS tags
        seq = [word for word in seq
               if ("_" in _wtext(word) or word.tag_ in include_tags)]
    if include_pos:  # coarse-grained POS tags
        seq = [word for word in seq
               if ("_" in _wtext(word) or word.pos_ in include_pos)]
    if lemmatize:
        seq = [word.lemma_ if not _wtext(word).startswith('ENT_')
               else _wtext(word) for word in seq]
    # don't lowercase if token is an entity (entities will be of type span instead of token; or will be prefixed with 'ENT_' if already transformed to types)
    elif lowercase:
        seq = [_wtext(word).lower() if not _wtext(word).startswith('ENT_')
               else _wtext(word) for word in seq]
    else:
        seq = [_wtext(word) for word in seq]
    # some words may be empty strings, so filter
    seq = [word for word in seq if word]
    if prepend_start:
        seq.insert(0, u"<START>")
    return seq


def ent_counts_to_probs(ent_counts):
    """Convert entity counts to probabilities for sampling when adapting entities."""
    return {ent_type: {ent: count * 1.0 / sum(counts.values())
                       for ent, count in counts.items()}
            for ent_type, counts in ent_counts.items()}


def get_ents(encoder, seq, include_ent_types=('PERSON', 'NORP', 'ORG', 'GPE')):
    '''return dict of all entities in seq mapped to their entity types, optionally labeled with gender for PERSON entities'''

    ents = {}
    ent_counts = {}
    for ent in encoder(seq).ents:
        ent_type = ent.label_
        if ent_type in include_ent_types:
            ent = getattr(ent, 'text', getattr(
                ent, 'string', str(ent))).strip()
            if ent:  # not sure why, but whitespace can be detected as an ent, so need to check for this
                ents[ent] = [ent_type]
                if ent in ent_counts:
                    ent_counts[ent] += 1
                else:
                    ent_counts[ent] = 1
                ents[ent] = "_".join(ents[ent])
    return ents, ent_counts


def number_ents(encoder, ents, ent_counts):
    '''return dict of all entities in seq mapped to their entity types,
    with numerical suffixes to distinguish entities of the same type'''
    ent_counts = sorted([(count, ent, ents[ent])
                        for ent, count in ent_counts.items()])[::-1]
    ent_type_counts = {}
    num_ents = {}
    for count, ent, ent_type in ent_counts:
        tok_ent = tokenize(encoder, ent, lowercase=False)
        coref_ent = [num_ent for num_ent in num_ents
                     if (tokenize(encoder, num_ent, lowercase=False)[0] == tok_ent[0]
                         or tokenize(encoder, num_ent, lowercase=False)[-1] == tok_ent[-1])
                     # treat ents with same first or last word as co-referring
                     and ents[num_ent] == ent_type]
        if coref_ent:
            num_ents[ent] = num_ents[coref_ent[0]]
        else:
            ent_type = ent_type.split("_")
            if ent_type[0] in ent_type_counts:
                ent_type_counts[ent_type[0]] += 1
            else:
                ent_type_counts[ent_type[0]] = 1
            num_ents[ent] = ent_type
            # insert number id after entity type (and before tag, if it exists)
            num_ents[ent].insert(1, str(ent_type_counts[ent_type[0]] - 1))
            num_ents[ent] = "_".join(num_ents[ent])
    return num_ents


def replace_ents_in_seq(encoder, seq):
    '''extract entities from seq and replace them with their entity types'''
    ents, ent_counts = get_ents(encoder, seq)
    ents = number_ents(encoder, ents, ent_counts)
    seq = tokenize(encoder, seq, lowercase=False, recognize_ents=True)
    # word can be Token or Span; get text for lookup

    def _text(w):
        return (getattr(w, 'text', None) or getattr(w, 'string', None) or str(w)).strip()
    seq = ['ENT_' + ents[_text(word)] if _text(word)
           in ents else _text(word) for word in seq]
    seq = " ".join(seq)
    return seq


def decode_num_seqs(encoder, lexicon_lookup, unk_word, seqs, max_new_sents=None, eos_tokens=[],
                    detokenize=False, ents=[], capitalize_ents=False, adapt_ents=False,
                    sub_ent_probs=None, begin_sentence=True):
    if not seqs:
        return []
    if type(seqs[0]) not in (list, np.ndarray, tuple):
        seqs = [seqs]
    decoded_seqs = []
    # transform numerical seq back into string (seq elements are token IDs)
    for seq_idx, seq in enumerate(seqs):
        # Flatten to list of Python ints (handles 2D tensors from model.generate, e.g. (1, seq_len))
        if hasattr(seq, 'cpu'):
            seq = seq.cpu()
        if hasattr(seq, 'tolist'):
            seq = seq.tolist()
        elif seq and hasattr(seq[0], 'tolist'):
            # list(tensor) gives list of row tensors - convert each to list
            seq = [row.tolist() for row in seq]
        else:
            seq = list(seq)
        # If 2D (batch, seq_len), take each row; else single sequence
        if seq and isinstance(seq[0], list):
            rows = seq
        else:
            rows = [seq]

        def _to_int(x):
            if isinstance(x, (list, tuple)):
                return [_to_int(v) for v in x]
            return int(x.item()) if hasattr(x, 'item') else int(x)

        for row_idx, row in enumerate(rows):
            tok_seq = []
            flat_row = _to_int(row) if isinstance(
                row, (list, tuple)) else [_to_int(row)]
            if isinstance(flat_row[0], list):
                flat_row = [v for sub in flat_row for v in (
                    sub if isinstance(sub, list) else [sub])]
            for w in flat_row:
                i = w if isinstance(w, int) else int(w)
                tok_seq.append(
                    lexicon_lookup[i] if (0 <= i < len(lexicon_lookup) and lexicon_lookup[i])
                    else unk_word
                )
            seq = tok_seq
            if adapt_ents:  # replace ENT_* with entities from ents, or sub_ent_probs/UNK as fallback
                ent_idx = min(seq_idx + row_idx, len(ents) - 1) if ents else 0
                seq_ents = ents[ent_idx] if ents else {}
                seq = adapt_tok_seq_ents(
                    seq, ents=seq_ents, sub_ent_probs=sub_ent_probs or {})
            if detokenize:  # apply rules for transforming token list into formatted sequence
                if ents and capitalize_ents:
                    ent_idx = min(seq_idx + row_idx,
                                  len(ents) - 1) if ents else 0
                    seq = detokenize_tok_seq(
                        encoder, seq, ents=ents[ent_idx], begin_sentence=begin_sentence)
                else:
                    seq = detokenize_tok_seq(
                        encoder, seq, ents=[], begin_sentence=begin_sentence)
            else:
                # otherwise just join tokens with whitespace between each
                seq = " ".join(seq)
            if eos_tokens:  # if filter_n_sents is a number, filter generated sequence to only the first N=filter_n_sents sentences
                seq = filter_gen_seq(encoder, seq, eos_tokens=eos_tokens)
            elif max_new_sents:
                seq = filter_gen_seq(encoder, seq, n_sents=max_new_sents)
            decoded_seqs.append(seq)
    return decoded_seqs


def adapt_tok_seq_ents(seq, ents={}, sub_ent_probs={}):

    # reverse ents so that types map to names
    ents = {ent_type: ent for ent, ent_type in ents.items()}
    adapted_seq_ents = {"_".join(token.split("_")[1:]): None
                        for token in seq if token.startswith('ENT_')}

    if not adapted_seq_ents:
        return seq

    for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}:
        if seq_ent_type in ents:
            adapted_seq_ents[seq_ent_type] = ents[seq_ent_type]
            del ents[seq_ent_type]

    if ents:
        for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}:
            for ent_type, ent in ents.items():
                if seq_ent_type.split("_")[0] in ent_type.split("_")[0]:
                    adapted_seq_ents[seq_ent_type] = ents[ent_type]
                    del ents[ent_type]
                    break

    for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}:
        if seq_ent_type.split("_")[0] in sub_ent_probs:
            sub_ents, sub_probs = zip(
                *sub_ent_probs[seq_ent_type.split("_")[0]].items())
            rand_ent_idx = _rng.choice(len(sub_ents), p=np.array(sub_probs))
            adapted_seq_ents[seq_ent_type] = sub_ents[rand_ent_idx]

    # Use ANY available entity (any type) when no type-specific match found
    all_entities = list(ents.values())
    for base_type, type_ents in sub_ent_probs.items():
        all_entities.extend(type_ents.keys())
    for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}:
        if all_entities:
            adapted_seq_ents[seq_ent_type] = _rng.choice(all_entities)
        else:
            adapted_seq_ents[seq_ent_type] = "ENT_" + seq_ent_type

    seq = [adapted_seq_ents["_".join(token.split("_")[1:])] if "_".join(
        token.split("_")[1:]) in adapted_seq_ents else token for token in seq]
    return seq


def detokenize_tok_seq(encoder, seq, ents=[], begin_sentence=True):
    '''use simple rules for transforming list of tokens back into string
    ents is optional list of words (named entities) that should be capitalized'''
    seq = [sent.split() for sent
           in segment(encoder, " ".join(seq))]  # split sequence into sentences
    detok_seq = []
    for sent_idx, sent in enumerate(seq):

        assert (type(sent) in (list, tuple))

        if ents:
            token_idx = 0
            # capitalize all tokens that appear in cap_ents
            while token_idx < len(sent):
                for ent in ents:
                    ent = ent.split()
                    if sent[token_idx:token_idx + len(ent)] == [token.lower() for token in ent]:
                        # import pdb;pdb.set_trace()
                        sent[token_idx:token_idx + len(ent)] = list(ent)
                        token_idx += len(ent) - 1
                        break
                token_idx += 1

        detok_sent = " ".join(sent)

        detok_sent = re.sub("\'", "'", detok_sent)

        # capitalize first-person "I" pronoun
        detok_sent = re.sub(r"(^| )i ", r"\1I ", detok_sent)

        # rules for contractions (pattern: raw string for \s; replacement: no backslash)
        detok_sent = re.sub(r" n'\s*t ", "n't ", detok_sent)
        detok_sent = re.sub(r" '\s*d ", "'d ", detok_sent)
        detok_sent = re.sub(r" '\s*s ", "'s ", detok_sent)
        detok_sent = re.sub(r" '\s*ve ", "'ve ", detok_sent)
        detok_sent = re.sub(r" '\s*ll ", "'ll ", detok_sent)
        detok_sent = re.sub(r" '\s*m ", "'m ", detok_sent)
        detok_sent = re.sub(r" '\s*re ", "'re ", detok_sent)

        # rules for formatting punctuation
        detok_sent = re.sub(" \.", ".", detok_sent)
        detok_sent = re.sub(" \!", "!", detok_sent)
        detok_sent = re.sub(" \?", "?", detok_sent)
        detok_sent = re.sub(" ,", ",", detok_sent)
        detok_sent = re.sub(" \- ", "-", detok_sent)
        detok_sent = re.sub(" :", ":", detok_sent)
        detok_sent = re.sub(" ;", ";", detok_sent)
        detok_sent = re.sub("\$ ", "$", detok_sent)
        detok_sent = re.sub("\' \'", "\'\'", detok_sent)
        detok_sent = re.sub("\` \`", "\`\`", detok_sent)

        # replace repeated single quotes with double quotation mark.
        detok_sent = re.sub("\'\'", "\"", detok_sent)
        detok_sent = re.sub("\`\`", "\"", detok_sent)

        # filter repetitive characters
        detok_sent = re.sub(r'(["\']\s*){2,}', '" ', detok_sent)

        # map each opening puncutation mark to closing mark
        punc_pairs = {"\'": "\'", "\'": "\'",
                      "`": "\'", "\"": "\"", "(": ")", "[": "]"}
        open_punc = []
        char_idx = 0
        while char_idx < len(detok_sent):  # check for quotes and parenthesis
            char = detok_sent[char_idx]
            # end quote/parenthesis
            if open_punc and char == punc_pairs[open_punc[-1]]:
                if char_idx > 0 and detok_sent[char_idx - 1] == " ":
                    detok_sent = detok_sent[:char_idx -
                                            1] + detok_sent[char_idx:]
                    open_punc.pop()
            elif char in punc_pairs:
                if char_idx < len(detok_sent) - 1 and detok_sent[char_idx + 1] == " ":
                    open_punc.append(char)
                    detok_sent = detok_sent[:char_idx +
                                            1] + detok_sent[char_idx + 2:]
            if char_idx < len(detok_sent) and detok_sent[char_idx] == char:
                char_idx += 1

        detok_sent = detok_sent.strip()
        # capitalize first alphabetic character if begin_sentence is True
        if begin_sentence:
            for char_idx, char in enumerate(detok_sent):
                if char.isalpha():
                    detok_sent = detok_sent[:char_idx +
                                            1].upper() + detok_sent[char_idx + 1:]
                    break
        detok_seq.append(detok_sent)

    detok_seq = " ".join(detok_seq)
    contraction_patterns = ("'s", "'re", "'ve", "'d", "'ll", "'m", "n't")
    punctuation_patterns = (".", "!", "?", ",", "-", ":", ";", ")", "]")
    # Only prepend space if detok_seq doesn't start with these
    starts_with_pattern = detok_seq.startswith(
        contraction_patterns) or detok_seq.startswith(punctuation_patterns)
    if not starts_with_pattern and detok_seq:
        detok_seq = " " + detok_seq
    return detok_seq


def filter_gen_seq(encoder, seq, n_sents=1, eos_tokens=[]):
    '''given a generated sequence, filter so that only the first n_sents are included in final generated sequence'''
    leading_space = seq.startswith(" ") if seq else False
    if eos_tokens:  # if end-of-sentence tokens given, cut off sequence at first occurrence of one of these tokens; otherwise use segmenter to infer sentence boundaries
        doc = encoder(seq)
        for idx, word in enumerate(doc):
            wtext = getattr(word, 'text', getattr(
                word, 'string', str(word))).strip()
            if wtext in eos_tokens:
                span = doc[:idx + 1]
                seq = getattr(span, 'text', getattr(
                    span, 'string', str(span))).strip()
                break
        else:
            seq = getattr(doc, 'text', getattr(doc, 'string', str(doc)))
    else:
        sentences = segment(encoder, seq)
        n = n_sents
        seq = ""
        while n <= len(sentences):
            seq = " ".join(sentences[:n]).strip()
            if seq:
                break
            n += 1
        if not seq and sentences:
            seq = " ".join(sentences).strip()
    if leading_space and seq:
        seq = " " + seq.lstrip()
    return seq