| | """Tokenization utilities for RNNLM - entity extraction, replacement, and decoding.""" |
| |
|
| | import re |
| | import numpy as np |
| |
|
| | |
| | _rng = np.random.RandomState(0) |
| |
|
| |
|
| | def segment(encoder, seq): |
| | doc = encoder(seq) |
| | return [getattr(sent, 'text', getattr(sent, 'string', str(sent))).strip() for sent in doc.sents] |
| |
|
| |
|
| | def tokenize(encoder, seq, lowercase=True, recognize_ents=False, |
| | lemmatize=False, include_tags=[], include_pos=[], prepend_start=False): |
| | seq = encoder(seq) |
| | if recognize_ents: |
| | ent_start_idxs = {ent.start: ent for ent in seq.ents |
| | if getattr(ent, 'text', getattr(ent, 'string', '')).strip()} |
| | |
| | seq = [ent_start_idxs[word_idx] if word_idx in ent_start_idxs else word |
| | for word_idx, word in enumerate(seq) |
| | if (not word.ent_type_ or word_idx in ent_start_idxs)] |
| |
|
| | def _wtext(w): |
| | return getattr(w, 'text', getattr(w, 'string', str(w))).strip() |
| |
|
| | |
| | if include_tags: |
| | seq = [word for word in seq |
| | if ("_" in _wtext(word) or word.tag_ in include_tags)] |
| | if include_pos: |
| | seq = [word for word in seq |
| | if ("_" in _wtext(word) or word.pos_ in include_pos)] |
| | if lemmatize: |
| | seq = [word.lemma_ if not _wtext(word).startswith('ENT_') |
| | else _wtext(word) for word in seq] |
| | |
| | elif lowercase: |
| | seq = [_wtext(word).lower() if not _wtext(word).startswith('ENT_') |
| | else _wtext(word) for word in seq] |
| | else: |
| | seq = [_wtext(word) for word in seq] |
| | |
| | seq = [word for word in seq if word] |
| | if prepend_start: |
| | seq.insert(0, u"<START>") |
| | return seq |
| |
|
| |
|
| | def ent_counts_to_probs(ent_counts): |
| | """Convert entity counts to probabilities for sampling when adapting entities.""" |
| | return {ent_type: {ent: count * 1.0 / sum(counts.values()) |
| | for ent, count in counts.items()} |
| | for ent_type, counts in ent_counts.items()} |
| |
|
| |
|
| | def get_ents(encoder, seq, include_ent_types=('PERSON', 'NORP', 'ORG', 'GPE')): |
| | '''return dict of all entities in seq mapped to their entity types, optionally labeled with gender for PERSON entities''' |
| |
|
| | ents = {} |
| | ent_counts = {} |
| | for ent in encoder(seq).ents: |
| | ent_type = ent.label_ |
| | if ent_type in include_ent_types: |
| | ent = getattr(ent, 'text', getattr( |
| | ent, 'string', str(ent))).strip() |
| | if ent: |
| | ents[ent] = [ent_type] |
| | if ent in ent_counts: |
| | ent_counts[ent] += 1 |
| | else: |
| | ent_counts[ent] = 1 |
| | ents[ent] = "_".join(ents[ent]) |
| | return ents, ent_counts |
| |
|
| |
|
| | def number_ents(encoder, ents, ent_counts): |
| | '''return dict of all entities in seq mapped to their entity types, |
| | with numerical suffixes to distinguish entities of the same type''' |
| | ent_counts = sorted([(count, ent, ents[ent]) |
| | for ent, count in ent_counts.items()])[::-1] |
| | ent_type_counts = {} |
| | num_ents = {} |
| | for count, ent, ent_type in ent_counts: |
| | tok_ent = tokenize(encoder, ent, lowercase=False) |
| | coref_ent = [num_ent for num_ent in num_ents |
| | if (tokenize(encoder, num_ent, lowercase=False)[0] == tok_ent[0] |
| | or tokenize(encoder, num_ent, lowercase=False)[-1] == tok_ent[-1]) |
| | |
| | and ents[num_ent] == ent_type] |
| | if coref_ent: |
| | num_ents[ent] = num_ents[coref_ent[0]] |
| | else: |
| | ent_type = ent_type.split("_") |
| | if ent_type[0] in ent_type_counts: |
| | ent_type_counts[ent_type[0]] += 1 |
| | else: |
| | ent_type_counts[ent_type[0]] = 1 |
| | num_ents[ent] = ent_type |
| | |
| | num_ents[ent].insert(1, str(ent_type_counts[ent_type[0]] - 1)) |
| | num_ents[ent] = "_".join(num_ents[ent]) |
| | return num_ents |
| |
|
| |
|
| | def replace_ents_in_seq(encoder, seq): |
| | '''extract entities from seq and replace them with their entity types''' |
| | ents, ent_counts = get_ents(encoder, seq) |
| | ents = number_ents(encoder, ents, ent_counts) |
| | seq = tokenize(encoder, seq, lowercase=False, recognize_ents=True) |
| | |
| |
|
| | def _text(w): |
| | return (getattr(w, 'text', None) or getattr(w, 'string', None) or str(w)).strip() |
| | seq = ['ENT_' + ents[_text(word)] if _text(word) |
| | in ents else _text(word) for word in seq] |
| | seq = " ".join(seq) |
| | return seq |
| |
|
| |
|
| | def decode_num_seqs(encoder, lexicon_lookup, unk_word, seqs, max_new_sents=None, eos_tokens=[], |
| | detokenize=False, ents=[], capitalize_ents=False, adapt_ents=False, |
| | sub_ent_probs=None, begin_sentence=True): |
| | if not seqs: |
| | return [] |
| | if type(seqs[0]) not in (list, np.ndarray, tuple): |
| | seqs = [seqs] |
| | decoded_seqs = [] |
| | |
| | for seq_idx, seq in enumerate(seqs): |
| | |
| | if hasattr(seq, 'cpu'): |
| | seq = seq.cpu() |
| | if hasattr(seq, 'tolist'): |
| | seq = seq.tolist() |
| | elif seq and hasattr(seq[0], 'tolist'): |
| | |
| | seq = [row.tolist() for row in seq] |
| | else: |
| | seq = list(seq) |
| | |
| | if seq and isinstance(seq[0], list): |
| | rows = seq |
| | else: |
| | rows = [seq] |
| |
|
| | def _to_int(x): |
| | if isinstance(x, (list, tuple)): |
| | return [_to_int(v) for v in x] |
| | return int(x.item()) if hasattr(x, 'item') else int(x) |
| |
|
| | for row_idx, row in enumerate(rows): |
| | tok_seq = [] |
| | flat_row = _to_int(row) if isinstance( |
| | row, (list, tuple)) else [_to_int(row)] |
| | if isinstance(flat_row[0], list): |
| | flat_row = [v for sub in flat_row for v in ( |
| | sub if isinstance(sub, list) else [sub])] |
| | for w in flat_row: |
| | i = w if isinstance(w, int) else int(w) |
| | tok_seq.append( |
| | lexicon_lookup[i] if (0 <= i < len(lexicon_lookup) and lexicon_lookup[i]) |
| | else unk_word |
| | ) |
| | seq = tok_seq |
| | if adapt_ents: |
| | ent_idx = min(seq_idx + row_idx, len(ents) - 1) if ents else 0 |
| | seq_ents = ents[ent_idx] if ents else {} |
| | seq = adapt_tok_seq_ents( |
| | seq, ents=seq_ents, sub_ent_probs=sub_ent_probs or {}) |
| | if detokenize: |
| | if ents and capitalize_ents: |
| | ent_idx = min(seq_idx + row_idx, |
| | len(ents) - 1) if ents else 0 |
| | seq = detokenize_tok_seq( |
| | encoder, seq, ents=ents[ent_idx], begin_sentence=begin_sentence) |
| | else: |
| | seq = detokenize_tok_seq( |
| | encoder, seq, ents=[], begin_sentence=begin_sentence) |
| | else: |
| | |
| | seq = " ".join(seq) |
| | if eos_tokens: |
| | seq = filter_gen_seq(encoder, seq, eos_tokens=eos_tokens) |
| | elif max_new_sents: |
| | seq = filter_gen_seq(encoder, seq, n_sents=max_new_sents) |
| | decoded_seqs.append(seq) |
| | return decoded_seqs |
| |
|
| |
|
| | def adapt_tok_seq_ents(seq, ents={}, sub_ent_probs={}): |
| |
|
| | |
| | ents = {ent_type: ent for ent, ent_type in ents.items()} |
| | adapted_seq_ents = {"_".join(token.split("_")[1:]): None |
| | for token in seq if token.startswith('ENT_')} |
| |
|
| | if not adapted_seq_ents: |
| | return seq |
| |
|
| | for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}: |
| | if seq_ent_type in ents: |
| | adapted_seq_ents[seq_ent_type] = ents[seq_ent_type] |
| | del ents[seq_ent_type] |
| |
|
| | if ents: |
| | for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}: |
| | for ent_type, ent in ents.items(): |
| | if seq_ent_type.split("_")[0] in ent_type.split("_")[0]: |
| | adapted_seq_ents[seq_ent_type] = ents[ent_type] |
| | del ents[ent_type] |
| | break |
| |
|
| | for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}: |
| | if seq_ent_type.split("_")[0] in sub_ent_probs: |
| | sub_ents, sub_probs = zip( |
| | *sub_ent_probs[seq_ent_type.split("_")[0]].items()) |
| | rand_ent_idx = _rng.choice(len(sub_ents), p=np.array(sub_probs)) |
| | adapted_seq_ents[seq_ent_type] = sub_ents[rand_ent_idx] |
| |
|
| | |
| | all_entities = list(ents.values()) |
| | for base_type, type_ents in sub_ent_probs.items(): |
| | all_entities.extend(type_ents.keys()) |
| | for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}: |
| | if all_entities: |
| | adapted_seq_ents[seq_ent_type] = _rng.choice(all_entities) |
| | else: |
| | adapted_seq_ents[seq_ent_type] = "ENT_" + seq_ent_type |
| |
|
| | seq = [adapted_seq_ents["_".join(token.split("_")[1:])] if "_".join( |
| | token.split("_")[1:]) in adapted_seq_ents else token for token in seq] |
| | return seq |
| |
|
| |
|
| | def detokenize_tok_seq(encoder, seq, ents=[], begin_sentence=True): |
| | '''use simple rules for transforming list of tokens back into string |
| | ents is optional list of words (named entities) that should be capitalized''' |
| | seq = [sent.split() for sent |
| | in segment(encoder, " ".join(seq))] |
| | detok_seq = [] |
| | for sent_idx, sent in enumerate(seq): |
| |
|
| | assert (type(sent) in (list, tuple)) |
| |
|
| | if ents: |
| | token_idx = 0 |
| | |
| | while token_idx < len(sent): |
| | for ent in ents: |
| | ent = ent.split() |
| | if sent[token_idx:token_idx + len(ent)] == [token.lower() for token in ent]: |
| | |
| | sent[token_idx:token_idx + len(ent)] = list(ent) |
| | token_idx += len(ent) - 1 |
| | break |
| | token_idx += 1 |
| |
|
| | detok_sent = " ".join(sent) |
| |
|
| | detok_sent = re.sub("\'", "'", detok_sent) |
| |
|
| | |
| | detok_sent = re.sub(r"(^| )i ", r"\1I ", detok_sent) |
| |
|
| | |
| | detok_sent = re.sub(r" n'\s*t ", "n't ", detok_sent) |
| | detok_sent = re.sub(r" '\s*d ", "'d ", detok_sent) |
| | detok_sent = re.sub(r" '\s*s ", "'s ", detok_sent) |
| | detok_sent = re.sub(r" '\s*ve ", "'ve ", detok_sent) |
| | detok_sent = re.sub(r" '\s*ll ", "'ll ", detok_sent) |
| | detok_sent = re.sub(r" '\s*m ", "'m ", detok_sent) |
| | detok_sent = re.sub(r" '\s*re ", "'re ", detok_sent) |
| |
|
| | |
| | detok_sent = re.sub(" \.", ".", detok_sent) |
| | detok_sent = re.sub(" \!", "!", detok_sent) |
| | detok_sent = re.sub(" \?", "?", detok_sent) |
| | detok_sent = re.sub(" ,", ",", detok_sent) |
| | detok_sent = re.sub(" \- ", "-", detok_sent) |
| | detok_sent = re.sub(" :", ":", detok_sent) |
| | detok_sent = re.sub(" ;", ";", detok_sent) |
| | detok_sent = re.sub("\$ ", "$", detok_sent) |
| | detok_sent = re.sub("\' \'", "\'\'", detok_sent) |
| | detok_sent = re.sub("\` \`", "\`\`", detok_sent) |
| |
|
| | |
| | detok_sent = re.sub("\'\'", "\"", detok_sent) |
| | detok_sent = re.sub("\`\`", "\"", detok_sent) |
| |
|
| | |
| | detok_sent = re.sub(r'(["\']\s*){2,}', '" ', detok_sent) |
| |
|
| | |
| | punc_pairs = {"\'": "\'", "\'": "\'", |
| | "`": "\'", "\"": "\"", "(": ")", "[": "]"} |
| | open_punc = [] |
| | char_idx = 0 |
| | while char_idx < len(detok_sent): |
| | char = detok_sent[char_idx] |
| | |
| | if open_punc and char == punc_pairs[open_punc[-1]]: |
| | if char_idx > 0 and detok_sent[char_idx - 1] == " ": |
| | detok_sent = detok_sent[:char_idx - |
| | 1] + detok_sent[char_idx:] |
| | open_punc.pop() |
| | elif char in punc_pairs: |
| | if char_idx < len(detok_sent) - 1 and detok_sent[char_idx + 1] == " ": |
| | open_punc.append(char) |
| | detok_sent = detok_sent[:char_idx + |
| | 1] + detok_sent[char_idx + 2:] |
| | if char_idx < len(detok_sent) and detok_sent[char_idx] == char: |
| | char_idx += 1 |
| |
|
| | detok_sent = detok_sent.strip() |
| | |
| | if begin_sentence: |
| | for char_idx, char in enumerate(detok_sent): |
| | if char.isalpha(): |
| | detok_sent = detok_sent[:char_idx + |
| | 1].upper() + detok_sent[char_idx + 1:] |
| | break |
| | detok_seq.append(detok_sent) |
| |
|
| | detok_seq = " ".join(detok_seq) |
| | contraction_patterns = ("'s", "'re", "'ve", "'d", "'ll", "'m", "n't") |
| | punctuation_patterns = (".", "!", "?", ",", "-", ":", ";", ")", "]") |
| | |
| | starts_with_pattern = detok_seq.startswith( |
| | contraction_patterns) or detok_seq.startswith(punctuation_patterns) |
| | if not starts_with_pattern and detok_seq: |
| | detok_seq = " " + detok_seq |
| | return detok_seq |
| |
|
| |
|
| | def filter_gen_seq(encoder, seq, n_sents=1, eos_tokens=[]): |
| | '''given a generated sequence, filter so that only the first n_sents are included in final generated sequence''' |
| | leading_space = seq.startswith(" ") if seq else False |
| | if eos_tokens: |
| | doc = encoder(seq) |
| | for idx, word in enumerate(doc): |
| | wtext = getattr(word, 'text', getattr( |
| | word, 'string', str(word))).strip() |
| | if wtext in eos_tokens: |
| | span = doc[:idx + 1] |
| | seq = getattr(span, 'text', getattr( |
| | span, 'string', str(span))).strip() |
| | break |
| | else: |
| | seq = getattr(doc, 'text', getattr(doc, 'string', str(doc))) |
| | else: |
| | sentences = segment(encoder, seq) |
| | n = n_sents |
| | seq = "" |
| | while n <= len(sentences): |
| | seq = " ".join(sentences[:n]).strip() |
| | if seq: |
| | break |
| | n += 1 |
| | if not seq and sentences: |
| | seq = " ".join(sentences).strip() |
| | if leading_space and seq: |
| | seq = " " + seq.lstrip() |
| | return seq |
| |
|