| import os | |
| import random | |
| from colbert.utils.parser import Arguments | |
| from colbert.utils.runs import Run | |
| from colbert.evaluation.loaders import load_colbert, load_topK, load_qrels | |
| from colbert.evaluation.loaders import load_queries, load_topK_pids, load_collection | |
| from colbert.evaluation.ranking import evaluate | |
| from colbert.evaluation.metrics import evaluate_recall | |
| def main(): | |
| random.seed(12345) | |
| parser = Arguments(description='Exhaustive (slow, not index-based) evaluation of re-ranking with ColBERT.') | |
| parser.add_model_parameters() | |
| parser.add_model_inference_parameters() | |
| parser.add_reranking_input() | |
| parser.add_argument('--depth', dest='depth', required=False, default=None, type=int) | |
| args = parser.parse() | |
| with Run.context(): | |
| args.colbert, args.checkpoint = load_colbert(args) | |
| args.qrels = load_qrels(args.qrels) | |
| if args.collection or args.queries: | |
| assert args.collection and args.queries | |
| args.queries = load_queries(args.queries) | |
| args.collection = load_collection(args.collection) | |
| args.topK_pids, args.qrels = load_topK_pids(args.topK, args.qrels) | |
| else: | |
| args.queries, args.topK_docs, args.topK_pids = load_topK(args.topK) | |
| assert (not args.shortcircuit) or args.qrels, \ | |
| "Short-circuiting (i.e., applying minimal computation to queries with no positives in the re-ranked set) " \ | |
| "can only be applied if qrels is provided." | |
| evaluate_recall(args.qrels, args.queries, args.topK_pids) | |
| evaluate(args) | |
| if __name__ == "__main__": | |
| main() | |