Spaces:
Build error
Build error
zhenyundeng commited on
Commit ·
7168c2f
1
Parent(s): 8d2d2b1
update
Browse files- app.py +81 -76
- requirements.txt +3 -2
app.py
CHANGED
|
@@ -15,6 +15,7 @@ import json
|
|
| 15 |
import pytorch_lightning as pl
|
| 16 |
from urllib.parse import urlparse
|
| 17 |
from accelerate import Accelerator
|
|
|
|
| 18 |
|
| 19 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
| 20 |
from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
|
|
@@ -273,6 +274,7 @@ def fever_veracity_prediction(claim, evidence):
|
|
| 273 |
return pred_label
|
| 274 |
|
| 275 |
|
|
|
|
| 276 |
def veracity_prediction(claim, qa_evidence):
|
| 277 |
# bert_model_name = "bert-base-uncased"
|
| 278 |
# tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
|
@@ -375,6 +377,7 @@ def google_justification_generation(claim, qa_evidence, verdict_label):
|
|
| 375 |
return pred_justification.strip()
|
| 376 |
|
| 377 |
|
|
|
|
| 378 |
def justification_generation(claim, qa_evidence, verdict_label):
|
| 379 |
#
|
| 380 |
claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
|
|
@@ -465,6 +468,7 @@ def docs2prompt(top_docs):
|
|
| 465 |
return "\n\n".join([doc2prompt(d) for d in top_docs])
|
| 466 |
|
| 467 |
|
|
|
|
| 468 |
def prompt_question_generation(test_claim, speaker="they", topk=10):
|
| 469 |
#
|
| 470 |
reference_file = "averitec/data/train.json"
|
|
@@ -926,88 +930,89 @@ def decorate_with_questions(claim, retrieve_evidence, top_k=5): # top_k=10, 100
|
|
| 926 |
return generate_qa_pairs
|
| 927 |
|
| 928 |
|
| 929 |
-
def decorate_with_questions_michale(claim, retrieve_evidence, top_k=10): # top_k=100
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
|
| 1006 |
|
| 1007 |
def triple_to_string(x):
|
| 1008 |
return " </s> ".join([item.strip() for item in x])
|
| 1009 |
|
| 1010 |
|
|
|
|
| 1011 |
def rerank_questions(claim, bm25_qas, topk=3):
|
| 1012 |
#
|
| 1013 |
# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
|
|
| 15 |
import pytorch_lightning as pl
|
| 16 |
from urllib.parse import urlparse
|
| 17 |
from accelerate import Accelerator
|
| 18 |
+
import spaces
|
| 19 |
|
| 20 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
| 21 |
from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
|
|
|
|
| 274 |
return pred_label
|
| 275 |
|
| 276 |
|
| 277 |
+
@spaces.GPU
|
| 278 |
def veracity_prediction(claim, qa_evidence):
|
| 279 |
# bert_model_name = "bert-base-uncased"
|
| 280 |
# tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
|
|
|
| 377 |
return pred_justification.strip()
|
| 378 |
|
| 379 |
|
| 380 |
+
@spaces.GPU
|
| 381 |
def justification_generation(claim, qa_evidence, verdict_label):
|
| 382 |
#
|
| 383 |
claim_str = extract_claim_str(claim, qa_evidence, verdict_label)
|
|
|
|
| 468 |
return "\n\n".join([doc2prompt(d) for d in top_docs])
|
| 469 |
|
| 470 |
|
| 471 |
+
@spaces.GPU
|
| 472 |
def prompt_question_generation(test_claim, speaker="they", topk=10):
|
| 473 |
#
|
| 474 |
reference_file = "averitec/data/train.json"
|
|
|
|
| 930 |
return generate_qa_pairs
|
| 931 |
|
| 932 |
|
| 933 |
+
# def decorate_with_questions_michale(claim, retrieve_evidence, top_k=10): # top_k=100
|
| 934 |
+
# #
|
| 935 |
+
# reference_file = "averitec/data/train.json"
|
| 936 |
+
# tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
|
| 937 |
+
# prompt_bm25 = BM25Okapi(tokenized_corpus)
|
| 938 |
+
#
|
| 939 |
+
# # Define the bloom model:
|
| 940 |
+
# accelerator = Accelerator()
|
| 941 |
+
# accel_device = accelerator.device
|
| 942 |
+
# # device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 943 |
+
# # tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
|
| 944 |
+
# # model = BloomForCausalLM.from_pretrained(
|
| 945 |
+
# # "bigscience/bloom-7b1",
|
| 946 |
+
# # device_map="auto",
|
| 947 |
+
# # torch_dtype=torch.bfloat16,
|
| 948 |
+
# # offload_folder="./offload"
|
| 949 |
+
# # )
|
| 950 |
+
#
|
| 951 |
+
# #
|
| 952 |
+
# tokenized_corpus = []
|
| 953 |
+
# all_data_corpus = []
|
| 954 |
+
#
|
| 955 |
+
# for retri_evi in tqdm.tqdm(retrieve_evidence):
|
| 956 |
+
# store_file = retri_evi[-1]
|
| 957 |
+
#
|
| 958 |
+
# with open(store_file, 'r') as f:
|
| 959 |
+
# first = True
|
| 960 |
+
# for line in f:
|
| 961 |
+
# line = line.strip()
|
| 962 |
+
#
|
| 963 |
+
# if first:
|
| 964 |
+
# first = False
|
| 965 |
+
# location_url = line
|
| 966 |
+
# continue
|
| 967 |
+
#
|
| 968 |
+
# if len(line) > 3:
|
| 969 |
+
# entry = nltk.word_tokenize(line)
|
| 970 |
+
# if (location_url, line) not in all_data_corpus:
|
| 971 |
+
# tokenized_corpus.append(entry)
|
| 972 |
+
# all_data_corpus.append((location_url, line))
|
| 973 |
+
#
|
| 974 |
+
# if len(tokenized_corpus) == 0:
|
| 975 |
+
# print("")
|
| 976 |
+
#
|
| 977 |
+
# bm25 = BM25Okapi(tokenized_corpus)
|
| 978 |
+
# s = bm25.get_scores(nltk.word_tokenize(claim))
|
| 979 |
+
# top_n = np.argsort(s)[::-1][:top_k]
|
| 980 |
+
# docs = [all_data_corpus[i] for i in top_n]
|
| 981 |
+
#
|
| 982 |
+
# generate_qa_pairs = []
|
| 983 |
+
# # Then, generate questions for those top 50:
|
| 984 |
+
# for doc in tqdm.tqdm(docs):
|
| 985 |
+
# # prompt_lookup_str = example["claim"] + " " + doc[1]
|
| 986 |
+
# prompt_lookup_str = doc[1]
|
| 987 |
+
#
|
| 988 |
+
# prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
|
| 989 |
+
# prompt_n = 10
|
| 990 |
+
# prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
|
| 991 |
+
# prompt_docs = [prompt_corpus[i] for i in prompt_top_n]
|
| 992 |
+
#
|
| 993 |
+
# claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
|
| 994 |
+
# prompt = "\n\n".join(prompt_docs + [claim_prompt])
|
| 995 |
+
# sentences = [prompt]
|
| 996 |
+
#
|
| 997 |
+
# inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
|
| 998 |
+
# outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
|
| 999 |
+
# early_stopping=True)
|
| 1000 |
+
#
|
| 1001 |
+
# tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
|
| 1002 |
+
# # We are not allowed to generate more than 250 characters:
|
| 1003 |
+
# tgt_text = tgt_text[:250]
|
| 1004 |
+
#
|
| 1005 |
+
# qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
|
| 1006 |
+
# generate_qa_pairs.append(qa_pair)
|
| 1007 |
+
#
|
| 1008 |
+
# return generate_qa_pairs
|
| 1009 |
|
| 1010 |
|
| 1011 |
def triple_to_string(x):
|
| 1012 |
return " </s> ".join([item.strip() for item in x])
|
| 1013 |
|
| 1014 |
|
| 1015 |
+
@spaces.GPU
|
| 1016 |
def rerank_questions(claim, bm25_qas, topk=3):
|
| 1017 |
#
|
| 1018 |
# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
requirements.txt
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
gradio
|
| 2 |
-
nltk
|
| 3 |
rank_bm25
|
| 4 |
accelerate
|
| 5 |
trafilatura
|
| 6 |
-
spacy
|
| 7 |
pytorch_lightning
|
| 8 |
transformers==4.29.2
|
| 9 |
datasets
|
|
@@ -21,3 +21,4 @@ azure-storage-blob
|
|
| 21 |
bm25s
|
| 22 |
PyStemmer
|
| 23 |
lxml_html_clean
|
|
|
|
|
|
| 1 |
gradio
|
| 2 |
+
nltk==3.8.1
|
| 3 |
rank_bm25
|
| 4 |
accelerate
|
| 5 |
trafilatura
|
| 6 |
+
spacy==3.7.5
|
| 7 |
pytorch_lightning
|
| 8 |
transformers==4.29.2
|
| 9 |
datasets
|
|
|
|
| 21 |
bm25s
|
| 22 |
PyStemmer
|
| 23 |
lxml_html_clean
|
| 24 |
+
spaces
|