clusd-search / src /utils.py
Ishika-max
CluSD end-to-end app
4b3b4fa
Raw
History Blame Contribute Delete
2.75 kB
import os
import json
import random
import numpy as np
import torch
from beir import util
def set_seeds(seed_val=42):
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
def load_and_split_dataset(dataset_name="msmarco", data_root="./datasets", split_name="dev", n_docs=100000, n_train=150, n_test=150):
"""OOM-Safe dataset loader that streams JSONL directly from disk."""
print("Downloading and extracting dataset...")
data_dir = os.path.join(data_root, dataset_name)
os.makedirs(data_root, exist_ok=True)
if not os.path.exists(data_dir):
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
util.download_url(url, os.path.join(data_root, f"{dataset_name}.zip"))
util.unzip(os.path.join(data_root, f"{dataset_name}.zip"), data_root)
print(f"Loading queries and qrels (using '{split_name}' split)...")
queries_all = {}
with open(os.path.join(data_dir, "queries.jsonl"), "r", encoding="utf-8") as f:
for line in f:
q = json.loads(line)
queries_all[q["_id"]] = q["text"]
qrels_all = {}
with open(os.path.join(data_dir, "qrels", f"{split_name}.tsv"), "r", encoding="utf-8") as f:
next(f) # Skip header
for line in f:
q_id, doc_id, score = line.strip().split("\t")
if q_id not in qrels_all: qrels_all[q_id] = {}
qrels_all[q_id][doc_id] = int(score)
print(f"Streaming corpus to extract {n_docs} documents safely...")
corpus = {}
with open(os.path.join(data_dir, "corpus.jsonl"), "r", encoding="utf-8") as f:
for i, line in enumerate(f):
if i >= n_docs: break
doc = json.loads(line)
corpus[doc["_id"]] = doc
doc_ids = list(corpus.keys())
doc_texts = [(corpus[d].get("title", "") + " " + corpus[d].get("text", "")).strip() for d in doc_ids]
valid_queries = []
selected_doc_ids = set(doc_ids)
for q_id, rel_docs in qrels_all.items():
if q_id in queries_all and any(d in selected_doc_ids for d in rel_docs.keys()):
valid_queries.append(q_id)
random.shuffle(valid_queries)
actual_train = min(n_train, int(len(valid_queries) * 0.5))
actual_test = min(n_test, len(valid_queries) - actual_train)
train_q_ids = valid_queries[:actual_train]
test_q_ids = valid_queries[actual_train:actual_train + actual_test]
qrels = {}
for q in valid_queries:
filtered = {d: r for d, r in qrels_all[q].items() if d in selected_doc_ids}
if filtered: qrels[q] = filtered
return doc_ids, doc_texts, queries_all, train_q_ids, test_q_ids, qrels