import os import json import random import numpy as np import torch from beir import util def set_seeds(seed_val=42): random.seed(seed_val) np.random.seed(seed_val) torch.manual_seed(seed_val) def load_and_split_dataset(dataset_name="msmarco", data_root="./datasets", split_name="dev", n_docs=100000, n_train=150, n_test=150): """OOM-Safe dataset loader that streams JSONL directly from disk.""" print("Downloading and extracting dataset...") data_dir = os.path.join(data_root, dataset_name) os.makedirs(data_root, exist_ok=True) if not os.path.exists(data_dir): url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip" util.download_url(url, os.path.join(data_root, f"{dataset_name}.zip")) util.unzip(os.path.join(data_root, f"{dataset_name}.zip"), data_root) print(f"Loading queries and qrels (using '{split_name}' split)...") queries_all = {} with open(os.path.join(data_dir, "queries.jsonl"), "r", encoding="utf-8") as f: for line in f: q = json.loads(line) queries_all[q["_id"]] = q["text"] qrels_all = {} with open(os.path.join(data_dir, "qrels", f"{split_name}.tsv"), "r", encoding="utf-8") as f: next(f) # Skip header for line in f: q_id, doc_id, score = line.strip().split("\t") if q_id not in qrels_all: qrels_all[q_id] = {} qrels_all[q_id][doc_id] = int(score) print(f"Streaming corpus to extract {n_docs} documents safely...") corpus = {} with open(os.path.join(data_dir, "corpus.jsonl"), "r", encoding="utf-8") as f: for i, line in enumerate(f): if i >= n_docs: break doc = json.loads(line) corpus[doc["_id"]] = doc doc_ids = list(corpus.keys()) doc_texts = [(corpus[d].get("title", "") + " " + corpus[d].get("text", "")).strip() for d in doc_ids] valid_queries = [] selected_doc_ids = set(doc_ids) for q_id, rel_docs in qrels_all.items(): if q_id in queries_all and any(d in selected_doc_ids for d in rel_docs.keys()): valid_queries.append(q_id) random.shuffle(valid_queries) actual_train = min(n_train, int(len(valid_queries) * 0.5)) actual_test = min(n_test, len(valid_queries) - actual_train) train_q_ids = valid_queries[:actual_train] test_q_ids = valid_queries[actual_train:actual_train + actual_test] qrels = {} for q in valid_queries: filtered = {d: r for d, r in qrels_all[q].items() if d in selected_doc_ids} if filtered: qrels[q] = filtered return doc_ids, doc_texts, queries_all, train_q_ids, test_q_ids, qrels