"""Tokenization utilities for RNNLM - entity extraction, replacement, and decoding.""" import re import numpy as np # RNG for adapt_tok_seq_ents when sampling from sub_ent_probs _rng = np.random.RandomState(0) def segment(encoder, seq): doc = encoder(seq) return [getattr(sent, 'text', getattr(sent, 'string', str(sent))).strip() for sent in doc.sents] def tokenize(encoder, seq, lowercase=True, recognize_ents=False, lemmatize=False, include_tags=[], include_pos=[], prepend_start=False): seq = encoder(seq) if recognize_ents: # merge named entities into single tokens ent_start_idxs = {ent.start: ent for ent in seq.ents if getattr(ent, 'text', getattr(ent, 'string', '')).strip()} # combine each ent into a single token; this is pretty hard to read, but it works seq = [ent_start_idxs[word_idx] if word_idx in ent_start_idxs else word for word_idx, word in enumerate(seq) if (not word.ent_type_ or word_idx in ent_start_idxs)] def _wtext(w): return getattr(w, 'text', getattr(w, 'string', str(w))).strip() # Don't apply POS filtering to phrases (words with underscores) if include_tags: # fine-grained POS tags seq = [word for word in seq if ("_" in _wtext(word) or word.tag_ in include_tags)] if include_pos: # coarse-grained POS tags seq = [word for word in seq if ("_" in _wtext(word) or word.pos_ in include_pos)] if lemmatize: seq = [word.lemma_ if not _wtext(word).startswith('ENT_') else _wtext(word) for word in seq] # don't lowercase if token is an entity (entities will be of type span instead of token; or will be prefixed with 'ENT_' if already transformed to types) elif lowercase: seq = [_wtext(word).lower() if not _wtext(word).startswith('ENT_') else _wtext(word) for word in seq] else: seq = [_wtext(word) for word in seq] # some words may be empty strings, so filter seq = [word for word in seq if word] if prepend_start: seq.insert(0, u"") return seq def ent_counts_to_probs(ent_counts): """Convert entity counts to probabilities for sampling when adapting entities.""" return {ent_type: {ent: count * 1.0 / sum(counts.values()) for ent, count in counts.items()} for ent_type, counts in ent_counts.items()} def get_ents(encoder, seq, include_ent_types=('PERSON', 'NORP', 'ORG', 'GPE')): '''return dict of all entities in seq mapped to their entity types, optionally labeled with gender for PERSON entities''' ents = {} ent_counts = {} for ent in encoder(seq).ents: ent_type = ent.label_ if ent_type in include_ent_types: ent = getattr(ent, 'text', getattr( ent, 'string', str(ent))).strip() if ent: # not sure why, but whitespace can be detected as an ent, so need to check for this ents[ent] = [ent_type] if ent in ent_counts: ent_counts[ent] += 1 else: ent_counts[ent] = 1 ents[ent] = "_".join(ents[ent]) return ents, ent_counts def number_ents(encoder, ents, ent_counts): '''return dict of all entities in seq mapped to their entity types, with numerical suffixes to distinguish entities of the same type''' ent_counts = sorted([(count, ent, ents[ent]) for ent, count in ent_counts.items()])[::-1] ent_type_counts = {} num_ents = {} for count, ent, ent_type in ent_counts: tok_ent = tokenize(encoder, ent, lowercase=False) coref_ent = [num_ent for num_ent in num_ents if (tokenize(encoder, num_ent, lowercase=False)[0] == tok_ent[0] or tokenize(encoder, num_ent, lowercase=False)[-1] == tok_ent[-1]) # treat ents with same first or last word as co-referring and ents[num_ent] == ent_type] if coref_ent: num_ents[ent] = num_ents[coref_ent[0]] else: ent_type = ent_type.split("_") if ent_type[0] in ent_type_counts: ent_type_counts[ent_type[0]] += 1 else: ent_type_counts[ent_type[0]] = 1 num_ents[ent] = ent_type # insert number id after entity type (and before tag, if it exists) num_ents[ent].insert(1, str(ent_type_counts[ent_type[0]] - 1)) num_ents[ent] = "_".join(num_ents[ent]) return num_ents def replace_ents_in_seq(encoder, seq): '''extract entities from seq and replace them with their entity types''' ents, ent_counts = get_ents(encoder, seq) ents = number_ents(encoder, ents, ent_counts) seq = tokenize(encoder, seq, lowercase=False, recognize_ents=True) # word can be Token or Span; get text for lookup def _text(w): return (getattr(w, 'text', None) or getattr(w, 'string', None) or str(w)).strip() seq = ['ENT_' + ents[_text(word)] if _text(word) in ents else _text(word) for word in seq] seq = " ".join(seq) return seq def decode_num_seqs(encoder, lexicon_lookup, unk_word, seqs, max_new_sents=None, eos_tokens=[], detokenize=False, ents=[], capitalize_ents=False, adapt_ents=False, sub_ent_probs=None, begin_sentence=True): if not seqs: return [] if type(seqs[0]) not in (list, np.ndarray, tuple): seqs = [seqs] decoded_seqs = [] # transform numerical seq back into string (seq elements are token IDs) for seq_idx, seq in enumerate(seqs): # Flatten to list of Python ints (handles 2D tensors from model.generate, e.g. (1, seq_len)) if hasattr(seq, 'cpu'): seq = seq.cpu() if hasattr(seq, 'tolist'): seq = seq.tolist() elif seq and hasattr(seq[0], 'tolist'): # list(tensor) gives list of row tensors - convert each to list seq = [row.tolist() for row in seq] else: seq = list(seq) # If 2D (batch, seq_len), take each row; else single sequence if seq and isinstance(seq[0], list): rows = seq else: rows = [seq] def _to_int(x): if isinstance(x, (list, tuple)): return [_to_int(v) for v in x] return int(x.item()) if hasattr(x, 'item') else int(x) for row_idx, row in enumerate(rows): tok_seq = [] flat_row = _to_int(row) if isinstance( row, (list, tuple)) else [_to_int(row)] if isinstance(flat_row[0], list): flat_row = [v for sub in flat_row for v in ( sub if isinstance(sub, list) else [sub])] for w in flat_row: i = w if isinstance(w, int) else int(w) tok_seq.append( lexicon_lookup[i] if (0 <= i < len(lexicon_lookup) and lexicon_lookup[i]) else unk_word ) seq = tok_seq if adapt_ents: # replace ENT_* with entities from ents, or sub_ent_probs/UNK as fallback ent_idx = min(seq_idx + row_idx, len(ents) - 1) if ents else 0 seq_ents = ents[ent_idx] if ents else {} seq = adapt_tok_seq_ents( seq, ents=seq_ents, sub_ent_probs=sub_ent_probs or {}) if detokenize: # apply rules for transforming token list into formatted sequence if ents and capitalize_ents: ent_idx = min(seq_idx + row_idx, len(ents) - 1) if ents else 0 seq = detokenize_tok_seq( encoder, seq, ents=ents[ent_idx], begin_sentence=begin_sentence) else: seq = detokenize_tok_seq( encoder, seq, ents=[], begin_sentence=begin_sentence) else: # otherwise just join tokens with whitespace between each seq = " ".join(seq) if eos_tokens: # if filter_n_sents is a number, filter generated sequence to only the first N=filter_n_sents sentences seq = filter_gen_seq(encoder, seq, eos_tokens=eos_tokens) elif max_new_sents: seq = filter_gen_seq(encoder, seq, n_sents=max_new_sents) decoded_seqs.append(seq) return decoded_seqs def adapt_tok_seq_ents(seq, ents={}, sub_ent_probs={}): # reverse ents so that types map to names ents = {ent_type: ent for ent, ent_type in ents.items()} adapted_seq_ents = {"_".join(token.split("_")[1:]): None for token in seq if token.startswith('ENT_')} if not adapted_seq_ents: return seq for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}: if seq_ent_type in ents: adapted_seq_ents[seq_ent_type] = ents[seq_ent_type] del ents[seq_ent_type] if ents: for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}: for ent_type, ent in ents.items(): if seq_ent_type.split("_")[0] in ent_type.split("_")[0]: adapted_seq_ents[seq_ent_type] = ents[ent_type] del ents[ent_type] break for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}: if seq_ent_type.split("_")[0] in sub_ent_probs: sub_ents, sub_probs = zip( *sub_ent_probs[seq_ent_type.split("_")[0]].items()) rand_ent_idx = _rng.choice(len(sub_ents), p=np.array(sub_probs)) adapted_seq_ents[seq_ent_type] = sub_ents[rand_ent_idx] # Use ANY available entity (any type) when no type-specific match found all_entities = list(ents.values()) for base_type, type_ents in sub_ent_probs.items(): all_entities.extend(type_ents.keys()) for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}: if all_entities: adapted_seq_ents[seq_ent_type] = _rng.choice(all_entities) else: adapted_seq_ents[seq_ent_type] = "ENT_" + seq_ent_type seq = [adapted_seq_ents["_".join(token.split("_")[1:])] if "_".join( token.split("_")[1:]) in adapted_seq_ents else token for token in seq] return seq def detokenize_tok_seq(encoder, seq, ents=[], begin_sentence=True): '''use simple rules for transforming list of tokens back into string ents is optional list of words (named entities) that should be capitalized''' seq = [sent.split() for sent in segment(encoder, " ".join(seq))] # split sequence into sentences detok_seq = [] for sent_idx, sent in enumerate(seq): assert (type(sent) in (list, tuple)) if ents: token_idx = 0 # capitalize all tokens that appear in cap_ents while token_idx < len(sent): for ent in ents: ent = ent.split() if sent[token_idx:token_idx + len(ent)] == [token.lower() for token in ent]: # import pdb;pdb.set_trace() sent[token_idx:token_idx + len(ent)] = list(ent) token_idx += len(ent) - 1 break token_idx += 1 detok_sent = " ".join(sent) detok_sent = re.sub("\'", "'", detok_sent) # capitalize first-person "I" pronoun detok_sent = re.sub(r"(^| )i ", r"\1I ", detok_sent) # rules for contractions (pattern: raw string for \s; replacement: no backslash) detok_sent = re.sub(r" n'\s*t ", "n't ", detok_sent) detok_sent = re.sub(r" '\s*d ", "'d ", detok_sent) detok_sent = re.sub(r" '\s*s ", "'s ", detok_sent) detok_sent = re.sub(r" '\s*ve ", "'ve ", detok_sent) detok_sent = re.sub(r" '\s*ll ", "'ll ", detok_sent) detok_sent = re.sub(r" '\s*m ", "'m ", detok_sent) detok_sent = re.sub(r" '\s*re ", "'re ", detok_sent) # rules for formatting punctuation detok_sent = re.sub(" \.", ".", detok_sent) detok_sent = re.sub(" \!", "!", detok_sent) detok_sent = re.sub(" \?", "?", detok_sent) detok_sent = re.sub(" ,", ",", detok_sent) detok_sent = re.sub(" \- ", "-", detok_sent) detok_sent = re.sub(" :", ":", detok_sent) detok_sent = re.sub(" ;", ";", detok_sent) detok_sent = re.sub("\$ ", "$", detok_sent) detok_sent = re.sub("\' \'", "\'\'", detok_sent) detok_sent = re.sub("\` \`", "\`\`", detok_sent) # replace repeated single quotes with double quotation mark. detok_sent = re.sub("\'\'", "\"", detok_sent) detok_sent = re.sub("\`\`", "\"", detok_sent) # filter repetitive characters detok_sent = re.sub(r'(["\']\s*){2,}', '" ', detok_sent) # map each opening puncutation mark to closing mark punc_pairs = {"\'": "\'", "\'": "\'", "`": "\'", "\"": "\"", "(": ")", "[": "]"} open_punc = [] char_idx = 0 while char_idx < len(detok_sent): # check for quotes and parenthesis char = detok_sent[char_idx] # end quote/parenthesis if open_punc and char == punc_pairs[open_punc[-1]]: if char_idx > 0 and detok_sent[char_idx - 1] == " ": detok_sent = detok_sent[:char_idx - 1] + detok_sent[char_idx:] open_punc.pop() elif char in punc_pairs: if char_idx < len(detok_sent) - 1 and detok_sent[char_idx + 1] == " ": open_punc.append(char) detok_sent = detok_sent[:char_idx + 1] + detok_sent[char_idx + 2:] if char_idx < len(detok_sent) and detok_sent[char_idx] == char: char_idx += 1 detok_sent = detok_sent.strip() # capitalize first alphabetic character if begin_sentence is True if begin_sentence: for char_idx, char in enumerate(detok_sent): if char.isalpha(): detok_sent = detok_sent[:char_idx + 1].upper() + detok_sent[char_idx + 1:] break detok_seq.append(detok_sent) detok_seq = " ".join(detok_seq) contraction_patterns = ("'s", "'re", "'ve", "'d", "'ll", "'m", "n't") punctuation_patterns = (".", "!", "?", ",", "-", ":", ";", ")", "]") # Only prepend space if detok_seq doesn't start with these starts_with_pattern = detok_seq.startswith( contraction_patterns) or detok_seq.startswith(punctuation_patterns) if not starts_with_pattern and detok_seq: detok_seq = " " + detok_seq return detok_seq def filter_gen_seq(encoder, seq, n_sents=1, eos_tokens=[]): '''given a generated sequence, filter so that only the first n_sents are included in final generated sequence''' leading_space = seq.startswith(" ") if seq else False if eos_tokens: # if end-of-sentence tokens given, cut off sequence at first occurrence of one of these tokens; otherwise use segmenter to infer sentence boundaries doc = encoder(seq) for idx, word in enumerate(doc): wtext = getattr(word, 'text', getattr( word, 'string', str(word))).strip() if wtext in eos_tokens: span = doc[:idx + 1] seq = getattr(span, 'text', getattr( span, 'string', str(span))).strip() break else: seq = getattr(doc, 'text', getattr(doc, 'string', str(doc))) else: sentences = segment(encoder, seq) n = n_sents seq = "" while n <= len(sentences): seq = " ".join(sentences[:n]).strip() if seq: break n += 1 if not seq and sentences: seq = " ".join(sentences).strip() if leading_space and seq: seq = " " + seq.lstrip() return seq