theachyuttiwari commited on
Commit
4eeb7ed
·
1 Parent(s): 8a678af

Upload eval_generate.py

Browse files
Files changed (1) hide show
  1. eval_generate.py +140 -0
eval_generate.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import torch
6
+ from datasets import load_dataset
7
+ from tqdm.auto import tqdm
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DPRQuestionEncoder
9
+
10
+ from common import articles_to_paragraphs, kilt_wikipedia_columns
11
+ from common import kilt_wikipedia_paragraph_columns as columns
12
+
13
+
14
+ def eval_generate(args):
15
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
16
+ question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
17
+ question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
18
+ _ = question_model.eval()
19
+
20
+ eli5_tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_eli5')
21
+ eli5_model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_eli5').to(device)
22
+ _ = eli5_model.eval()
23
+
24
+ min_snippet_length = 20
25
+ topk = 21
26
+ min_chars_per_passage = 200
27
+ kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
28
+ kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
29
+ remove_columns=kilt_wikipedia_columns,
30
+ batch_size=256,
31
+ cache_file_name=f"./data/wiki_kilt_paragraphs_full.arrow",
32
+ desc="Expanding wiki articles into paragraphs")
33
+
34
+ # use paragraphs that are not simple fragments or very short sentences
35
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
36
+ lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
37
+ kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
38
+
39
+ def embed_questions_for_retrieval(questions):
40
+ query = question_tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
41
+ with torch.no_grad():
42
+ q_reps = question_model(query["input_ids"].to(device),
43
+ query["attention_mask"].to(device)).pooler_output
44
+ return q_reps.cpu().numpy()
45
+
46
+ def query_index(question):
47
+ question_embedding = embed_questions_for_retrieval([question])
48
+ scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
49
+
50
+ retrieved_examples = []
51
+ r = list(zip(wiki_passages[k] for k in columns))
52
+ for i in range(topk):
53
+ retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
54
+ return retrieved_examples
55
+
56
+ def create_kilt_datapoint(q_id, query, answer, res_list):
57
+ # make a KILT data point
58
+ # see https://github.com/facebookresearch/KILT#kilt-data-format
59
+
60
+ provenance = [{
61
+ "wikipedia_id": r["wikipedia_id"], # *mandatory*
62
+ "title": r["title"],
63
+ "section": r["section"],
64
+ "start_paragraph_id": r["start_paragraph_id"],
65
+ "start_character": r["start_character"],
66
+ "end_paragraph_id": r["end_paragraph_id"],
67
+ "end_character": r["end_character"],
68
+ "text": r["text"],
69
+ "bleu_score": None, # wrt original evidence
70
+ "meta": None # dataset/task specific
71
+ } for r in res_list]
72
+
73
+ output = [{"answer": answer, "provenance": provenance}]
74
+
75
+ return {"id": q_id,
76
+ "input": query,
77
+ "output": output, # each element is an answer or provenance (can have multiple of each)
78
+ "meta": None # dataset/task specific
79
+ }
80
+
81
+ kilt_output = []
82
+ with open(args.kilt_input_file, "r") as f:
83
+ kilt_items = [json.loads(x) for x in f.read().strip().split("\n")]
84
+ progress_bar = tqdm(range(len(kilt_items)), desc="Creating KILT response document")
85
+ for idx, item in enumerate(kilt_items):
86
+ query = item["input"]
87
+ res_list = query_index(query)
88
+
89
+ res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
90
+ documents = [res["text"] for res in res_list]
91
+ conditioned_doc = "<P> " + " <P> ".join([d for d in documents])
92
+
93
+ query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
94
+
95
+ model_input = eli5_tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt")
96
+ generated_answers_encoded = eli5_model.generate(input_ids=model_input["input_ids"].to(device),
97
+ attention_mask=model_input["attention_mask"].to(device),
98
+ min_length=50,
99
+ max_length=250,
100
+ do_sample=False,
101
+ early_stopping=True,
102
+ num_beams=8,
103
+ temperature=1.0,
104
+ top_k=None,
105
+ top_p=None,
106
+ no_repeat_ngram_size=3,
107
+ num_return_sequences=1)
108
+ answer = eli5_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
109
+ clean_up_tokenization_spaces=True)
110
+
111
+ kilt_example = create_kilt_datapoint(item["id"], query, answer[0], res_list)
112
+ kilt_output.append(kilt_example)
113
+ progress_bar.update(1)
114
+
115
+ with open(args.kilt_output_file, "w") as fp:
116
+ for kilt_example in kilt_output:
117
+ json.dump(kilt_example, fp)
118
+ fp.write("\n")
119
+
120
+
121
+ if __name__ == "__main__":
122
+ parser = argparse.ArgumentParser()
123
+ parser.add_argument('--kilt_input_file', default="./eli5-dev-kilt.jsonl", type=str)
124
+ parser.add_argument('--kilt_output_file', default="./eli5-predicted_retrieval.jsonl", type=str)
125
+ parser.add_argument(
126
+ "--question_encoder_name",
127
+ default="vblagoje/dpr-question_encoder-single-lfqa-base",
128
+ help="Question encoder to use",
129
+ )
130
+
131
+ parser.add_argument(
132
+ "--index_file_name",
133
+ default="../data/kilt_dpr_wikipedia_first.faiss",
134
+ help="Faiss index with passage embeddings",
135
+ )
136
+
137
+ args = parser.parse_args()
138
+
139
+ assert os.path.isfile(args.kilt_input_file), f"Input file {args.kilt_input_file} couldn't be loaded"
140
+ eval_generate(args)