Spaces:
Build error
Build error
zhenyundeng
commited on
Commit
·
016ab20
1
Parent(s):
200e5b6
update files
Browse files- .gitattributes +3 -1
- app.py +179 -11
- utils.py +3 -1
.gitattributes
CHANGED
|
@@ -25,7 +25,6 @@
|
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
@@ -33,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 28 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.db filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
|
app.py
CHANGED
|
@@ -69,7 +69,9 @@ nlp = spacy.load("en_core_web_sm")
|
|
| 69 |
# ---------------------------------------------------------------------------
|
| 70 |
# Load sample dict for AVeriTeC search
|
| 71 |
# all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r'))
|
|
|
|
| 72 |
|
|
|
|
| 73 |
# ---------------------------------------------------------------------------
|
| 74 |
# ---------- Load pretrained models ----------
|
| 75 |
# ---------- load Evidence retrieval model ----------
|
|
@@ -424,9 +426,8 @@ def QAprediction(claim, evidence, sources):
|
|
| 424 |
|
| 425 |
# ----------GoogleAPIretriever---------
|
| 426 |
def generate_reference_corpus(reference_file):
|
| 427 |
-
with open(reference_file) as f:
|
| 428 |
-
|
| 429 |
-
train_examples = json.load(f)
|
| 430 |
|
| 431 |
all_data_corpus = []
|
| 432 |
tokenized_corpus = []
|
|
@@ -578,6 +579,12 @@ def get_and_store(url_link, fp, worker, worker_stack):
|
|
| 578 |
gc.collect()
|
| 579 |
|
| 580 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0):
|
| 582 |
search_results = []
|
| 583 |
for i in range(3):
|
|
@@ -599,7 +606,7 @@ def get_google_search_results(api_key, search_engine_id, google_search, sort_dat
|
|
| 599 |
return search_results
|
| 600 |
|
| 601 |
|
| 602 |
-
def
|
| 603 |
# default config
|
| 604 |
api_key = os.environ["GOOGLE_API_KEY"]
|
| 605 |
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
|
|
@@ -651,7 +658,6 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
|
|
| 651 |
for page_num in range(n_pages):
|
| 652 |
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
|
| 653 |
this_search_string, page=page_num)
|
| 654 |
-
search_results = search_results[:5]
|
| 655 |
|
| 656 |
for result in search_results:
|
| 657 |
link = str(result["link"])
|
|
@@ -668,8 +674,6 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
|
|
| 668 |
if link.endswith(".pdf") or link.endswith(".doc"):
|
| 669 |
continue
|
| 670 |
|
| 671 |
-
store_file_path = ""
|
| 672 |
-
|
| 673 |
if link in visited:
|
| 674 |
store_file_path = visited[link]
|
| 675 |
else:
|
|
@@ -678,7 +682,7 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
|
|
| 678 |
store_counter) + ".store"
|
| 679 |
visited[link] = store_file_path
|
| 680 |
|
| 681 |
-
while len(worker_stack) == 0: # Wait for a
|
| 682 |
sleep(1)
|
| 683 |
|
| 684 |
worker = worker_stack.pop()
|
|
@@ -692,6 +696,89 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
|
|
| 692 |
return retrieve_evidence
|
| 693 |
|
| 694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
def claim2prompts(example):
|
| 696 |
claim = example["claim"]
|
| 697 |
|
|
@@ -725,8 +812,8 @@ def claim2prompts(example):
|
|
| 725 |
|
| 726 |
|
| 727 |
def generate_step2_reference_corpus(reference_file):
|
| 728 |
-
with open(reference_file) as f:
|
| 729 |
-
|
| 730 |
|
| 731 |
prompt_corpus = []
|
| 732 |
tokenized_corpus = []
|
|
@@ -762,6 +849,87 @@ def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
|
|
| 762 |
tokenized_corpus = []
|
| 763 |
all_data_corpus = []
|
| 764 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
for retri_evi in tqdm.tqdm(retrieve_evidence):
|
| 766 |
store_file = retri_evi[-1]
|
| 767 |
|
|
@@ -1222,7 +1390,7 @@ def chat(claim, history, sources):
|
|
| 1222 |
try:
|
| 1223 |
# Log answer on Azure Blob Storage
|
| 1224 |
# IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
|
| 1225 |
-
if
|
| 1226 |
timestamp = str(datetime.now().timestamp())
|
| 1227 |
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 1228 |
file = timestamp + ".json"
|
|
|
|
| 69 |
# ---------------------------------------------------------------------------
|
| 70 |
# Load sample dict for AVeriTeC search
|
| 71 |
# all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r'))
|
| 72 |
+
train_examples = json.load(open('averitec/data/train.json', 'r'))
|
| 73 |
|
| 74 |
+
print(train_examples[0]['claim'])
|
| 75 |
# ---------------------------------------------------------------------------
|
| 76 |
# ---------- Load pretrained models ----------
|
| 77 |
# ---------- load Evidence retrieval model ----------
|
|
|
|
| 426 |
|
| 427 |
# ----------GoogleAPIretriever---------
|
| 428 |
def generate_reference_corpus(reference_file):
|
| 429 |
+
# with open(reference_file) as f:
|
| 430 |
+
# train_examples = json.load(f)
|
|
|
|
| 431 |
|
| 432 |
all_data_corpus = []
|
| 433 |
tokenized_corpus = []
|
|
|
|
| 579 |
gc.collect()
|
| 580 |
|
| 581 |
|
| 582 |
+
def get_text_from_link(url_link):
|
| 583 |
+
page_lines = url2lines(url_link)
|
| 584 |
+
|
| 585 |
+
return "\n".join([url_link] + page_lines)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0):
|
| 589 |
search_results = []
|
| 590 |
for i in range(3):
|
|
|
|
| 606 |
return search_results
|
| 607 |
|
| 608 |
|
| 609 |
+
def averitec_search_michael(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
|
| 610 |
# default config
|
| 611 |
api_key = os.environ["GOOGLE_API_KEY"]
|
| 612 |
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
|
|
|
|
| 658 |
for page_num in range(n_pages):
|
| 659 |
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
|
| 660 |
this_search_string, page=page_num)
|
|
|
|
| 661 |
|
| 662 |
for result in search_results:
|
| 663 |
link = str(result["link"])
|
|
|
|
| 674 |
if link.endswith(".pdf") or link.endswith(".doc"):
|
| 675 |
continue
|
| 676 |
|
|
|
|
|
|
|
| 677 |
if link in visited:
|
| 678 |
store_file_path = visited[link]
|
| 679 |
else:
|
|
|
|
| 682 |
store_counter) + ".store"
|
| 683 |
visited[link] = store_file_path
|
| 684 |
|
| 685 |
+
while len(worker_stack) == 0: # Wait for a worker to become available. Check every second.
|
| 686 |
sleep(1)
|
| 687 |
|
| 688 |
worker = worker_stack.pop()
|
|
|
|
| 696 |
return retrieve_evidence
|
| 697 |
|
| 698 |
|
| 699 |
+
def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
|
| 700 |
+
# default config
|
| 701 |
+
api_key = os.environ["GOOGLE_API_KEY"]
|
| 702 |
+
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
|
| 703 |
+
|
| 704 |
+
blacklist = [
|
| 705 |
+
"jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download
|
| 706 |
+
"facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this,
|
| 707 |
+
"ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up
|
| 708 |
+
"nlp.cs.princeton.edu",
|
| 709 |
+
"huggingface.co"
|
| 710 |
+
]
|
| 711 |
+
|
| 712 |
+
blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors
|
| 713 |
+
"/glove.",
|
| 714 |
+
"ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt",
|
| 715 |
+
"https://web.mit.edu/adamrose/Public/googlelist",
|
| 716 |
+
]
|
| 717 |
+
|
| 718 |
+
# save to folder
|
| 719 |
+
store_folder = "averitec/data/store/retrieved_docs"
|
| 720 |
+
#
|
| 721 |
+
index = 0
|
| 722 |
+
questions = [q["question"] for q in generate_question]
|
| 723 |
+
|
| 724 |
+
# check the date of the claim
|
| 725 |
+
current_date = datetime.now().strftime("%Y-%m-%d")
|
| 726 |
+
sort_date = check_claim_date(current_date) # check_date="2022-01-01"
|
| 727 |
+
|
| 728 |
+
#
|
| 729 |
+
search_strings = []
|
| 730 |
+
search_types = []
|
| 731 |
+
|
| 732 |
+
search_string_2 = string_to_search_query(claim, None)
|
| 733 |
+
search_strings += [search_string_2, claim, ]
|
| 734 |
+
search_types += ["claim", "claim-noformat", ]
|
| 735 |
+
|
| 736 |
+
search_strings += questions
|
| 737 |
+
search_types += ["question" for _ in questions]
|
| 738 |
+
|
| 739 |
+
# start to search
|
| 740 |
+
search_results = []
|
| 741 |
+
visited = {}
|
| 742 |
+
store_counter = 0
|
| 743 |
+
worker_stack = list(range(10))
|
| 744 |
+
|
| 745 |
+
retrieve_evidence = []
|
| 746 |
+
|
| 747 |
+
for this_search_string, this_search_type in zip(search_strings, search_types):
|
| 748 |
+
for page_num in range(n_pages):
|
| 749 |
+
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
|
| 750 |
+
this_search_string, page=page_num)
|
| 751 |
+
search_results = search_results[:5]
|
| 752 |
+
|
| 753 |
+
for result in search_results:
|
| 754 |
+
link = str(result["link"])
|
| 755 |
+
domain = get_domain_name(link)
|
| 756 |
+
|
| 757 |
+
if domain in blacklist:
|
| 758 |
+
continue
|
| 759 |
+
broken = False
|
| 760 |
+
for b_file in blacklist_files:
|
| 761 |
+
if b_file in link:
|
| 762 |
+
broken = True
|
| 763 |
+
if broken:
|
| 764 |
+
continue
|
| 765 |
+
if link.endswith(".pdf") or link.endswith(".doc"):
|
| 766 |
+
continue
|
| 767 |
+
|
| 768 |
+
store_file_path = ""
|
| 769 |
+
|
| 770 |
+
if link in visited:
|
| 771 |
+
web_text = visited[link]
|
| 772 |
+
else:
|
| 773 |
+
web_text = get_text_from_link(link)
|
| 774 |
+
visited[link] = web_text
|
| 775 |
+
|
| 776 |
+
line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, web_text]
|
| 777 |
+
retrieve_evidence.append(line)
|
| 778 |
+
|
| 779 |
+
return retrieve_evidence
|
| 780 |
+
|
| 781 |
+
|
| 782 |
def claim2prompts(example):
|
| 783 |
claim = example["claim"]
|
| 784 |
|
|
|
|
| 812 |
|
| 813 |
|
| 814 |
def generate_step2_reference_corpus(reference_file):
|
| 815 |
+
# with open(reference_file) as f:
|
| 816 |
+
# train_examples = json.load(f)
|
| 817 |
|
| 818 |
prompt_corpus = []
|
| 819 |
tokenized_corpus = []
|
|
|
|
| 849 |
tokenized_corpus = []
|
| 850 |
all_data_corpus = []
|
| 851 |
|
| 852 |
+
for retri_evi in tqdm.tqdm(retrieve_evidence):
|
| 853 |
+
# store_file = retri_evi[-1]
|
| 854 |
+
# with open(store_file, 'r') as f:
|
| 855 |
+
web_text = retri_evi[-1]
|
| 856 |
+
lines_in_web = web_text.split("\n")
|
| 857 |
+
|
| 858 |
+
first = True
|
| 859 |
+
for line in lines_in_web:
|
| 860 |
+
# for line in f:
|
| 861 |
+
line = line.strip()
|
| 862 |
+
|
| 863 |
+
if first:
|
| 864 |
+
first = False
|
| 865 |
+
location_url = line
|
| 866 |
+
continue
|
| 867 |
+
|
| 868 |
+
if len(line) > 3:
|
| 869 |
+
entry = nltk.word_tokenize(line)
|
| 870 |
+
if (location_url, line) not in all_data_corpus:
|
| 871 |
+
tokenized_corpus.append(entry)
|
| 872 |
+
all_data_corpus.append((location_url, line))
|
| 873 |
+
|
| 874 |
+
if len(tokenized_corpus) == 0:
|
| 875 |
+
print("")
|
| 876 |
+
|
| 877 |
+
bm25 = BM25Okapi(tokenized_corpus)
|
| 878 |
+
s = bm25.get_scores(nltk.word_tokenize(claim))
|
| 879 |
+
top_n = np.argsort(s)[::-1][:top_k]
|
| 880 |
+
docs = [all_data_corpus[i] for i in top_n]
|
| 881 |
+
|
| 882 |
+
generate_qa_pairs = []
|
| 883 |
+
# Then, generate questions for those top 50:
|
| 884 |
+
for doc in tqdm.tqdm(docs):
|
| 885 |
+
# prompt_lookup_str = example["claim"] + " " + doc[1]
|
| 886 |
+
prompt_lookup_str = doc[1]
|
| 887 |
+
|
| 888 |
+
prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
|
| 889 |
+
prompt_n = 10
|
| 890 |
+
prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
|
| 891 |
+
prompt_docs = [prompt_corpus[i] for i in prompt_top_n]
|
| 892 |
+
|
| 893 |
+
claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
|
| 894 |
+
prompt = "\n\n".join(prompt_docs + [claim_prompt])
|
| 895 |
+
sentences = [prompt]
|
| 896 |
+
|
| 897 |
+
inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
|
| 898 |
+
outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
|
| 899 |
+
early_stopping=True)
|
| 900 |
+
|
| 901 |
+
tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
|
| 902 |
+
# We are not allowed to generate more than 250 characters:
|
| 903 |
+
tgt_text = tgt_text[:250]
|
| 904 |
+
|
| 905 |
+
qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
|
| 906 |
+
generate_qa_pairs.append(qa_pair)
|
| 907 |
+
|
| 908 |
+
return generate_qa_pairs
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
def decorate_with_questions_michale(claim, retrieve_evidence, top_k=10): # top_k=100
|
| 912 |
+
#
|
| 913 |
+
reference_file = "averitec/data/train.json"
|
| 914 |
+
tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
|
| 915 |
+
prompt_bm25 = BM25Okapi(tokenized_corpus)
|
| 916 |
+
|
| 917 |
+
# Define the bloom model:
|
| 918 |
+
accelerator = Accelerator()
|
| 919 |
+
accel_device = accelerator.device
|
| 920 |
+
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 921 |
+
# tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
|
| 922 |
+
# model = BloomForCausalLM.from_pretrained(
|
| 923 |
+
# "bigscience/bloom-7b1",
|
| 924 |
+
# device_map="auto",
|
| 925 |
+
# torch_dtype=torch.bfloat16,
|
| 926 |
+
# offload_folder="./offload"
|
| 927 |
+
# )
|
| 928 |
+
|
| 929 |
+
#
|
| 930 |
+
tokenized_corpus = []
|
| 931 |
+
all_data_corpus = []
|
| 932 |
+
|
| 933 |
for retri_evi in tqdm.tqdm(retrieve_evidence):
|
| 934 |
store_file = retri_evi[-1]
|
| 935 |
|
|
|
|
| 1390 |
try:
|
| 1391 |
# Log answer on Azure Blob Storage
|
| 1392 |
# IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
|
| 1393 |
+
if os.environ["AZURE_ISSAVE"] == "TRUE":
|
| 1394 |
timestamp = str(datetime.now().timestamp())
|
| 1395 |
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 1396 |
file = timestamp + ".json"
|
utils.py
CHANGED
|
@@ -2,11 +2,13 @@ import numpy as np
|
|
| 2 |
import random
|
| 3 |
import string
|
| 4 |
import uuid
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def create_user_id():
|
| 8 |
"""Create user_id
|
| 9 |
str: String to id user
|
| 10 |
"""
|
|
|
|
| 11 |
user_id = str(uuid.uuid4())
|
| 12 |
-
return user_id
|
|
|
|
| 2 |
import random
|
| 3 |
import string
|
| 4 |
import uuid
|
| 5 |
+
from datetime import datetime
|
| 6 |
|
| 7 |
|
| 8 |
def create_user_id():
|
| 9 |
"""Create user_id
|
| 10 |
str: String to id user
|
| 11 |
"""
|
| 12 |
+
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
| 13 |
user_id = str(uuid.uuid4())
|
| 14 |
+
return current_date + '_' +user_id
|