Commit
·
3a1ca92
1
Parent(s):
daaf6f3
Upload create_dpr_training_from_dataset.py
Browse files
create_dpr_training_from_dataset.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import random
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
from sentence_transformers.util import semantic_search, cos_sim
|
| 8 |
+
from tqdm.auto import tqdm
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
from common import clean_answer, clean_question
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def find_hard_negative_ctxs(dataset, dataset_embeddings, embedding_index: int,
|
| 15 |
+
exclude_answer_patterns, similarity_threshold=[0.5, 0.6], k=25, min_count=3):
|
| 16 |
+
hard_negative_ctxs = []
|
| 17 |
+
results = semantic_search(dataset_embeddings[embedding_index], dataset_embeddings, top_k=k,
|
| 18 |
+
score_function=cos_sim)
|
| 19 |
+
# list if dicts
|
| 20 |
+
# [{'corpus_id': 8, 'score': -0.019427383318543434},
|
| 21 |
+
# ...
|
| 22 |
+
# {'corpus_id': 10, 'score': -0.09040290117263794}]
|
| 23 |
+
# hard negative are most similar and negatives are most disimilar to embedding_index
|
| 24 |
+
hard_negative_results = results[0][1:k + 1]
|
| 25 |
+
assert len(hard_negative_results) > min_count * 2
|
| 26 |
+
for r in hard_negative_results:
|
| 27 |
+
example = dataset[r["corpus_id"]]
|
| 28 |
+
if similarity_threshold[0] < r["score"] <= similarity_threshold[1]:
|
| 29 |
+
for a in example["answers"]["text"]:
|
| 30 |
+
hard_negative_ctxs.append({"title": "", "text": clean_answer(a)})
|
| 31 |
+
if len(hard_negative_ctxs) > min_count:
|
| 32 |
+
break
|
| 33 |
+
return hard_negative_ctxs[:min_count]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def find_negative_ctxs(dataset, dataset_embeddings, embedding_index: int,
|
| 37 |
+
exclude_answer_patterns, similarity_threshold=0.1, k=7, min_count=3):
|
| 38 |
+
negative_ctxs = []
|
| 39 |
+
random_sample = random.sample(range(len(dataset_embeddings)), k * 20)
|
| 40 |
+
similarities = cos_sim(dataset_embeddings[embedding_index], dataset_embeddings[random_sample])[0].tolist()
|
| 41 |
+
for idx, score in enumerate(similarities):
|
| 42 |
+
if score < similarity_threshold:
|
| 43 |
+
example = dataset[random_sample[idx]]
|
| 44 |
+
for a in example["answers"]["text"]:
|
| 45 |
+
negative_ctxs.append({"title": "", "text": clean_answer(a)})
|
| 46 |
+
if len(negative_ctxs) > min_count:
|
| 47 |
+
break
|
| 48 |
+
return negative_ctxs[:min_count]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def generate_dpr_training_file(args):
|
| 52 |
+
embedder = SentenceTransformer(args.embedding_model)
|
| 53 |
+
|
| 54 |
+
eli5_train_set = load_dataset("vblagoje/lfqa", split="train")
|
| 55 |
+
eli5_validation_set = load_dataset("vblagoje/lfqa", split="validation")
|
| 56 |
+
eli5_test_set = load_dataset("vblagoje/lfqa", split="test")
|
| 57 |
+
|
| 58 |
+
train_set = embedder.encode([example["title"] for example in eli5_train_set], convert_to_tensor=True,
|
| 59 |
+
show_progress_bar=True)
|
| 60 |
+
validation_set = embedder.encode([example["title"] for example in eli5_validation_set], convert_to_tensor=True,
|
| 61 |
+
show_progress_bar=True)
|
| 62 |
+
|
| 63 |
+
test_set = embedder.encode([example["title"] for example in eli5_test_set], convert_to_tensor=True,
|
| 64 |
+
show_progress_bar=True)
|
| 65 |
+
exclude_answer_patterns = [re.compile("not sure what you"), re.compile("\n\n >")]
|
| 66 |
+
for dataset_name, dataset, dataset_embeddings in zip(["train", "validation", "test"],
|
| 67 |
+
[eli5_train_set, eli5_validation_set, eli5_test_set],
|
| 68 |
+
[train_set, validation_set, test_set]):
|
| 69 |
+
min_elements = 3
|
| 70 |
+
skip_count = 0
|
| 71 |
+
progress_bar = tqdm(range(len(dataset)), desc="Creating DPR formatted question/passage docs")
|
| 72 |
+
with open('eli5-dpr-' + dataset_name + '.jsonl', 'w') as fp:
|
| 73 |
+
for idx, example in enumerate(dataset):
|
| 74 |
+
negative_ctxs = find_negative_ctxs(dataset, dataset_embeddings, idx, exclude_answer_patterns)
|
| 75 |
+
hard_negative_ctxs = find_hard_negative_ctxs(dataset, dataset_embeddings, idx, exclude_answer_patterns)
|
| 76 |
+
positive_context = [{"text": clean_answer(a), "title": ""} for a in example["answers"]["text"] if
|
| 77 |
+
not any([p.search(a) for p in exclude_answer_patterns])]
|
| 78 |
+
if not positive_context:
|
| 79 |
+
positive_context = [{"text": clean_answer(a), "title": ""} for a in example["answers"]["text"]]
|
| 80 |
+
if len(positive_context) > 0 and len(negative_ctxs) > 0 and len(hard_negative_ctxs) >= min_elements:
|
| 81 |
+
json.dump({"id": example["q_id"],
|
| 82 |
+
"question": clean_question(example["title"]),
|
| 83 |
+
"positive_ctxs": positive_context[:min_elements],
|
| 84 |
+
"negative_ctxs": negative_ctxs[:min_elements],
|
| 85 |
+
"hard_negative_ctxs": hard_negative_ctxs[:min_elements]}, fp)
|
| 86 |
+
fp.write("\n")
|
| 87 |
+
else:
|
| 88 |
+
skip_count += 1
|
| 89 |
+
progress_bar.update(1)
|
| 90 |
+
|
| 91 |
+
print(f"Skipped {skip_count} questions")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
parser = argparse.ArgumentParser(description="Creates DPR training file from LFQA dataset")
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--embedding_model",
|
| 98 |
+
default="all-mpnet-base-v2",
|
| 99 |
+
help="Embedding model to use for question encoding and semantic search",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
main_args, _ = parser.parse_known_args()
|
| 103 |
+
generate_dpr_training_file(main_args)
|