Spaces:
Build error
Build error
| # -*- coding:utf-8 -*- | |
| """ | |
| @Author : Bao | |
| @Date : 2020/9/17 | |
| @Desc : Document selection and sentence ranking code from KGAT. Not used in LOREN. | |
| @Last modified by : Bao | |
| @Last modified date : 2020/9/17 | |
| """ | |
| import re | |
| import time | |
| import json | |
| import nltk | |
| from tqdm import tqdm | |
| from allennlp.predictors import Predictor | |
| from drqa.retriever import DocDB, utils | |
| from drqa.retriever.utils import normalize | |
| import wikipedia | |
| class FeverDocDB(DocDB): | |
| def __init__(self, path=None): | |
| super().__init__(path) | |
| def get_doc_lines(self, doc_id): | |
| """Fetch the raw text of the doc for 'doc_id'.""" | |
| cursor = self.connection.cursor() | |
| cursor.execute( | |
| "SELECT lines FROM documents WHERE id = ?", | |
| (utils.normalize(doc_id),) | |
| ) | |
| result = cursor.fetchone() | |
| cursor.close() | |
| result = result[0] if result is not None else '' | |
| doc_lines = [] | |
| for line in result.split('\n'): | |
| if len(line) == 0: continue | |
| line = line.split('\t')[1] | |
| if len(line) == 0: continue | |
| doc_lines.append((doc_id, len(doc_lines), line, 0)) | |
| return doc_lines | |
| def get_non_empty_doc_ids(self): | |
| """Fetch all ids of docs stored in the db.""" | |
| cursor = self.connection.cursor() | |
| cursor.execute("SELECT id FROM documents WHERE length(trim(text)) > 0") | |
| results = [r[0] for r in cursor.fetchall()] | |
| cursor.close() | |
| return results | |
| class DocRetrieval: | |
| def __init__(self, database_path, add_claim=False, k_wiki_results=None): | |
| self.db = FeverDocDB(database_path) | |
| self.add_claim = add_claim | |
| self.k_wiki_results = k_wiki_results | |
| self.porter_stemmer = nltk.PorterStemmer() | |
| self.tokenizer = nltk.word_tokenize | |
| self.predictor = Predictor.from_path( | |
| "https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz" | |
| ) | |
| def get_NP(self, tree, nps): | |
| if isinstance(tree, dict): | |
| if "children" not in tree: | |
| if tree['nodeType'] == "NP": | |
| # print(tree['word']) | |
| # print(tree) | |
| nps.append(tree['word']) | |
| elif "children" in tree: | |
| if tree['nodeType'] == "NP": | |
| # print(tree['word']) | |
| nps.append(tree['word']) | |
| self.get_NP(tree['children'], nps) | |
| else: | |
| self.get_NP(tree['children'], nps) | |
| elif isinstance(tree, list): | |
| for sub_tree in tree: | |
| self.get_NP(sub_tree, nps) | |
| return nps | |
| def get_subjects(self, tree): | |
| subject_words = [] | |
| subjects = [] | |
| for subtree in tree['children']: | |
| if subtree['nodeType'] == "VP" or subtree['nodeType'] == 'S' or subtree['nodeType'] == 'VBZ': | |
| subjects.append(' '.join(subject_words)) | |
| subject_words.append(subtree['word']) | |
| else: | |
| subject_words.append(subtree['word']) | |
| return subjects | |
| def get_noun_phrases(self, claim): | |
| tokens = self.predictor.predict(claim) | |
| nps = [] | |
| tree = tokens['hierplane_tree']['root'] | |
| noun_phrases = self.get_NP(tree, nps) | |
| subjects = self.get_subjects(tree) | |
| for subject in subjects: | |
| if len(subject) > 0: | |
| noun_phrases.append(subject) | |
| if self.add_claim: | |
| noun_phrases.append(claim) | |
| return list(set(noun_phrases)) | |
| def get_doc_for_claim(self, noun_phrases): | |
| predicted_pages = [] | |
| for np in noun_phrases: | |
| if len(np) > 300: | |
| continue | |
| i = 1 | |
| while i < 12: | |
| try: | |
| # print(np) | |
| # res = server.lookup(np, keep_all=True) | |
| # docs = [y for _, y in res] if res is not None else [] | |
| docs = wikipedia.search(np) | |
| if self.k_wiki_results is not None: | |
| predicted_pages.extend(docs[:self.k_wiki_results]) | |
| else: | |
| predicted_pages.extend(docs) | |
| except (ConnectionResetError, ConnectionError, ConnectionAbortedError, ConnectionRefusedError): | |
| print("Connection reset error received! Trial #" + str(i)) | |
| time.sleep(600 * i) | |
| i += 1 | |
| else: | |
| break | |
| # sleep_num = random.uniform(0.1,0.7) | |
| # time.sleep(sleep_num) | |
| predicted_pages = set(predicted_pages) | |
| processed_pages = [] | |
| for page in predicted_pages: | |
| page = page.replace(" ", "_") | |
| page = page.replace("(", "-LRB-") | |
| page = page.replace(")", "-RRB-") | |
| page = page.replace(":", "-COLON-") | |
| processed_pages.append(page) | |
| return processed_pages | |
| def np_conc(self, noun_phrases): | |
| noun_phrases = set(noun_phrases) | |
| predicted_pages = [] | |
| for np in noun_phrases: | |
| page = np.replace('( ', '-LRB-') | |
| page = page.replace(' )', '-RRB-') | |
| page = page.replace(' - ', '-') | |
| page = page.replace(' :', '-COLON-') | |
| page = page.replace(' ,', ',') | |
| page = page.replace(" 's", "'s") | |
| page = page.replace(' ', '_') | |
| if len(page) < 1: | |
| continue | |
| doc_lines = self.db.get_doc_lines(page) | |
| if len(doc_lines) > 0: | |
| predicted_pages.append(page) | |
| return predicted_pages | |
| def exact_match(self, claim): | |
| noun_phrases = self.get_noun_phrases(claim) | |
| wiki_results = self.get_doc_for_claim(noun_phrases) | |
| wiki_results = list(set(wiki_results)) | |
| claim = claim.replace(".", "") | |
| claim = claim.replace("-", " ") | |
| words = [self.porter_stemmer.stem(word.lower()) for word in self.tokenizer(claim)] | |
| words = set(words) | |
| predicted_pages = self.np_conc(noun_phrases) | |
| for page in wiki_results: | |
| page = normalize(page) | |
| processed_page = re.sub("-LRB-.*?-RRB-", "", page) | |
| processed_page = re.sub("_", " ", processed_page) | |
| processed_page = re.sub("-COLON-", ":", processed_page) | |
| processed_page = processed_page.replace("-", " ") | |
| processed_page = processed_page.replace("–", " ") | |
| processed_page = processed_page.replace(".", "") | |
| page_words = [self.porter_stemmer.stem(word.lower()) for word in self.tokenizer(processed_page) if | |
| len(word) > 0] | |
| if all([item in words for item in page_words]): | |
| if ':' in page: | |
| page = page.replace(":", "-COLON-") | |
| predicted_pages.append(page) | |
| predicted_pages = list(set(predicted_pages)) | |
| return noun_phrases, wiki_results, predicted_pages | |
| def save_to_file(results, client, filename): | |
| with open(filename, 'w', encoding='utf-8') as fout: | |
| for _id, line in enumerate(results): | |
| claim = line['claim'] | |
| evidence = [] | |
| for page in line['predicted_pages']: | |
| evidence.extend(client.db.get_doc_lines(page)) | |
| print(json.dumps({'claim': claim, 'evidence': evidence}, ensure_ascii=False), file=fout) | |
| if __name__ == '__main__': | |
| database_path = 'data/fever.db' | |
| add_claim = True | |
| k_wiki_results = 7 | |
| client = DocRetrieval(database_path, add_claim, k_wiki_results) | |
| results = [] | |
| with open('data/claims.json', 'r', encoding='utf-8') as fin: | |
| for line in tqdm(fin): | |
| line = json.loads(line) | |
| _, _, predicted_pages = client.exact_match(line['claim']) | |
| evidence = [] | |
| for page in predicted_pages: | |
| evidence.extend(client.db.get_doc_lines(page)) | |
| line['evidence'] = evidence | |
| results.append(line) | |
| with open('data/pages.json', 'w', encoding='utf-8') as fout: | |
| for line in results: | |
| print(json.dumps(line, ensure_ascii=False), file=fout) | |