Spaces:
Build error
Build error
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # Created by zd302 at 17/07/2024 | |
| import torch | |
| import numpy as np | |
| import requests | |
| from rank_bm25 import BM25Okapi | |
| from bs4 import BeautifulSoup | |
| from transformers import BartTokenizer, BartForConditionalGeneration | |
| from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification | |
| from transformers import RobertaTokenizer, RobertaForSequenceClassification | |
| import pytorch_lightning as pl | |
| from averitec.models.DualEncoderModule import DualEncoderModule | |
| from averitec.models.SequenceClassificationModule import SequenceClassificationModule | |
| from averitec.models.JustificationGenerationModule import JustificationGenerationModule | |
| import wikipediaapi | |
| wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en') | |
| import os | |
| import nltk | |
| nltk.download('punkt') | |
| from nltk import pos_tag, word_tokenize, sent_tokenize | |
| import spacy | |
| os.system("python -m spacy download en_core_web_sm") | |
| nlp = spacy.load("en_core_web_sm") | |
| # ---------- Load Veracity and Justification prediction model ---------- | |
| LABEL = [ | |
| "Supported", | |
| "Refuted", | |
| "Not Enough Evidence", | |
| "Conflicting Evidence/Cherrypicking", | |
| ] | |
| # Veracity | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification") | |
| veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", | |
| tokenizer=veracity_tokenizer, model=bert_model).to(device) | |
| # Justification | |
| justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True) | |
| bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large") | |
| best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt' | |
| justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device) | |
| # --------------------------------------------------------------------------- | |
| # ---------------------------------------------------------------------------- | |
| class Docs: | |
| def __init__(self, metadata=dict(), page_content=""): | |
| self.metadata = metadata | |
| self.page_content = page_content | |
| # ------------------------------ Googleretriever ----------------------------- | |
| def Googleretriever(): | |
| return 0 | |
| # ------------------------------ Googleretriever ----------------------------- | |
| # ------------------------------ Wikipediaretriever -------------------------- | |
| def search_entity_wikipeida(entity): | |
| find_evidence = [] | |
| page_py = wiki_wiki.page(entity) | |
| if page_py.exists(): | |
| introduction = page_py.summary | |
| find_evidence.append([str(entity), introduction]) | |
| return find_evidence | |
| def clean_str(p): | |
| return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8") | |
| def find_similar_wikipedia(entity, relevant_wikipages): | |
| # If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages. | |
| ent_ = entity.replace(" ", "+") | |
| search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1" | |
| response_text = requests.get(search_url).text | |
| soup = BeautifulSoup(response_text, features="html.parser") | |
| result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) | |
| if result_divs: | |
| result_titles = [clean_str(div.get_text().strip()) for div in result_divs] | |
| similar_titles = result_titles[:5] | |
| saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages | |
| for _t in similar_titles: | |
| if _t not in saved_titles and len(relevant_wikipages) < 5: | |
| _evi = search_entity_wikipeida(_t) | |
| # _evi = search_step(_t) | |
| relevant_wikipages.extend(_evi) | |
| return relevant_wikipages | |
| def find_evidence_from_wikipedia(claim): | |
| # | |
| doc = nlp(claim) | |
| # | |
| wikipedia_page = [] | |
| for ent in doc.ents: | |
| relevant_wikipages = search_entity_wikipeida(ent) | |
| if len(relevant_wikipages) < 5: | |
| relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages) | |
| wikipedia_page.extend(relevant_wikipages) | |
| return wikipedia_page | |
| def bm25_retriever(query, corpus, topk=3): | |
| bm25 = BM25Okapi(corpus) | |
| # | |
| query_tokens = word_tokenize(query) | |
| scores = bm25.get_scores(query_tokens) | |
| top_n = np.argsort(scores)[::-1][:topk] | |
| top_n_scores = [scores[i] for i in top_n] | |
| return top_n, top_n_scores | |
| def relevant_sentence_retrieval(query, wiki_intro, k): | |
| # 1. Create corpus here | |
| corpus, sentences = [], [] | |
| titles = [] | |
| for i, (title, intro) in enumerate(wiki_intro): | |
| sents_in_intro = sent_tokenize(intro) | |
| for sent in sents_in_intro: | |
| corpus.append(word_tokenize(sent)) | |
| sentences.append(sent) | |
| titles.append(title) | |
| # ----- BM25 | |
| bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k) | |
| bm25_top_n_sents = [sentences[i] for i in bm25_top_n] | |
| bm25_top_n_titles = [titles[i] for i in bm25_top_n] | |
| return bm25_top_n_sents, bm25_top_n_titles | |
| # ------------------------------ Wikipediaretriever ----------------------------- | |
| def Wikipediaretriever(claim): | |
| # 1. extract relevant wikipedia pages from wikipedia dumps | |
| wikipedia_page = find_evidence_from_wikipedia(claim) | |
| # 2. extract relevant sentences from extracted wikipedia pages | |
| sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3) | |
| # | |
| results = [] | |
| for i, (sent, title) in enumerate(zip(sents, titles)): | |
| metadata = dict() | |
| metadata['name'] = claim | |
| metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split()) | |
| metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title) | |
| metadata['short_name'] = "Evidence {}".format(i + 1) | |
| metadata['page_number'] = "" | |
| metadata['query'] = sent | |
| metadata['title'] = title | |
| metadata['evidence'] = sent | |
| metadata['answer'] = "" | |
| metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata['evidence'] | |
| page_content = f"""{metadata['page_content']}""" | |
| results.append(Docs(metadata, page_content)) | |
| return results | |
| # ------------------------------ Veracity Prediction ------------------------------ | |
| class SequenceClassificationDataLoader(pl.LightningDataModule): | |
| def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.data_file = data_file | |
| self.batch_size = batch_size | |
| self.add_extra_nee = add_extra_nee | |
| def tokenize_strings( | |
| self, | |
| source_sentences, | |
| max_length=400, | |
| pad_to_max_length=False, | |
| return_tensors="pt", | |
| ): | |
| encoded_dict = self.tokenizer( | |
| source_sentences, | |
| max_length=max_length, | |
| padding="max_length" if pad_to_max_length else "longest", | |
| truncation=True, | |
| return_tensors=return_tensors, | |
| ) | |
| input_ids = encoded_dict["input_ids"] | |
| attention_masks = encoded_dict["attention_mask"] | |
| return input_ids, attention_masks | |
| def quadruple_to_string(self, claim, question, answer, bool_explanation=""): | |
| if bool_explanation is not None and len(bool_explanation) > 0: | |
| bool_explanation = ", because " + bool_explanation.lower().strip() | |
| else: | |
| bool_explanation = "" | |
| return ( | |
| "[CLAIM] " | |
| + claim.strip() | |
| + " [QUESTION] " | |
| + question.strip() | |
| + " " | |
| + answer.strip() | |
| + bool_explanation | |
| ) | |
| def veracity_prediction(claim, evidence): | |
| dataLoader = SequenceClassificationDataLoader( | |
| tokenizer=veracity_tokenizer, | |
| data_file="this_is_discontinued", | |
| batch_size=32, | |
| add_extra_nee=False, | |
| ) | |
| evidence_strings = [] | |
| for evi in evidence: | |
| evidence_strings.append(dataLoader.quadruple_to_string(claim, evi.metadata["query"], evi.metadata["answer"], "")) | |
| if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI. | |
| pred_label = "Not Enough Evidence" | |
| return pred_label | |
| tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings) | |
| example_support = torch.argmax( | |
| veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1) | |
| has_unanswerable = False | |
| has_true = False | |
| has_false = False | |
| for v in example_support: | |
| if v == 0: | |
| has_true = True | |
| if v == 1: | |
| has_false = True | |
| if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this | |
| has_unanswerable = True | |
| if has_unanswerable: | |
| answer = 2 | |
| elif has_true and not has_false: | |
| answer = 0 | |
| elif not has_true and has_false: | |
| answer = 1 | |
| else: | |
| answer = 3 | |
| pred_label = LABEL[answer] | |
| return pred_label | |
| # ------------------------------ Justification Generation ------------------------------ | |
| def extract_claim_str(claim, evidence, verdict_label): | |
| claim_str = "[CLAIM] " + claim + " [EVIDENCE] " | |
| for evi in evidence: | |
| q_text = evi.metadata['query'].strip() | |
| if len(q_text) == 0: | |
| continue | |
| if not q_text[-1] == "?": | |
| q_text += "?" | |
| answer_strings = [] | |
| answer_strings.append(evi.metadata['answer']) | |
| claim_str += q_text | |
| for a_text in answer_strings: | |
| if a_text: | |
| if not a_text[-1] == ".": | |
| a_text += "." | |
| claim_str += " " + a_text.strip() | |
| claim_str += " " | |
| claim_str += " [VERDICT] " + verdict_label | |
| return claim_str | |
| def justification_generation(claim, evidence, verdict_label): | |
| # | |
| claim_str = extract_claim_str(claim, evidence, verdict_label) | |
| claim_str.strip() | |
| pred_justification = justification_model.generate(claim_str, device=device) | |
| return pred_justification.strip() | |