Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import random | |
| import numpy as np | |
| import torch | |
| from beir import util | |
| def set_seeds(seed_val=42): | |
| random.seed(seed_val) | |
| np.random.seed(seed_val) | |
| torch.manual_seed(seed_val) | |
| def load_and_split_dataset(dataset_name="msmarco", data_root="./datasets", split_name="dev", n_docs=100000, n_train=150, n_test=150): | |
| """OOM-Safe dataset loader that streams JSONL directly from disk.""" | |
| print("Downloading and extracting dataset...") | |
| data_dir = os.path.join(data_root, dataset_name) | |
| os.makedirs(data_root, exist_ok=True) | |
| if not os.path.exists(data_dir): | |
| url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip" | |
| util.download_url(url, os.path.join(data_root, f"{dataset_name}.zip")) | |
| util.unzip(os.path.join(data_root, f"{dataset_name}.zip"), data_root) | |
| print(f"Loading queries and qrels (using '{split_name}' split)...") | |
| queries_all = {} | |
| with open(os.path.join(data_dir, "queries.jsonl"), "r", encoding="utf-8") as f: | |
| for line in f: | |
| q = json.loads(line) | |
| queries_all[q["_id"]] = q["text"] | |
| qrels_all = {} | |
| with open(os.path.join(data_dir, "qrels", f"{split_name}.tsv"), "r", encoding="utf-8") as f: | |
| next(f) # Skip header | |
| for line in f: | |
| q_id, doc_id, score = line.strip().split("\t") | |
| if q_id not in qrels_all: qrels_all[q_id] = {} | |
| qrels_all[q_id][doc_id] = int(score) | |
| print(f"Streaming corpus to extract {n_docs} documents safely...") | |
| corpus = {} | |
| with open(os.path.join(data_dir, "corpus.jsonl"), "r", encoding="utf-8") as f: | |
| for i, line in enumerate(f): | |
| if i >= n_docs: break | |
| doc = json.loads(line) | |
| corpus[doc["_id"]] = doc | |
| doc_ids = list(corpus.keys()) | |
| doc_texts = [(corpus[d].get("title", "") + " " + corpus[d].get("text", "")).strip() for d in doc_ids] | |
| valid_queries = [] | |
| selected_doc_ids = set(doc_ids) | |
| for q_id, rel_docs in qrels_all.items(): | |
| if q_id in queries_all and any(d in selected_doc_ids for d in rel_docs.keys()): | |
| valid_queries.append(q_id) | |
| random.shuffle(valid_queries) | |
| actual_train = min(n_train, int(len(valid_queries) * 0.5)) | |
| actual_test = min(n_test, len(valid_queries) - actual_train) | |
| train_q_ids = valid_queries[:actual_train] | |
| test_q_ids = valid_queries[actual_train:actual_train + actual_test] | |
| qrels = {} | |
| for q in valid_queries: | |
| filtered = {d: r for d, r in qrels_all[q].items() if d in selected_doc_ids} | |
| if filtered: qrels[q] = filtered | |
| return doc_ids, doc_texts, queries_all, train_q_ids, test_q_ids, qrels |