juliaturc commited on
Commit
4e68d3a
·
1 Parent(s): 8699925

Script to run retrieval for the Kaggle competition

Browse files
benchmarks/retrieval/retrieve_kaggle.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to call retrieval on the Kaggle dataset."""
2
+
3
+ import csv
4
+ import json
5
+ import logging
6
+
7
+ import configargparse
8
+
9
+ import sage.config
10
+ from sage.retriever import build_retriever_from_args
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger()
14
+ logger.setLevel(logging.INFO)
15
+
16
+
17
+ def main():
18
+ parser = configargparse.ArgParser(
19
+ description="Runs retrieval on the Kaggle dataset.", ignore_unknown_config_file_keys=True
20
+ )
21
+ parser.add("--benchmark", required=True, help="Path to the Kaggle dataset.")
22
+ parser.add("--output-file", required=True, help="Path to the output file with predictions.")
23
+
24
+ sage.config.add_config_args(parser)
25
+ sage.config.add_embedding_args(parser)
26
+ sage.config.add_vector_store_args(parser)
27
+ sage.config.add_reranking_args(parser)
28
+ args = parser.parse_args()
29
+ sage.config.validate_vector_store_args(args)
30
+
31
+ retriever = build_retriever_from_args(args)
32
+
33
+ with open(args.benchmark, "r") as f:
34
+ benchmark = csv.DictReader(f)
35
+ benchmark = [row for row in benchmark]
36
+
37
+ outputs = []
38
+ for question_idx, item in enumerate(benchmark):
39
+ print(f"Processing question {question_idx}...")
40
+
41
+ retrieved = retriever.invoke(item["question"])
42
+ # Sort by score in descending order.
43
+ retrieved = sorted(retrieved, key=lambda doc: doc.metadata.get("score", doc.metadata.get("relevance_score")), reverse=True)
44
+ # Keep top 3, since the Kaggle competition only evaluates the top 3.
45
+ retrieved = retrieved[:3]
46
+ retrieved_filenames = [doc.metadata["file_path"] for doc in retrieved]
47
+ outputs.append((item["id"], json.dumps(retrieved_filenames)))
48
+
49
+ with open(args.output_file, "w") as f:
50
+ csv_writer = csv.writer(f)
51
+ csv_writer.writerow(["id", "documents"])
52
+ csv_writer.writerows(outputs)
53
+
54
+
55
+ if __name__ == "__main__":
56
+ main()