File size: 2,281 Bytes
a608d21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Optional

from beir import util
from beir.datasets.data_loader import GenericDataLoader
from sklearn.model_selection import train_test_split


@dataclass(frozen=True)
class DatasetSplit:
    corpus: dict
    queries: dict[str, str]
    qrels: dict[str, dict[str, int]]


def download_beir_dataset(dataset_name: str, datasets_dir: Path) -> Path:
    url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
    data_path = util.download_and_unzip(url, str(datasets_dir))
    return Path(data_path)


def load_beir_source(
    dataset_name: str,
    datasets_dir: Path,
    max_queries: Optional[int] = None,
    source_split: str = "test",
) -> DatasetSplit:
    if dataset_name.lower() == "msmarco" and source_split == "test":
        source_split = "train"
    data_path = download_beir_dataset(dataset_name, datasets_dir)
    corpus, queries, qrels = GenericDataLoader(data_folder=str(data_path)).load(split=source_split)

    qids = list(queries.keys())
    if max_queries is not None:
        qids = qids[:max_queries]
        queries = {qid: queries[qid] for qid in qids}
        qrels = {qid: qrels[qid] for qid in qids if qid in qrels}
    return DatasetSplit(corpus=corpus, queries=queries, qrels=qrels)

def load_beir_splits(
    dataset_name: str,
    datasets_dir: Path,
    max_queries: Optional[int] = None,
    source_split: str = "test",
) -> dict[str, DatasetSplit]:
    base = load_beir_source(dataset_name, datasets_dir, max_queries=max_queries, source_split=source_split)
    qids = list(base.queries.keys())
    train_ids, temp_ids = train_test_split(qids, test_size=0.3, random_state=42)
    val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)

    splits: dict[str, DatasetSplit] = {}
    for split_name, split_ids in (("train", train_ids), ("val", val_ids), ("test", test_ids)):
        split_queries = {qid: base.queries[qid] for qid in split_ids if qid in base.qrels}
        split_qrels = {qid: base.qrels[qid] for qid in split_ids if qid in base.qrels}
        splits[split_name] = DatasetSplit(corpus=base.corpus, queries=split_queries, qrels=split_qrels)
    return splits