IMRNNs / src /imrnns /beir_data.py
yashsaxena21's picture
Upload folder using huggingface_hub
a608d21 verified
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from sklearn.model_selection import train_test_split
@dataclass(frozen=True)
class DatasetSplit:
corpus: dict
queries: dict[str, str]
qrels: dict[str, dict[str, int]]
def download_beir_dataset(dataset_name: str, datasets_dir: Path) -> Path:
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
data_path = util.download_and_unzip(url, str(datasets_dir))
return Path(data_path)
def load_beir_source(
dataset_name: str,
datasets_dir: Path,
max_queries: Optional[int] = None,
source_split: str = "test",
) -> DatasetSplit:
if dataset_name.lower() == "msmarco" and source_split == "test":
source_split = "train"
data_path = download_beir_dataset(dataset_name, datasets_dir)
corpus, queries, qrels = GenericDataLoader(data_folder=str(data_path)).load(split=source_split)
qids = list(queries.keys())
if max_queries is not None:
qids = qids[:max_queries]
queries = {qid: queries[qid] for qid in qids}
qrels = {qid: qrels[qid] for qid in qids if qid in qrels}
return DatasetSplit(corpus=corpus, queries=queries, qrels=qrels)
def load_beir_splits(
dataset_name: str,
datasets_dir: Path,
max_queries: Optional[int] = None,
source_split: str = "test",
) -> dict[str, DatasetSplit]:
base = load_beir_source(dataset_name, datasets_dir, max_queries=max_queries, source_split=source_split)
qids = list(base.queries.keys())
train_ids, temp_ids = train_test_split(qids, test_size=0.3, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)
splits: dict[str, DatasetSplit] = {}
for split_name, split_ids in (("train", train_ids), ("val", val_ids), ("test", test_ids)):
split_queries = {qid: base.queries[qid] for qid in split_ids if qid in base.qrels}
split_qrels = {qid: base.qrels[qid] for qid in split_ids if qid in base.qrels}
splits[split_name] = DatasetSplit(corpus=base.corpus, queries=split_queries, qrels=split_qrels)
return splits