| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional |
|
|
| from beir import util |
| from beir.datasets.data_loader import GenericDataLoader |
| from sklearn.model_selection import train_test_split |
|
|
|
|
| @dataclass(frozen=True) |
| class DatasetSplit: |
| corpus: dict |
| queries: dict[str, str] |
| qrels: dict[str, dict[str, int]] |
|
|
|
|
| def download_beir_dataset(dataset_name: str, datasets_dir: Path) -> Path: |
| url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip" |
| data_path = util.download_and_unzip(url, str(datasets_dir)) |
| return Path(data_path) |
|
|
|
|
| def load_beir_source( |
| dataset_name: str, |
| datasets_dir: Path, |
| max_queries: Optional[int] = None, |
| source_split: str = "test", |
| ) -> DatasetSplit: |
| if dataset_name.lower() == "msmarco" and source_split == "test": |
| source_split = "train" |
| data_path = download_beir_dataset(dataset_name, datasets_dir) |
| corpus, queries, qrels = GenericDataLoader(data_folder=str(data_path)).load(split=source_split) |
|
|
| qids = list(queries.keys()) |
| if max_queries is not None: |
| qids = qids[:max_queries] |
| queries = {qid: queries[qid] for qid in qids} |
| qrels = {qid: qrels[qid] for qid in qids if qid in qrels} |
| return DatasetSplit(corpus=corpus, queries=queries, qrels=qrels) |
|
|
| def load_beir_splits( |
| dataset_name: str, |
| datasets_dir: Path, |
| max_queries: Optional[int] = None, |
| source_split: str = "test", |
| ) -> dict[str, DatasetSplit]: |
| base = load_beir_source(dataset_name, datasets_dir, max_queries=max_queries, source_split=source_split) |
| qids = list(base.queries.keys()) |
| train_ids, temp_ids = train_test_split(qids, test_size=0.3, random_state=42) |
| val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42) |
|
|
| splits: dict[str, DatasetSplit] = {} |
| for split_name, split_ids in (("train", train_ids), ("val", val_ids), ("test", test_ids)): |
| split_queries = {qid: base.queries[qid] for qid in split_ids if qid in base.qrels} |
| split_qrels = {qid: base.qrels[qid] for qid in split_ids if qid in base.qrels} |
| splits[split_name] = DatasetSplit(corpus=base.corpus, queries=split_queries, qrels=split_qrels) |
| return splits |
|
|