Spaces:
Build error
Build error
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # Created by zd302 at 08/07/2024 | |
| import gradio as gr | |
| import tqdm | |
| import torch | |
| import numpy as np | |
| from time import sleep | |
| from datetime import datetime | |
| import threading | |
| import gc | |
| import os | |
| import json | |
| import pytorch_lightning as pl | |
| from urllib.parse import urlparse | |
| from accelerate import Accelerator | |
| import spaces | |
| from transformers import BartTokenizer, BartForConditionalGeneration | |
| from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification | |
| from transformers import RobertaTokenizer, RobertaForSequenceClassification | |
| from rank_bm25 import BM25Okapi | |
| # import bm25s | |
| # import Stemmer # optional: for stemming | |
| from html2lines import url2lines | |
| from googleapiclient.discovery import build | |
| from averitec.models.DualEncoderModule import DualEncoderModule | |
| from averitec.models.SequenceClassificationModule import SequenceClassificationModule | |
| from averitec.models.JustificationGenerationModule import JustificationGenerationModule | |
| from averitec.data.sample_claims import CLAIMS_Type | |
| # --------------------------------------------------------------------------- | |
| # load .env | |
| from utils import create_user_id | |
| user_id = create_user_id() | |
| from azure.storage.fileshare import ShareServiceClient | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except Exception as e: | |
| pass | |
| # os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| account_url = os.environ["AZURE_ACCOUNT_URL"] | |
| credential = { | |
| "account_key": os.environ['AZURE_ACCOUNT_KEY'], | |
| "account_name": os.environ['AZURE_ACCOUNT_NAME'] | |
| } | |
| file_share_name = "averitec" | |
| azure_service = ShareServiceClient(account_url=account_url, credential=credential) | |
| azure_share_client = azure_service.get_share_client(file_share_name) | |
| # ---------- Setting ---------- | |
| import requests | |
| from bs4 import BeautifulSoup | |
| import wikipediaapi | |
| wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en') | |
| import nltk | |
| nltk.download('averaged_perceptron_tagger_eng') | |
| nltk.download('averaged_perceptron_tagger') | |
| nltk.download('punkt') | |
| nltk.download('punkt_tab') | |
| from nltk import pos_tag, word_tokenize, sent_tokenize | |
| import spacy | |
| os.system("python -m spacy download en_core_web_sm") | |
| nlp = spacy.load("en_core_web_sm") | |
| # --------------------------------------------------------------------------- | |
| # Load sample dict for AVeriTeC search | |
| # all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r')) | |
| train_examples = json.load(open('averitec/data/train.json', 'r')) | |
| def claim2prompts(example): | |
| claim = example["claim"] | |
| # claim_str = "Claim: " + claim + "||Evidence: " | |
| claim_str = "Evidence: " | |
| for question in example["questions"]: | |
| q_text = question["question"].strip() | |
| if len(q_text) == 0: | |
| continue | |
| if not q_text[-1] == "?": | |
| q_text += "?" | |
| answer_strings = [] | |
| for a in question["answers"]: | |
| if a["answer_type"] in ["Extractive", "Abstractive"]: | |
| answer_strings.append(a["answer"]) | |
| if a["answer_type"] == "Boolean": | |
| answer_strings.append(a["answer"] + ", because " + a["boolean_explanation"].lower().strip()) | |
| for a_text in answer_strings: | |
| if not a_text[-1] in [".", "!", ":", "?"]: | |
| a_text += "." | |
| # prompt_lookup_str = claim + " " + a_text | |
| prompt_lookup_str = a_text | |
| this_q_claim_str = claim_str + " " + a_text.strip() + "||Question answered: " + q_text | |
| yield (prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n")) | |
| def generate_reference_corpus(reference_file): | |
| all_data_corpus = [] | |
| tokenized_corpus = [] | |
| for train_example in train_examples: | |
| train_claim = train_example["claim"] | |
| speaker = train_example["speaker"].strip() if train_example["speaker"] is not None and len( | |
| train_example["speaker"]) > 1 else "they" | |
| questions = [q["question"] for q in train_example["questions"]] | |
| claim_dict_builder = {} | |
| claim_dict_builder["claim"] = train_claim | |
| claim_dict_builder["speaker"] = speaker | |
| claim_dict_builder["questions"] = questions | |
| tokenized_corpus.append(nltk.word_tokenize(claim_dict_builder["claim"])) | |
| all_data_corpus.append(claim_dict_builder) | |
| return tokenized_corpus, all_data_corpus | |
| def generate_step2_reference_corpus(reference_file): | |
| prompt_corpus = [] | |
| tokenized_corpus = [] | |
| for example in train_examples: | |
| for lookup_str, prompt in claim2prompts(example): | |
| entry = nltk.word_tokenize(lookup_str) | |
| tokenized_corpus.append(entry) | |
| prompt_corpus.append(prompt) | |
| return tokenized_corpus, prompt_corpus | |
| reference_file = "averitec/data/train.json" | |
| tokenized_corpus0, all_data_corpus0 = generate_reference_corpus(reference_file) | |
| qg_bm25 = BM25Okapi(tokenized_corpus0) | |
| tokenized_corpus1, prompt_corpus1 = generate_step2_reference_corpus(reference_file) | |
| prompt_bm25 = BM25Okapi(tokenized_corpus1) | |
| # print(train_examples[0]['claim']) | |
| # --------------------------------------------------------------------------- | |
| # ---------- Load pretrained models ---------- | |
| # ---------- load Evidence retrieval model ---------- | |
| # from drqa import retriever | |
| # db_class = retriever.get_class('sqlite') | |
| # doc_db = db_class("averitec/data/wikipedia_dumps/enwiki.db") | |
| # ranker = retriever.get_class('tfidf')(tfidf_path="averitec/data/wikipedia_dumps/enwiki-tfidf-with-id-title.npz") | |
| # ---------- Load Veracity and Justification prediction model ---------- | |
| print("Loading models ...") | |
| LABEL = [ | |
| "Supported", | |
| "Refuted", | |
| "Not Enough Evidence", | |
| "Conflicting Evidence/Cherrypicking", | |
| ] | |
| if torch.cuda.is_available(): | |
| # # device | |
| # device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # question generation | |
| qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-1b1") | |
| qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-1b1", torch_dtype=torch.bfloat16).to('cuda') | |
| # qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device) | |
| # qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1") | |
| # qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device) | |
| # rerank | |
| rerank_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| rereank_bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, problem_type="single_label_classification") # Must specify single_label for some reason | |
| best_checkpoint = "averitec/pretrained_models/bert_dual_encoder.ckpt" | |
| rerank_trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=rerank_tokenizer, model=rereank_bert_model) | |
| # rerank_trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=rerank_tokenizer, model=rereank_bert_model).to(device) | |
| # Veracity | |
| veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification") | |
| veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", tokenizer=veracity_tokenizer, model=bert_model) | |
| # veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", tokenizer=veracity_tokenizer, model=bert_model).to(device) | |
| # Justification | |
| justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True) | |
| bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large") | |
| best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt' | |
| justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model) | |
| # justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device) | |
| # Set up Gradio Theme | |
| theme = gr.themes.Base( | |
| primary_hue="blue", | |
| secondary_hue="red", | |
| font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], | |
| ) | |
| # ---------- Setting ---------- | |
| class Docs: | |
| def __init__(self, metadata=dict(), page_content=""): | |
| self.metadata = metadata | |
| self.page_content = page_content | |
| def make_html_source(source, i): | |
| meta = source.metadata | |
| content = source.page_content.strip() | |
| card = f""" | |
| <div class="card" id="doc{i}"> | |
| <div class="card-content"> | |
| <h2>Doc {i} - URL: <a href="{meta['url']}" target="_blank" class="pdf-link">{meta['url']}</a></h2> | |
| <p>{content}</p> | |
| </div> | |
| <div class="card-footer"> | |
| <span>CACHED SOURCE URL:</span> | |
| <a href="{meta['cached_source_url']}" target="_blank" class="pdf-link"> | |
| <span role="img" aria-label="Open PDF">🔗</span> | |
| </a> | |
| </div> | |
| </div> | |
| """ | |
| return card | |
| # ----- veracity_prediction ----- | |
| class SequenceClassificationDataLoader(pl.LightningDataModule): | |
| def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.data_file = data_file | |
| self.batch_size = batch_size | |
| self.add_extra_nee = add_extra_nee | |
| def tokenize_strings( | |
| self, | |
| source_sentences, | |
| max_length=400, | |
| pad_to_max_length=False, | |
| return_tensors="pt", | |
| ): | |
| encoded_dict = self.tokenizer( | |
| source_sentences, | |
| max_length=max_length, | |
| padding="max_length" if pad_to_max_length else "longest", | |
| truncation=True, | |
| return_tensors=return_tensors, | |
| ) | |
| input_ids = encoded_dict["input_ids"] | |
| attention_masks = encoded_dict["attention_mask"] | |
| return input_ids, attention_masks | |
| def quadruple_to_string(self, claim, question, answer, bool_explanation=""): | |
| if bool_explanation is not None and len(bool_explanation) > 0: | |
| bool_explanation = ", because " + bool_explanation.lower().strip() | |
| else: | |
| bool_explanation = "" | |
| return ( | |
| "[CLAIM] " | |
| + claim.strip() | |
| + " [QUESTION] " | |
| + question.strip() | |
| + " " | |
| + answer.strip() | |
| + bool_explanation | |
| ) | |
| def veracity_prediction(claim, qa_evidence): | |
| dataLoader = SequenceClassificationDataLoader( | |
| tokenizer=veracity_tokenizer, | |
| data_file="this_is_discontinued", | |
| batch_size=32, | |
| add_extra_nee=False, | |
| ) | |
| evidence_strings = [] | |
| for evidence in qa_evidence: | |
| evidence_strings.append( | |
| dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], "")) | |
| if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI. | |
| pred_label = "Not Enough Evidence" | |
| return pred_label | |
| tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings) | |
| example_support = torch.argmax(veracity_model(tokenized_strings.to(veracity_model.device), attention_mask=attention_mask.to(veracity_model.device)).logits, axis=1) | |
| # example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1) | |
| has_unanswerable = False | |
| has_true = False | |
| has_false = False | |
| for v in example_support: | |
| if v == 0: | |
| has_true = True | |
| if v == 1: | |
| has_false = True | |
| if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this | |
| has_unanswerable = True | |
| if has_unanswerable: | |
| answer = 2 | |
| elif has_true and not has_false: | |
| answer = 0 | |
| elif not has_true and has_false: | |
| answer = 1 | |
| else: | |
| answer = 3 | |
| pred_label = LABEL[answer] | |
| return pred_label | |
| def extract_claim_str(claim, qa_evidence, verdict_label): | |
| claim_str = "[CLAIM] " + claim + " [EVIDENCE] " | |
| for evidence in qa_evidence: | |
| q_text = evidence.metadata['query'].strip() | |
| if len(q_text) == 0: | |
| continue | |
| if not q_text[-1] == "?": | |
| q_text += "?" | |
| answer_strings = [] | |
| answer_strings.append(evidence.metadata['answer']) | |
| claim_str += q_text | |
| for a_text in answer_strings: | |
| if a_text: | |
| if not a_text[-1] == ".": | |
| a_text += "." | |
| claim_str += " " + a_text.strip() | |
| claim_str += " " | |
| claim_str += " [VERDICT] " + verdict_label | |
| return claim_str | |
| def justification_generation(claim, qa_evidence, verdict_label): | |
| # | |
| # claim_str = extract_claim_str(claim, qa_evidence, verdict_label) | |
| claim_str = "[CLAIM] " + claim + " [EVIDENCE] " | |
| for evi in qa_evidence: | |
| q_text = evi.metadata['query'].strip() | |
| if len(q_text) == 0: | |
| continue | |
| if not q_text[-1] == "?": | |
| q_text += "?" | |
| answer_strings = [] | |
| answer_strings.append(evi.metadata['answer']) | |
| claim_str += q_text | |
| for a_text in answer_strings: | |
| if a_text: | |
| if not a_text[-1] == ".": | |
| a_text += "." | |
| claim_str += " " + a_text.strip() | |
| claim_str += " " | |
| claim_str += " [VERDICT] " + verdict_label | |
| # | |
| claim_str.strip() | |
| pred_justification = justification_model.generate(claim_str, device=justification_model.device) | |
| # pred_justification = justification_model.generate(claim_str, device=device) | |
| return pred_justification.strip() | |
| def QAprediction(claim, evidence, sources): | |
| parts = [] | |
| # | |
| evidence_title = f"""<h5>Retrieved Evidence:</h5>""" | |
| for i, evi in enumerate(evidence, 1): | |
| part = f"""<span>Doc {i}</span>""" | |
| subpart = f"""<a href="#doc{i}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{i}</sup></span></a>""" | |
| subparts = "".join([part, subpart]) | |
| parts.append(subparts) | |
| evidence_part = ", ".join(parts) | |
| prediction_title = f"""<h5>Prediction:</h5>""" | |
| # if 'Google' in sources: | |
| # verdict_label = google_veracity_prediction(claim, evidence) | |
| # justification_label = google_justification_generation(claim, evidence, verdict_label) | |
| # justification_part = f"""<span>Justification: {justification_label}</span>""" | |
| # if 'WikiPedia' in sources: | |
| # verdict_label = wikipedia_veracity_prediction(claim, evidence) | |
| # justification_label = wikipedia_justification_generation(claim, evidence, verdict_label) | |
| # # justification_label = "See retrieved docs." | |
| # justification_part = f"""<span>Justification: {justification_label}</span>""" | |
| verdict_label = veracity_prediction(claim, evidence) | |
| justification_label = justification_generation(claim, evidence, verdict_label) | |
| # justification_label = "See retrieved docs." | |
| justification_part = f"""<span>Justification: {justification_label}</span>""" | |
| verdict_part = f"""Verdict: <span>{verdict_label}.</span><br>""" | |
| content_parts = "".join([evidence_title, evidence_part, prediction_title, verdict_part, justification_part]) | |
| return content_parts, [verdict_label, justification_label] | |
| # ----------GoogleAPIretriever--------- | |
| # def generate_reference_corpus(reference_file): | |
| # # with open(reference_file) as f: | |
| # # train_examples = json.load(f) | |
| # | |
| # all_data_corpus = [] | |
| # tokenized_corpus = [] | |
| # | |
| # for train_example in train_examples: | |
| # train_claim = train_example["claim"] | |
| # | |
| # speaker = train_example["speaker"].strip() if train_example["speaker"] is not None and len( | |
| # train_example["speaker"]) > 1 else "they" | |
| # | |
| # questions = [q["question"] for q in train_example["questions"]] | |
| # | |
| # claim_dict_builder = {} | |
| # claim_dict_builder["claim"] = train_claim | |
| # claim_dict_builder["speaker"] = speaker | |
| # claim_dict_builder["questions"] = questions | |
| # | |
| # tokenized_corpus.append(nltk.word_tokenize(claim_dict_builder["claim"])) | |
| # all_data_corpus.append(claim_dict_builder) | |
| # | |
| # return tokenized_corpus, all_data_corpus | |
| def doc2prompt(doc): | |
| prompt_parts = "Outrageously, " + doc["speaker"] + " claimed that \"" + doc[ | |
| "claim"].strip() + "\". Criticism includes questions like: " | |
| questions = [q.strip() for q in doc["questions"]] | |
| return prompt_parts + " ".join(questions) | |
| def docs2prompt(top_docs): | |
| return "\n\n".join([doc2prompt(d) for d in top_docs]) | |
| def prompt_question_generation(test_claim, speaker="they", topk=10): | |
| # | |
| # reference_file = "averitec/data/train.json" | |
| # tokenized_corpus, all_data_corpus = generate_reference_corpus(reference_file) | |
| # bm25 = BM25Okapi(tokenized_corpus) | |
| # -------------------------------------------------- | |
| # test claim | |
| s = qg_bm25.get_scores(nltk.word_tokenize(test_claim)) | |
| top_n = np.argsort(s)[::-1][:topk] | |
| docs = [all_data_corpus0[i] for i in top_n] | |
| # -------------------------------------------------- | |
| prompt = docs2prompt(docs) + "\n\n" + "Outrageously, " + speaker + " claimed that \"" + test_claim.strip() + \ | |
| "\". Criticism includes questions like: " | |
| sentences = [prompt] | |
| inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(qg_model.device) | |
| # inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device) | |
| outputs = qg_model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True) | |
| tgt_text = qg_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| in_len = len(sentences[0]) | |
| questions_str = tgt_text[in_len:].split("\n")[0] | |
| qs = questions_str.split("?") | |
| qs = [q.strip() + "?" for q in qs if q.strip() and len(q.strip()) < 300] | |
| # | |
| generate_question = [{"question": q, "answers": []} for q in qs] | |
| return generate_question | |
| def check_claim_date(check_date): | |
| try: | |
| year, month, date = check_date.split("-") | |
| except: | |
| month, date, year = "01", "01", "2022" | |
| if len(year) == 2 and int(year) <= 30: | |
| year = "20" + year | |
| elif len(year) == 2: | |
| year = "19" + year | |
| elif len(year) == 1: | |
| year = "200" + year | |
| if len(month) == 1: | |
| month = "0" + month | |
| if len(date) == 1: | |
| date = "0" + date | |
| sort_date = year + month + date | |
| return sort_date | |
| def string_to_search_query(text, author): | |
| parts = word_tokenize(text.strip()) | |
| tags = pos_tag(parts) | |
| keep_tags = ["CD", "JJ", "NN", "VB"] | |
| if author is not None: | |
| search_string = author.split() | |
| else: | |
| search_string = [] | |
| for token, tag in zip(parts, tags): | |
| for keep_tag in keep_tags: | |
| if tag[1].startswith(keep_tag): | |
| search_string.append(token) | |
| search_string = " ".join(search_string) | |
| return search_string | |
| def google_search(search_term, api_key, cse_id, **kwargs): | |
| service = build("customsearch", "v1", developerKey=api_key) | |
| res = service.cse().list(q=search_term, cx=cse_id, **kwargs).execute() | |
| if "items" in res: | |
| return res['items'] | |
| else: | |
| return [] | |
| def get_domain_name(url): | |
| if '://' not in url: | |
| url = 'http://' + url | |
| domain = urlparse(url).netloc | |
| if domain.startswith("www."): | |
| return domain[4:] | |
| else: | |
| return domain | |
| def get_and_store(url_link, fp, worker, worker_stack): | |
| page_lines = url2lines(url_link) | |
| with open(fp, "w") as out_f: | |
| print("\n".join([url_link] + page_lines), file=out_f) | |
| worker_stack.append(worker) | |
| gc.collect() | |
| def get_text_from_link(url_link): | |
| page_lines = url2lines(url_link) | |
| return "\n".join([url_link] + page_lines) | |
| def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0): | |
| search_results = [] | |
| for i in range(1): | |
| try: | |
| search_results += google_search( | |
| search_string, | |
| api_key, | |
| search_engine_id, | |
| num=3, # num=10, | |
| start=0 + 10 * page, | |
| sort="date:r:19000101:" + sort_date, | |
| dateRestrict=None, | |
| gl="US" | |
| ) | |
| break | |
| except: | |
| sleep(1) | |
| # for i in range(3): | |
| # try: | |
| # search_results += google_search( | |
| # search_string, | |
| # api_key, | |
| # search_engine_id, | |
| # num=10, | |
| # start=0 + 10 * page, | |
| # sort="date:r:19000101:" + sort_date, | |
| # dateRestrict=None, | |
| # gl="US" | |
| # ) | |
| # break | |
| # except: | |
| # sleep(3) | |
| return search_results | |
| # @spaces.GPU | |
| def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3 | |
| # default config | |
| api_key = os.environ["GOOGLE_API_KEY"] | |
| search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"] | |
| blacklist = [ | |
| "jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download | |
| "facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this, | |
| "ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up | |
| "nlp.cs.princeton.edu", | |
| "huggingface.co" | |
| ] | |
| blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors | |
| "/glove.", | |
| "ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt", | |
| "https://web.mit.edu/adamrose/Public/googlelist", | |
| ] | |
| # save to folder | |
| store_folder = "averitec/data/store/retrieved_docs" | |
| # | |
| index = 0 | |
| questions = [q["question"] for q in generate_question][:3] | |
| # questions = [q["question"] for q in generate_question] # ori | |
| # check the date of the claim | |
| current_date = datetime.now().strftime("%Y-%m-%d") | |
| sort_date = check_claim_date(current_date) # check_date="2022-01-01" | |
| # | |
| search_strings = [] | |
| search_types = [] | |
| search_string_2 = string_to_search_query(claim, None) | |
| search_strings += [search_string_2, claim, ] | |
| search_types += ["claim", "claim-noformat", ] | |
| search_strings += questions | |
| search_types += ["question" for _ in questions] | |
| # start to search | |
| search_results = [] | |
| visited = {} | |
| store_counter = 0 | |
| worker_stack = list(range(10)) | |
| retrieve_evidence = [] | |
| for this_search_string, this_search_type in zip(search_strings, search_types): | |
| for page_num in range(n_pages): | |
| search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date, | |
| this_search_string, page=page_num) | |
| for result in search_results: | |
| link = str(result["link"]) | |
| domain = get_domain_name(link) | |
| if domain in blacklist: | |
| continue | |
| broken = False | |
| for b_file in blacklist_files: | |
| if b_file in link: | |
| broken = True | |
| if broken: | |
| continue | |
| if link.endswith(".pdf") or link.endswith(".doc"): | |
| continue | |
| store_file_path = "" | |
| if link in visited: | |
| web_text = visited[link] | |
| else: | |
| web_text = get_text_from_link(link) | |
| visited[link] = web_text | |
| line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, web_text] | |
| retrieve_evidence.append(line) | |
| return retrieve_evidence | |
| # def generate_step2_reference_corpus(reference_file): | |
| # # with open(reference_file) as f: | |
| # # train_examples = json.load(f) | |
| # | |
| # prompt_corpus = [] | |
| # tokenized_corpus = [] | |
| # | |
| # for example in train_examples: | |
| # for lookup_str, prompt in claim2prompts(example): | |
| # entry = nltk.word_tokenize(lookup_str) | |
| # tokenized_corpus.append(entry) | |
| # prompt_corpus.append(prompt) | |
| # | |
| # return tokenized_corpus, prompt_corpus | |
| def decorate_with_questions(claim, retrieve_evidence, top_k=3): # top_k=5, 10, 100 | |
| # | |
| # reference_file = "averitec/data/train.json" | |
| # tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file) | |
| # prompt_bm25 = BM25Okapi(tokenized_corpus) | |
| # | |
| tokenized_corpus = [] | |
| all_data_corpus = [] | |
| for retri_evi in tqdm.tqdm(retrieve_evidence): | |
| # store_file = retri_evi[-1] | |
| # with open(store_file, 'r') as f: | |
| web_text = retri_evi[-1] | |
| lines_in_web = web_text.split("\n") | |
| first = True | |
| for line in lines_in_web: | |
| # for line in f: | |
| line = line.strip() | |
| if first: | |
| first = False | |
| location_url = line | |
| continue | |
| if len(line) > 3: | |
| entry = nltk.word_tokenize(line) | |
| if (location_url, line) not in all_data_corpus: | |
| tokenized_corpus.append(entry) | |
| all_data_corpus.append((location_url, line)) | |
| if len(tokenized_corpus) == 0: | |
| print("") | |
| bm25 = BM25Okapi(tokenized_corpus) | |
| s = bm25.get_scores(nltk.word_tokenize(claim)) | |
| top_n = np.argsort(s)[::-1][:top_k] | |
| docs = [all_data_corpus[i] for i in top_n] | |
| generate_qa_pairs = [] | |
| # Then, generate questions for those top 50: | |
| for doc in tqdm.tqdm(docs): | |
| # prompt_lookup_str = example["claim"] + " " + doc[1] | |
| prompt_lookup_str = doc[1] | |
| prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str)) | |
| prompt_n = 10 | |
| prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n] | |
| prompt_docs = [prompt_corpus1[i] for i in prompt_top_n] | |
| claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: " | |
| prompt = "\n\n".join(prompt_docs + [claim_prompt]) | |
| sentences = [prompt] | |
| inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(qg_model.device) | |
| # inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device) | |
| outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True) | |
| tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0] | |
| # We are not allowed to generate more than 250 characters: | |
| tgt_text = tgt_text[:250] | |
| qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]] | |
| generate_qa_pairs.append(qa_pair) | |
| return generate_qa_pairs | |
| # def decorate_with_questions_michale(claim, retrieve_evidence, top_k=10): # top_k=100 | |
| # # | |
| # reference_file = "averitec/data/train.json" | |
| # tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file) | |
| # prompt_bm25 = BM25Okapi(tokenized_corpus) | |
| # | |
| # # Define the bloom model: | |
| # accelerator = Accelerator() | |
| # accel_device = accelerator.device | |
| # # device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # # tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1") | |
| # # model = BloomForCausalLM.from_pretrained( | |
| # # "bigscience/bloom-7b1", | |
| # # device_map="auto", | |
| # # torch_dtype=torch.bfloat16, | |
| # # offload_folder="./offload" | |
| # # ) | |
| # | |
| # # | |
| # tokenized_corpus = [] | |
| # all_data_corpus = [] | |
| # | |
| # for retri_evi in tqdm.tqdm(retrieve_evidence): | |
| # store_file = retri_evi[-1] | |
| # | |
| # with open(store_file, 'r') as f: | |
| # first = True | |
| # for line in f: | |
| # line = line.strip() | |
| # | |
| # if first: | |
| # first = False | |
| # location_url = line | |
| # continue | |
| # | |
| # if len(line) > 3: | |
| # entry = nltk.word_tokenize(line) | |
| # if (location_url, line) not in all_data_corpus: | |
| # tokenized_corpus.append(entry) | |
| # all_data_corpus.append((location_url, line)) | |
| # | |
| # if len(tokenized_corpus) == 0: | |
| # print("") | |
| # | |
| # bm25 = BM25Okapi(tokenized_corpus) | |
| # s = bm25.get_scores(nltk.word_tokenize(claim)) | |
| # top_n = np.argsort(s)[::-1][:top_k] | |
| # docs = [all_data_corpus[i] for i in top_n] | |
| # | |
| # generate_qa_pairs = [] | |
| # # Then, generate questions for those top 50: | |
| # for doc in tqdm.tqdm(docs): | |
| # # prompt_lookup_str = example["claim"] + " " + doc[1] | |
| # prompt_lookup_str = doc[1] | |
| # | |
| # prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str)) | |
| # prompt_n = 10 | |
| # prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n] | |
| # prompt_docs = [prompt_corpus[i] for i in prompt_top_n] | |
| # | |
| # claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: " | |
| # prompt = "\n\n".join(prompt_docs + [claim_prompt]) | |
| # sentences = [prompt] | |
| # | |
| # inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device) | |
| # outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2, | |
| # early_stopping=True) | |
| # | |
| # tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0] | |
| # # We are not allowed to generate more than 250 characters: | |
| # tgt_text = tgt_text[:250] | |
| # | |
| # qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]] | |
| # generate_qa_pairs.append(qa_pair) | |
| # | |
| # return generate_qa_pairs | |
| def triple_to_string(x): | |
| return " </s> ".join([item.strip() for item in x]) | |
| def rerank_questions(claim, bm25_qas, topk=3): | |
| # | |
| strs_to_score = [] | |
| values = [] | |
| for question, answer, source in bm25_qas: | |
| str_to_score = triple_to_string([claim, question, answer]) | |
| strs_to_score.append(str_to_score) | |
| values.append([question, answer, source]) | |
| if len(bm25_qas) > 0: | |
| encoded_dict = rerank_tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(rerank_trained_model.device) | |
| # encoded_dict = rerank_tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(device) | |
| input_ids = encoded_dict['input_ids'] | |
| attention_masks = encoded_dict['attention_mask'] | |
| scores = torch.softmax(rerank_trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1] | |
| top_n = torch.argsort(scores, descending=True)[:topk] | |
| pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n] | |
| else: | |
| pass_through = [] | |
| top3_qa_pairs = pass_through | |
| return top3_qa_pairs | |
| def Googleretriever(query, sources): | |
| # ----- Generate QA pairs using AVeriTeC | |
| # step 1: generate questions for the query/claim using Bloom | |
| generate_question = prompt_question_generation(query) | |
| # step 2: retrieve evidence for the generated questions using Google API | |
| retrieve_evidence = averitec_search(query, generate_question) | |
| # step 3: generate QA pairs for each retrieved document | |
| bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence) | |
| # step 4: rerank QA pairs | |
| top3_qa_pairs = rerank_questions(query, bm25_qa_pairs) | |
| # Add score to metadata | |
| results = [] | |
| for i, qa in enumerate(top3_qa_pairs): | |
| metadata = dict() | |
| metadata['name'] = qa['question'] | |
| metadata['url'] = qa['source_url'] | |
| metadata['cached_source_url'] = qa['source_url'] | |
| metadata['short_name'] = "Evidence {}".format(i + 1) | |
| metadata['page_number'] = "" | |
| metadata['title'] = qa['question'] | |
| metadata['evidence'] = qa['answers'] | |
| metadata['query'] = qa['question'] | |
| metadata['answer'] = qa['answers'] | |
| metadata['page_content'] = "<b>Question</b>: " + qa['question'] + "<br>" + "<b>Answer</b>: " + qa['answers'] | |
| page_content = f"""{metadata['page_content']}""" | |
| results.append(Docs(metadata, page_content)) | |
| return results | |
| # ----------GoogleAPIretriever--------- | |
| # ----------Wikipediaretriever--------- | |
| def bm25_retriever(query, corpus, topk=3): | |
| bm25 = BM25Okapi(corpus) | |
| # | |
| query_tokens = word_tokenize(query) | |
| scores = bm25.get_scores(query_tokens) | |
| top_n = np.argsort(scores)[::-1][:topk] | |
| top_n_scores = [scores[i] for i in top_n] | |
| return top_n, top_n_scores | |
| def bm25s_retriever(query, corpus, topk=3): | |
| # optional: create a stemmer | |
| stemmer = Stemmer.Stemmer("english") | |
| # Tokenize the corpus and only keep the ids (faster and saves memory) | |
| corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer) | |
| # Create the BM25 model and index the corpus | |
| retriever = bm25s.BM25() | |
| retriever.index(corpus_tokens) | |
| # Query the corpus | |
| query_tokens = bm25s.tokenize(query, stemmer=stemmer) | |
| # Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k) | |
| results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=topk) | |
| top_n = [corpus.index(res) for res in results[0]] | |
| return top_n, scores | |
| def find_evidence_from_wikipedia_dumps(claim): | |
| # | |
| doc = nlp(claim) | |
| entities_in_claim = [str(ent).lower() for ent in doc.ents] | |
| title2id = ranker.doc_dict[0] | |
| wiki_intro, ent_list = [], [] | |
| for ent in entities_in_claim: | |
| if ent in title2id.keys(): | |
| ids = title2id[ent] | |
| introduction = doc_db.get_doc_intro(ids) | |
| wiki_intro.append([ent, introduction]) | |
| # fulltext = doc_db.get_doc_text(ids) | |
| # evidence.append([ent, fulltext]) | |
| ent_list.append(ent) | |
| if len(wiki_intro) < 5: | |
| evidence_tfidf = process_topk(claim, title2id, ent_list, k=5) | |
| wiki_intro.extend(evidence_tfidf) | |
| return wiki_intro, doc | |
| def relevant_sentence_retrieval(query, wiki_intro, k): | |
| # 1. Create corpus here | |
| corpus, sentences = [], [] | |
| titles = [] | |
| for i, (title, intro) in enumerate(wiki_intro): | |
| sents_in_intro = sent_tokenize(intro) | |
| for sent in sents_in_intro: | |
| corpus.append(word_tokenize(sent)) | |
| sentences.append(sent) | |
| titles.append(title) | |
| # | |
| # ----- BM25 | |
| bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k) | |
| bm25_top_n_sents = [sentences[i] for i in bm25_top_n] | |
| bm25_top_n_titles = [titles[i] for i in bm25_top_n] | |
| # ----- BM25s | |
| # bm25s_top_n, bm25s_top_n_scores = bm25s_retriever(query, sentences, topk=k) # corpus->sentences | |
| # bm25s_top_n_sents = [sentences[i] for i in bm25s_top_n] | |
| # bm25s_top_n_titles = [titles[i] for i in bm25s_top_n] | |
| return bm25_top_n_sents, bm25_top_n_titles | |
| def process_topk(query, title2id, ent_list, k=1): | |
| doc_names, doc_scores = ranker.closest_docs(query, k) | |
| evidence_tfidf = [] | |
| for _name in doc_names: | |
| if _name not in ent_list and len(ent_list) < 5: | |
| ent_list.append(_name) | |
| idx = title2id[_name] | |
| introduction = doc_db.get_doc_intro(idx) | |
| evidence_tfidf.append([_name, introduction]) | |
| # fulltext = doc_db.get_doc_text(idx) | |
| # evidence_tfidf.append([_name,fulltext]) | |
| return evidence_tfidf | |
| def WikipediaDumpsretriever(claim): | |
| # | |
| # 1. extract relevant wikipedia pages from wikipedia dumps | |
| wiki_intro, doc = find_evidence_from_wikipedia_dumps(claim) | |
| # wiki_intro = [['trump', "'''Trump''' most commonly refers to:\n* Donald Trump (born 1946), President of the United States from 2017 to 2021 \n* Trump (card games), any playing card given an ad-hoc high rank\n\n'''Trump''' may also refer to:"]] | |
| # 2. extract relevant sentences from extracted wikipedia pages | |
| sents, titles = relevant_sentence_retrieval(claim, wiki_intro, k=3) | |
| # | |
| results = [] | |
| for i, (sent, title) in enumerate(zip(sents, titles)): | |
| metadata = dict() | |
| metadata['name'] = claim | |
| metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split()) | |
| metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split()) | |
| metadata['short_name'] = "Evidence {}".format(i + 1) | |
| metadata['page_number'] = "" | |
| metadata['query'] = sent | |
| metadata['title'] = title | |
| metadata['evidence'] = sent | |
| metadata['answer'] = "" | |
| metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata[ | |
| 'evidence'] | |
| page_content = f"""{metadata['page_content']}""" | |
| results.append(Docs(metadata, page_content)) | |
| return results | |
| # ----------WikipediaAPIretriever--------- | |
| def clean_str(p): | |
| return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8") | |
| def get_page_obs(page): | |
| # find all paragraphs | |
| paragraphs = page.split("\n") | |
| paragraphs = [p.strip() for p in paragraphs if p.strip()] | |
| # # find all sentence | |
| # sentences = [] | |
| # for p in paragraphs: | |
| # sentences += p.split('. ') | |
| # sentences = [s.strip() + '.' for s in sentences if s.strip()] | |
| # # return ' '.join(sentences[:5]) | |
| # return ' '.join(sentences) | |
| return ' '.join(paragraphs[:5]) | |
| def search_entity_wikipeida(entity): | |
| find_evidence = [] | |
| page_py = wiki_wiki.page(entity) | |
| if page_py.exists(): | |
| introduction = page_py.summary | |
| find_evidence.append([str(entity), introduction]) | |
| return find_evidence | |
| def search_step(entity): | |
| ent_ = entity.replace(" ", "+") | |
| search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}" | |
| response_text = requests.get(search_url).text | |
| soup = BeautifulSoup(response_text, features="html.parser") | |
| result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) | |
| find_evidence = [] | |
| if result_divs: # mismatch | |
| # If the wikipeida page of the entity is not exist, find similar wikipedia pages. | |
| result_titles = [clean_str(div.get_text().strip()) for div in result_divs] | |
| similar_titles = result_titles[:5] | |
| for _t in similar_titles: | |
| if len(find_evidence) < 5: | |
| _evi = search_step(_t) | |
| find_evidence.extend(_evi) | |
| else: | |
| page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")] | |
| if any("may refer to:" in p for p in page): | |
| _evi = search_step("[" + entity + "]") | |
| find_evidence.extend(_evi) | |
| else: | |
| # page_py = wiki_wiki.page(entity) | |
| # | |
| # if page_py.exists(): | |
| # introduction = page_py.summary | |
| # else: | |
| page_text = "" | |
| for p in page: | |
| if len(p.split(" ")) > 2: | |
| page_text += clean_str(p) | |
| if not p.endswith("\n"): | |
| page_text += "\n" | |
| introduction = get_page_obs(page_text) | |
| find_evidence.append([entity, introduction]) | |
| return find_evidence | |
| def find_similar_wikipedia(entity, relevant_wikipages): | |
| # If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages. | |
| ent_ = entity.replace(" ", "+") | |
| search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1" | |
| response_text = requests.get(search_url).text | |
| soup = BeautifulSoup(response_text, features="html.parser") | |
| result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) | |
| if result_divs: | |
| result_titles = [clean_str(div.get_text().strip()) for div in result_divs] | |
| similar_titles = result_titles[:5] | |
| saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages | |
| for _t in similar_titles: | |
| if _t not in saved_titles and len(relevant_wikipages) < 5: | |
| _evi = search_entity_wikipeida(_t) | |
| # _evi = search_step(_t) | |
| relevant_wikipages.extend(_evi) | |
| return relevant_wikipages | |
| def find_evidence_from_wikipedia(claim): | |
| # | |
| doc = nlp(claim) | |
| # | |
| wikipedia_page = [] | |
| for ent in doc.ents: | |
| relevant_wikipages = search_entity_wikipeida(ent) | |
| if len(relevant_wikipages) < 5: | |
| relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages) | |
| wikipedia_page.extend(relevant_wikipages) | |
| return wikipedia_page | |
| def relevant_wikipedia_API_retriever(claim): | |
| # | |
| doc = nlp(claim) | |
| wiki_intro = [] | |
| for ent in doc.ents: | |
| page_py = wiki_wiki.page(ent) | |
| if page_py.exists(): | |
| introduction = page_py.summary | |
| else: | |
| introduction = "No documents found." | |
| wiki_intro.append([str(ent), introduction]) | |
| return wiki_intro, doc | |
| def Wikipediaretriever(claim, sources): | |
| # | |
| # 1. extract relevant wikipedia pages from wikipedia dumps | |
| if "Dump" in sources: | |
| wikipedia_page = find_evidence_from_wikipedia_dumps(claim) | |
| else: | |
| wikipedia_page = find_evidence_from_wikipedia(claim) | |
| # wiki_intro, doc = relevant_wikipedia_API_retriever(claim) | |
| # 2. extract relevant sentences from extracted wikipedia pages | |
| sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3) | |
| # | |
| results = [] | |
| for i, (sent, title) in enumerate(zip(sents, titles)): | |
| metadata = dict() | |
| metadata['name'] = claim | |
| metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split()) | |
| metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title) | |
| metadata['short_name'] = "Evidence {}".format(i + 1) | |
| metadata['page_number'] = "" | |
| metadata['query'] = sent | |
| metadata['title'] = title | |
| metadata['evidence'] = sent | |
| metadata['answer'] = "" | |
| metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata['evidence'] | |
| page_content = f"""{metadata['page_content']}""" | |
| results.append(Docs(metadata, page_content)) | |
| return results | |
| def log_on_azure(file, logs, azure_share_client): | |
| logs = json.dumps(logs) | |
| file_client = azure_share_client.get_file_client(file) | |
| file_client.upload_file(logs) | |
| def chat(claim, history, sources): | |
| evidence = [] | |
| if 'Google' in sources: | |
| evidence = Googleretriever(claim, sources) | |
| if 'WikiPedia' in sources: | |
| evidence = Wikipediaretriever(claim, sources) | |
| answer_set, answer_output = QAprediction(claim, evidence, sources) | |
| docs_html = "" | |
| if len(evidence) > 0: | |
| docs_html = [] | |
| for i, evi in enumerate(evidence, 1): | |
| docs_html.append(make_html_source(evi, i)) | |
| docs_html = "".join(docs_html) | |
| else: | |
| print("No documents found") | |
| url_of_evidence = "" | |
| output_language = "English" | |
| output_query = claim | |
| history[-1] = (claim, answer_set) | |
| history = [tuple(x) for x in history] | |
| ############################################################ | |
| evi_list = [] | |
| for evi in evidence: | |
| title_str = evi.metadata['title'] | |
| evi_str = evi.metadata['evidence'] | |
| url_str = evi.metadata['url'] | |
| evi_list.append([title_str, evi_str, url_str]) | |
| try: | |
| # Log answer on Azure Blob Storage | |
| # IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client. | |
| if os.environ["AZURE_ISSAVE"] == "TRUE": | |
| # timestamp = str(datetime.now().timestamp()) | |
| timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") | |
| file = timestamp + ".json" | |
| logs = { | |
| "user_id": str(user_id), | |
| "claim": claim, | |
| "sources": sources, | |
| "evidence": evi_list, | |
| "answer": answer_output, | |
| "time": timestamp, | |
| } | |
| log_on_azure(file, logs, azure_share_client) | |
| except Exception as e: | |
| print(f"Error logging on Azure Blob Storage: {e}") | |
| raise gr.Error( | |
| f"AVeriTeC Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") | |
| ########## | |
| return history, docs_html, output_query, output_language | |
| def main(): | |
| init_prompt = """ | |
| Hello, I am a fact-checking assistant designed to help you find appropriate evidence to predict the veracity of claims. | |
| What do you want to fact-check? | |
| """ | |
| with gr.Blocks(title="AVeriTeC fact-checker", css="style.css", theme=theme, elem_id="main-component") as demo: | |
| with gr.Tab("AVeriTeC"): | |
| with gr.Row(elem_id="chatbot-row"): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| value=[(None, init_prompt)], | |
| show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel", | |
| avatar_images = (None, "assets/averitec.png") | |
| ) # avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"), | |
| with gr.Row(elem_id="input-message"): | |
| textbox = gr.Textbox(placeholder="Ask me what claim do you want to check!", show_label=False, | |
| scale=7, lines=1, interactive=True, elem_id="input-textbox") | |
| # submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png") | |
| with gr.Column(scale=1, variant="panel", elem_id="right-panel"): | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("Examples", elem_id="tab-examples", id=0): | |
| examples_hidden = gr.Textbox(visible=False) | |
| first_key = list(CLAIMS_Type.keys())[0] | |
| dropdown_samples = gr.Dropdown(CLAIMS_Type.keys(), value=first_key, interactive=True, | |
| show_label=True, | |
| label="Select claim type", | |
| elem_id="dropdown-samples") | |
| samples = [] | |
| for i, key in enumerate(CLAIMS_Type.keys()): | |
| examples_visible = True if i == 0 else False | |
| with gr.Row(visible=examples_visible) as group_examples: | |
| examples_questions = gr.Examples( | |
| CLAIMS_Type[key], | |
| [examples_hidden], | |
| examples_per_page=8, | |
| run_on_click=False, | |
| elem_id=f"examples{i}", | |
| api_name=f"examples{i}", | |
| # label = "Click on the example question or enter your own", | |
| # cache_examples=True, | |
| ) | |
| samples.append(group_examples) | |
| with gr.Tab("Sources", elem_id="tab-citations", id=1): | |
| sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox") | |
| docs_textbox = gr.State("") | |
| with gr.Tab("Configuration", elem_id="tab-config", id=2): | |
| gr.Markdown("Reminder: We currently only support fact-checking in English!") | |
| # dropdown_sources = gr.Radio( | |
| # ["AVeriTeC", "WikiPediaDumps", "Google", "WikiPediaAPI"], | |
| # label="Select source", | |
| # value="WikiPediaAPI", | |
| # interactive=True, | |
| # ) | |
| dropdown_sources = gr.Radio( | |
| ["Google", "WikiPedia"], | |
| label="Select source", | |
| value="WikiPedia", | |
| interactive=True, | |
| ) | |
| dropdown_retriever = gr.Dropdown( | |
| ["BM25", "BM25s"], | |
| label="Select evidence retriever", | |
| multiselect=False, | |
| value="BM25", | |
| interactive=True, | |
| ) | |
| output_query = gr.Textbox(label="Query used for retrieval", show_label=True, | |
| elem_id="reformulated-query", lines=2, interactive=False) | |
| output_language = gr.Textbox(label="Language", show_label=True, elem_id="language", lines=1, | |
| interactive=False) | |
| with gr.Tab("About", elem_classes="max-height other-tabs"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("See more info at [https://fever.ai/task.html](https://fever.ai/task.html)") | |
| def start_chat(query, history): | |
| history = history + [(query, None)] | |
| history = [tuple(x) for x in history] | |
| return (gr.update(interactive=False), gr.update(selected=1), history) | |
| def finish_chat(): | |
| return (gr.update(interactive=True, value="")) | |
| (textbox | |
| .submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox") | |
| .then(chat, [textbox, chatbot, dropdown_sources], | |
| [chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_textbox") | |
| .then(finish_chat, None, [textbox], api_name="finish_chat_textbox") | |
| ) | |
| (examples_hidden | |
| .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, | |
| api_name="start_chat_examples") | |
| .then(chat, [examples_hidden, chatbot, dropdown_sources], | |
| [chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_examples") | |
| .then(finish_chat, None, [textbox], api_name="finish_chat_examples") | |
| ) | |
| def change_sample_questions(key): | |
| index = list(CLAIMS_Type.keys()).index(key) | |
| visible_bools = [False] * len(samples) | |
| visible_bools[index] = True | |
| return [gr.update(visible=visible_bools[i]) for i in range(len(samples))] | |
| dropdown_samples.change(change_sample_questions, dropdown_samples, samples) | |
| demo.queue() | |
| demo.launch() | |
| # demo.launch(share=True) | |
| if __name__ == "__main__": | |
| main() | |