| | import argparse |
| | import time |
| | import json |
| | import nltk |
| | from rank_bm25 import BM25Okapi |
| | import numpy as np |
| | import torch |
| | from transformers import BloomTokenizerFast, BloomForCausalLM |
| |
|
| |
|
| | def claim2prompts(example): |
| | claim = example["claim"] |
| |
|
| | |
| | claim_str = "Evidence: " |
| |
|
| | for question in example["questions"]: |
| | q_text = question["question"].strip() |
| | if len(q_text) == 0: |
| | continue |
| |
|
| | if not q_text[-1] == "?": |
| | q_text += "?" |
| |
|
| | answer_strings = [] |
| |
|
| | for a in question["answers"]: |
| | if a["answer_type"] in ["Extractive", "Abstractive"]: |
| | answer_strings.append(a["answer"]) |
| | if a["answer_type"] == "Boolean": |
| | answer_strings.append( |
| | a["answer"] |
| | + ", because " |
| | + a["boolean_explanation"].lower().strip() |
| | ) |
| |
|
| | for a_text in answer_strings: |
| | if not a_text[-1] in [".", "!", ":", "?"]: |
| | a_text += "." |
| |
|
| | |
| | prompt_lookup_str = a_text |
| | this_q_claim_str = ( |
| | claim_str + " " + a_text.strip() + "||Question answered: " + q_text |
| | ) |
| | yield ( |
| | prompt_lookup_str, |
| | this_q_claim_str.replace("\n", " ").replace("||", "\n"), |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser( |
| | description="Use a prompt to generate questions that could be answered by top-k retrieved evidence. Output generated questions." |
| | ) |
| | parser.add_argument("--reference_corpus", default="data/train.json", help="") |
| | parser.add_argument("--target_file", default="data/dev.json", help="") |
| | parser.add_argument( |
| | "-i", |
| | "--top_k_target_knowledge", |
| | default="data_store/dev_top_k_sentences.json", |
| | help="Directory where the sentences for the scraped data is saved.", |
| | ) |
| | parser.add_argument( |
| | "-o", |
| | "--output_questions", |
| | default="data_store/dev_top_k_qa.json", |
| | help="Directory where the sentences for the scraped data is saved.", |
| | ) |
| | parser.add_argument( |
| | "--top_k", |
| | default=100, |
| | type=int, |
| | help="How many documents should we pick out with BM25", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | |
| | with open(args.reference_corpus, "r", encoding="utf-8") as json_file: |
| | train_examples = json.load(json_file) |
| |
|
| | prompt_corpus, tokenized_corpus = [], [] |
| |
|
| | for example in train_examples: |
| | for lookup_str, prompt in claim2prompts(example): |
| | entry = nltk.word_tokenize(lookup_str) |
| | tokenized_corpus.append(entry) |
| | prompt_corpus.append(prompt) |
| |
|
| | prompt_bm25 = BM25Okapi(tokenized_corpus) |
| |
|
| | |
| | tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1") |
| | model = BloomForCausalLM.from_pretrained( |
| | "bigscience/bloom-7b1", |
| | device_map="auto", |
| | torch_dtype=torch.bfloat16, |
| | offload_folder="./offload", |
| | ) |
| |
|
| | with open(args.output_questions, "w", encoding="utf-8") as output_file: |
| | with open(args.top_k_target_knowledge, "r", encoding="utf-8") as json_file: |
| | for i, line in enumerate(json_file): |
| | data = json.loads(line) |
| | top_k_sentences_urls = data[f"top_{args.top_k}"] |
| | claim = data["claim"] |
| | claim_id = data["claim_id"] |
| |
|
| | bm25_qau = [] |
| | |
| | for sent_i, sentences_urls in enumerate(top_k_sentences_urls): |
| |
|
| | prompt_lookup_str = sentences_urls["sentence"] |
| | url = sentences_urls["url"] |
| |
|
| | prompt_s = prompt_bm25.get_scores( |
| | nltk.word_tokenize(prompt_lookup_str) |
| | ) |
| | prompt_n = 10 |
| | prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n] |
| | prompt_docs = [prompt_corpus[i] for i in prompt_top_n] |
| |
|
| | claim_prompt = ( |
| | "Evidence: " |
| | + prompt_lookup_str.replace("\n", " ") |
| | + "\nQuestion answered: " |
| | ) |
| |
|
| | prompt = "\n\n".join(prompt_docs + [claim_prompt]) |
| |
|
| | inputs = tokenizer([prompt], padding=True, return_tensors="pt").to( |
| | model.device |
| | ) |
| | st = time.time() |
| | outputs = model.generate( |
| | inputs["input_ids"], |
| | max_length=5000, |
| | num_beams=2, |
| | no_repeat_ngram_size=2, |
| | early_stopping=True, |
| | ) |
| | print( |
| | f"Generated QA for sent {sent_i} in file {i}. Time elapsed: {time.time() - st}" |
| | ) |
| |
|
| | tgt_text = tokenizer.batch_decode( |
| | outputs[:, inputs["input_ids"].shape[-1] :], |
| | skip_special_tokens=True, |
| | )[0] |
| |
|
| | |
| | tgt_text = tgt_text[:250] |
| |
|
| | qau_pair = [ |
| | tgt_text.strip().split("?")[0].replace("\n", " ") + "?", |
| | prompt_lookup_str.replace("\n", " "), |
| | url, |
| | ] |
| |
|
| | bm25_qau.append(qau_pair) |
| |
|
| | json_data = { |
| | "claim_id": claim_id, |
| | "claim": claim, |
| | "bm25_qau": bm25_qau, |
| | } |
| | output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n") |
| | output_file.flush() |
| |
|