theachyuttiwari commited on
Commit
3a1ca92
·
1 Parent(s): daaf6f3

Upload create_dpr_training_from_dataset.py

Browse files
Files changed (1) hide show
  1. create_dpr_training_from_dataset.py +103 -0
create_dpr_training_from_dataset.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import json
4
+ import re
5
+
6
+ from sentence_transformers import SentenceTransformer
7
+ from sentence_transformers.util import semantic_search, cos_sim
8
+ from tqdm.auto import tqdm
9
+ from datasets import load_dataset
10
+
11
+ from common import clean_answer, clean_question
12
+
13
+
14
+ def find_hard_negative_ctxs(dataset, dataset_embeddings, embedding_index: int,
15
+ exclude_answer_patterns, similarity_threshold=[0.5, 0.6], k=25, min_count=3):
16
+ hard_negative_ctxs = []
17
+ results = semantic_search(dataset_embeddings[embedding_index], dataset_embeddings, top_k=k,
18
+ score_function=cos_sim)
19
+ # list if dicts
20
+ # [{'corpus_id': 8, 'score': -0.019427383318543434},
21
+ # ...
22
+ # {'corpus_id': 10, 'score': -0.09040290117263794}]
23
+ # hard negative are most similar and negatives are most disimilar to embedding_index
24
+ hard_negative_results = results[0][1:k + 1]
25
+ assert len(hard_negative_results) > min_count * 2
26
+ for r in hard_negative_results:
27
+ example = dataset[r["corpus_id"]]
28
+ if similarity_threshold[0] < r["score"] <= similarity_threshold[1]:
29
+ for a in example["answers"]["text"]:
30
+ hard_negative_ctxs.append({"title": "", "text": clean_answer(a)})
31
+ if len(hard_negative_ctxs) > min_count:
32
+ break
33
+ return hard_negative_ctxs[:min_count]
34
+
35
+
36
+ def find_negative_ctxs(dataset, dataset_embeddings, embedding_index: int,
37
+ exclude_answer_patterns, similarity_threshold=0.1, k=7, min_count=3):
38
+ negative_ctxs = []
39
+ random_sample = random.sample(range(len(dataset_embeddings)), k * 20)
40
+ similarities = cos_sim(dataset_embeddings[embedding_index], dataset_embeddings[random_sample])[0].tolist()
41
+ for idx, score in enumerate(similarities):
42
+ if score < similarity_threshold:
43
+ example = dataset[random_sample[idx]]
44
+ for a in example["answers"]["text"]:
45
+ negative_ctxs.append({"title": "", "text": clean_answer(a)})
46
+ if len(negative_ctxs) > min_count:
47
+ break
48
+ return negative_ctxs[:min_count]
49
+
50
+
51
+ def generate_dpr_training_file(args):
52
+ embedder = SentenceTransformer(args.embedding_model)
53
+
54
+ eli5_train_set = load_dataset("vblagoje/lfqa", split="train")
55
+ eli5_validation_set = load_dataset("vblagoje/lfqa", split="validation")
56
+ eli5_test_set = load_dataset("vblagoje/lfqa", split="test")
57
+
58
+ train_set = embedder.encode([example["title"] for example in eli5_train_set], convert_to_tensor=True,
59
+ show_progress_bar=True)
60
+ validation_set = embedder.encode([example["title"] for example in eli5_validation_set], convert_to_tensor=True,
61
+ show_progress_bar=True)
62
+
63
+ test_set = embedder.encode([example["title"] for example in eli5_test_set], convert_to_tensor=True,
64
+ show_progress_bar=True)
65
+ exclude_answer_patterns = [re.compile("not sure what you"), re.compile("\n\n >")]
66
+ for dataset_name, dataset, dataset_embeddings in zip(["train", "validation", "test"],
67
+ [eli5_train_set, eli5_validation_set, eli5_test_set],
68
+ [train_set, validation_set, test_set]):
69
+ min_elements = 3
70
+ skip_count = 0
71
+ progress_bar = tqdm(range(len(dataset)), desc="Creating DPR formatted question/passage docs")
72
+ with open('eli5-dpr-' + dataset_name + '.jsonl', 'w') as fp:
73
+ for idx, example in enumerate(dataset):
74
+ negative_ctxs = find_negative_ctxs(dataset, dataset_embeddings, idx, exclude_answer_patterns)
75
+ hard_negative_ctxs = find_hard_negative_ctxs(dataset, dataset_embeddings, idx, exclude_answer_patterns)
76
+ positive_context = [{"text": clean_answer(a), "title": ""} for a in example["answers"]["text"] if
77
+ not any([p.search(a) for p in exclude_answer_patterns])]
78
+ if not positive_context:
79
+ positive_context = [{"text": clean_answer(a), "title": ""} for a in example["answers"]["text"]]
80
+ if len(positive_context) > 0 and len(negative_ctxs) > 0 and len(hard_negative_ctxs) >= min_elements:
81
+ json.dump({"id": example["q_id"],
82
+ "question": clean_question(example["title"]),
83
+ "positive_ctxs": positive_context[:min_elements],
84
+ "negative_ctxs": negative_ctxs[:min_elements],
85
+ "hard_negative_ctxs": hard_negative_ctxs[:min_elements]}, fp)
86
+ fp.write("\n")
87
+ else:
88
+ skip_count += 1
89
+ progress_bar.update(1)
90
+
91
+ print(f"Skipped {skip_count} questions")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser(description="Creates DPR training file from LFQA dataset")
96
+ parser.add_argument(
97
+ "--embedding_model",
98
+ default="all-mpnet-base-v2",
99
+ help="Embedding model to use for question encoding and semantic search",
100
+ )
101
+
102
+ main_args, _ = parser.parse_known_args()
103
+ generate_dpr_training_file(main_args)