| | import os
|
| | import sys
|
| | import tqdm
|
| | import ujson
|
| | import random
|
| |
|
| | from argparse import ArgumentParser
|
| | from collections import OrderedDict
|
| | from colbert.utils.utils import print_message, file_tqdm
|
| |
|
| |
|
| | def main(args):
|
| | qid_to_file_idx = {}
|
| |
|
| | for qrels_idx, qrels in enumerate(args.all_queries):
|
| | with open(qrels) as f:
|
| | for line in f:
|
| | qid, *_ = line.strip().split('\t')
|
| | qid = int(qid)
|
| |
|
| | assert qid_to_file_idx.get(qid, qrels_idx) == qrels_idx, (qid, qrels_idx)
|
| | qid_to_file_idx[qid] = qrels_idx
|
| |
|
| | all_outputs_paths = [f'{args.ranking}.{idx}' for idx in range(len(args.all_queries))]
|
| |
|
| | assert all(not os.path.exists(path) for path in all_outputs_paths)
|
| |
|
| | all_outputs = [open(path, 'w') for path in all_outputs_paths]
|
| |
|
| | with open(args.ranking) as f:
|
| | print_message(f"#> Loading ranked lists from {f.name} ..")
|
| |
|
| | last_file_idx = -1
|
| |
|
| | for line in file_tqdm(f):
|
| | qid, *_ = line.strip().split('\t')
|
| |
|
| | file_idx = qid_to_file_idx[int(qid)]
|
| |
|
| | if file_idx != last_file_idx:
|
| | print_message(f"#> Switched to file #{file_idx} at {all_outputs[file_idx].name}")
|
| |
|
| | last_file_idx = file_idx
|
| |
|
| | all_outputs[file_idx].write(line)
|
| |
|
| | print()
|
| |
|
| | for f in all_outputs:
|
| | print(f.name)
|
| | f.close()
|
| |
|
| | print("#> Done!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | random.seed(12345)
|
| |
|
| | parser = ArgumentParser(description='.')
|
| |
|
| |
|
| | parser.add_argument('--ranking', dest='ranking', required=True, type=str)
|
| | parser.add_argument('--all-queries', dest='all_queries', required=True, type=str, nargs='+')
|
| |
|
| | args = parser.parse_args()
|
| |
|
| | main(args)
|
| |
|