import json import numpy as np import re import string import spacy import nltk from rank_bm25 import BM25Okapi import os from concurrent.futures import ThreadPoolExecutor from nltk.tokenize import sent_tokenize nltk.download("punkt") class Atomizer(object): def __init__(self, client, demo_dir): self.nlp = spacy.load("en_core_web_sm") self.is_bio = True self.demo_path = os.path.join(demo_dir, "demos.json" if self.is_bio else "demos_complex.json") self.client = client # get the demos with open(self.demo_path, 'r') as f: self.demos = json.load(f) tokenized_corpus = [doc.split(" ") for doc in self.demos.keys()] self.bm25 = BM25Okapi(tokenized_corpus) def save_cache(self): self.client.save_cache() def run(self, generation, cost_estimate=None): """Convert the generation into a set of atomic facts. Return a total words cost if cost_estimate != None.""" assert isinstance(generation, str), "generation must be a string" paragraphs = [para.strip() for para in generation.split("\n") if len(para.strip()) > 0] return self.get_atomic_facts_from_paragraph(paragraphs, cost_estimate=cost_estimate) def get_atomic_facts_from_paragraph(self, paragraphs, cost_estimate=None): sentences = [] para_breaks = [] for para_idx, paragraph in enumerate(paragraphs): if para_idx > 0 : para_breaks.append(len(sentences)) initials = detect_initials(paragraph) curr_sentences = sent_tokenize(paragraph) curr_sentences_2 = sent_tokenize(paragraph) curr_sentences = fix_sentence_splitter(curr_sentences, initials) curr_sentences_2 = fix_sentence_splitter(curr_sentences_2, initials) # checking this, just to ensure the crediability of the sentence splitter fixing algorithm assert curr_sentences == curr_sentences_2, (paragraph, curr_sentences, curr_sentences_2) sentences += curr_sentences atoms_or_estimate = self.get_init_atomic_facts_from_sentence([sent for i, sent in enumerate(sentences) if not (not self.is_bio and ( \ (i==0 and (sent.startswith("Sure") or sent.startswith("Here are"))) or \ (i==len(sentences)-1 and (sent.startswith("Please") or sent.startswith("I hope") or sent.startswith("Here are")))))], cost_estimate=cost_estimate) if cost_estimate: return atoms_or_estimate else: atoms = atoms_or_estimate atomic_facts_pairs = [] for i, sent in enumerate(sentences): if not self.is_bio and ( \ (i==0 and (sent.startswith("Sure") or sent.startswith("Here are"))) or \ (i==len(sentences)-1 and (sent.startswith("Please") or sent.startswith("I hope") or sent.startswith("Here are")))): atomic_facts_pairs.append((sent, [])) elif self.is_bio and sent.startswith("This sentence does not contain any facts"): atomic_facts_pairs.append((sent, [])) elif sent.startswith("Sure") or sent.startswith("Please") or (i==0 and sent.startswith("Here are")): atomic_facts_pairs.append((sent, [])) else: atomic_facts_pairs.append((sent, atoms[sent])) # postprocess_atomic_facts will fix minor issues from InstructGPT # it is supposed to handle sentence splitter issue too, but since here # we fixed sentence splitter issue already, # the new para_breaks should be identical to the original para_breaks if self.is_bio: atomic_facts_pairs, para_breaks = postprocess_atomic_facts(atomic_facts_pairs, list(para_breaks), self.nlp) return atomic_facts_pairs, para_breaks def get_init_atomic_facts_from_sentence(self, sentences, cost_estimate=None): """Get the initial atomic facts from the sentences. Return a total words cost if cost_estimate != None.""" is_bio = self.is_bio demos = self.demos k = 1 if is_bio else 0 n = 7 if is_bio else 8 prompts = [] prompt_to_sent = {} atoms = {} for sentence in sentences: if sentence in atoms: continue top_matchings = best_demos(sentence, self.bm25, list(demos.keys()), k) prompt = "" for i in range(n): prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(list(demos.keys())[i]) for fact in demos[list(demos.keys())[i]]: prompt = prompt + "- {}\n".format(fact) prompt = prompt + "\n" for match in top_matchings: prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(match) for fact in demos[match]: prompt = prompt + "- {}\n".format(fact) prompt = prompt + "\n" prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(sentence) prompts.append(prompt) prompt_to_sent[prompt] = sentence if cost_estimate: total_words_estimate = 0 for prompt in prompts: if cost_estimate == "consider_cache" and (prompt.strip() + "_0") in self.client.cache_dict: continue total_words_estimate += len(prompt.split()) return total_words_estimate else: outputs = [] with ThreadPoolExecutor(max_workers=len(prompts)) as executor: outputs = list( executor.map( lambda x : self.client.query(x), prompts ) ) for prompt, output in zip(prompts, outputs): atoms[prompt_to_sent[prompt]] = text_to_sentences(output[0]['message']) # for prompt in prompts: # output = self.client.query(prompt) # outputs.append(output) # atoms[prompt_to_sent[prompt]] = text_to_sentences(output[0]['message']) self.client.cache_outputs( prompts=prompts, sample_indices=np.zeros((len(prompts),), dtype=int), outputs=outputs ) for key, value in demos.items(): if key not in atoms: atoms[key] = value return atoms def best_demos(query, bm25, demos_sents, k): tokenized_query = query.split(" ") top_matchings = bm25.get_top_n(tokenized_query, demos_sents, k) return top_matchings # transform InstructGPT output into sentences def text_to_sentences(text): sentences = text.split("- ")[1:] sentences = [sent.strip()[:-1] if sent.strip()[-1] == '\n' else sent.strip() for sent in sentences] if len(sentences) > 0: if sentences[-1][-1] != '.': sentences[-1] = sentences[-1] + '.' else: sentences = [] return sentences def normalize_answer(s): """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) return re.sub(regex, ' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return ''.join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) MONTHS = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"] MONTHS = [m.lower() for m in MONTHS] def is_num(text): try: text = int(text) return True except Exception: return False def is_date(text): text = normalize_answer(text) for token in text.split(" "): if (not is_num(token)) and token not in MONTHS: return False return True def extract_numeric_values(text): pattern = r'\b\d+\b' # regular expression pattern for integers numeric_values = re.findall(pattern, text) # find all numeric values in the text return set([value for value in numeric_values]) # convert the values to float and return as a list def detect_entities(text, nlp): doc = nlp(text) entities = set() def _add_to_entities(text): if "-" in text: for _text in text.split("-"): entities.add(_text.strip()) else: entities.add(text) for ent in doc.ents: # spacy often has errors with other types of entities if ent.label_ in ["DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"]: if is_date(ent.text): _add_to_entities(ent.text) else: for token in ent.text.split(): if is_date(token): _add_to_entities(token) for new_ent in extract_numeric_values(text): if not np.any([new_ent in ent for ent in entities]): entities.add(new_ent) return entities def postprocess_atomic_facts(_atomic_facts, para_breaks, nlp): verbs = ["born.", " appointed.", " characterized.", " described.", " known.", " member.", " advocate.", "served.", "elected."] permitted_verbs = ["founding member."] atomic_facts = [] new_atomic_facts = [] new_para_breaks = [] for i, (sent, facts) in enumerate(_atomic_facts): sent = sent.strip() if len(sent.split())==1 and i not in para_breaks and i > 0: assert i not in para_breaks atomic_facts[-1][0] += " " + sent atomic_facts[-1][1] += facts else: if i in para_breaks: new_para_breaks.append(len(atomic_facts)) atomic_facts.append([sent, facts]) for i, (sent, facts) in enumerate(atomic_facts): entities = detect_entities(sent, nlp) covered_entities = set() # print (entities) new_facts = [] for i, fact in enumerate(facts): if any([fact.endswith(verb) for verb in verbs]) and not any([fact.endswith(verb) for verb in permitted_verbs]): if any([fact[:-1] in other_fact for j, other_fact in enumerate(facts) if j != i]): continue sent_entities = detect_entities(fact, nlp) covered_entities |= set([e for e in sent_entities if e in entities]) new_entities = sent_entities - entities if len(new_entities) > 0: do_pass = False for new_ent in new_entities: pre_ent = None for ent in entities: if ent.startswith(new_ent): pre_ent = ent break if pre_ent is None: do_pass = True break fact = fact.replace(new_ent, pre_ent) covered_entities.add(pre_ent) if do_pass: continue if fact in new_facts: continue new_facts.append(fact) try: assert entities==covered_entities except Exception: new_facts = facts # there is a bug in spacy entity linker, so just go with the previous facts new_atomic_facts.append((sent, new_facts)) return new_atomic_facts, new_para_breaks def is_integer(s): try: s = int(s) return True except Exception: return False def detect_initials(text): pattern = r"[A-Z]\. ?[A-Z]\." match = re.findall(pattern, text) return [m for m in match] def fix_sentence_splitter(curr_sentences, initials): for initial in initials: if not np.any([initial in sent for sent in curr_sentences]): alpha1, alpha2 = [t.strip() for t in initial.split(".") if len(t.strip())>0] for i, (sent1, sent2) in enumerate(zip(curr_sentences, curr_sentences[1:])): if sent1.endswith(alpha1 + ".") and sent2.startswith(alpha2 + "."): # merge sentence i and i+1 curr_sentences = curr_sentences[:i] + [curr_sentences[i] + " " + curr_sentences[i+1]] + curr_sentences[i+2:] break sentences = [] combine_with_previous = None for sent_idx, sent in enumerate(curr_sentences): if len(sent.split())<=1 and sent_idx==0: assert not combine_with_previous combine_with_previous = True sentences.append(sent) elif len(sent.split())<=1: assert sent_idx > 0 sentences[-1] += " " + sent combined_with_previous = False elif sent[0].isalpha() and not sent[0].isupper() and sent_idx > 0: assert sent_idx > 0, curr_sentences sentences[-1] += " " + sent combine_with_previous = False elif combine_with_previous: assert sent_idx > 0 sentences[-1] += " " + sent combine_with_previous = False else: assert not combine_with_previous sentences.append(sent) return sentences