File size: 6,110 Bytes
b464490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""MS MARCO data loading for training and evaluation."""

import random
from typing import Optional

import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset


class MSMARCOTripleDataset(Dataset):
    """MS MARCO passage ranking dataset with hard negatives.

    Each example yields (query, positive_passage, [negative_passages]).
    """

    def __init__(self, tokenizer, max_samples: int = 100_000,
                 num_negatives: int = 7, max_seq_length: int = 128,
                 split: str = "train", seed: int = 42):
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.num_negatives = num_negatives

        # Load MS MARCO dataset
        print(f"Loading MS MARCO ({split} split, max {max_samples} samples)...")
        dataset = load_dataset("ms_marco", "v2.1", split=split, trust_remote_code=True)

        # Filter to examples with at least one selected passage
        self.examples = []
        for i, ex in enumerate(dataset):
            if len(self.examples) >= max_samples:
                break
            passages = ex["passages"]
            selected = [j for j, s in enumerate(passages["is_selected"]) if s == 1]
            if selected:
                self.examples.append({
                    "query": ex["query"],
                    "positive": passages["passage_text"][selected[0]],
                    "negatives": [
                        passages["passage_text"][j]
                        for j in range(len(passages["passage_text"]))
                        if j not in selected
                    ],
                })

        print(f"Loaded {len(self.examples)} training examples.")
        self.rng = random.Random(seed)

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx: int) -> dict:
        ex = self.examples[idx]
        # Sample negatives (from in-passage negatives, pad with random if needed)
        available_negs = ex["negatives"]
        if len(available_negs) >= self.num_negatives:
            negs = self.rng.sample(available_negs, self.num_negatives)
        else:
            negs = available_negs[:]
            # Pad with random negatives from other examples
            while len(negs) < self.num_negatives:
                rand_ex = self.examples[self.rng.randint(0, len(self.examples) - 1)]
                if rand_ex["positive"] != ex["positive"]:
                    negs.append(rand_ex["positive"])

        return {
            "query": ex["query"],
            "positive": ex["positive"],
            "negatives": negs,
        }


def collate_fn(batch: list[dict], tokenizer, max_seq_length: int = 128) -> dict:
    """Collate batch into tokenized tensors."""
    queries = [b["query"] for b in batch]
    positives = [b["positive"] for b in batch]
    all_negatives = []
    for b in batch:
        all_negatives.extend(b["negatives"])

    # Tokenize
    q_enc = tokenizer(
        queries, padding=True, truncation=True,
        max_length=max_seq_length, return_tensors="pt",
    )
    p_enc = tokenizer(
        positives, padding=True, truncation=True,
        max_length=max_seq_length, return_tensors="pt",
    )
    n_enc = tokenizer(
        all_negatives, padding=True, truncation=True,
        max_length=max_seq_length, return_tensors="pt",
    )

    num_negatives = len(batch[0]["negatives"])
    return {
        "query_input_ids": q_enc["input_ids"],
        "query_attention_mask": q_enc["attention_mask"],
        "pos_input_ids": p_enc["input_ids"],
        "pos_attention_mask": p_enc["attention_mask"],
        "neg_input_ids": n_enc["input_ids"],
        "neg_attention_mask": n_enc["attention_mask"],
        "num_negatives": num_negatives,
    }


def get_dataloader(tokenizer, max_samples: int = 100_000,
                   num_negatives: int = 7, batch_size: int = 64,
                   max_seq_length: int = 128, split: str = "train",
                   seed: int = 42, num_workers: int = 0) -> DataLoader:
    """Create a DataLoader for MS MARCO training."""
    dataset = MSMARCOTripleDataset(
        tokenizer=tokenizer, max_samples=max_samples,
        num_negatives=num_negatives, max_seq_length=max_seq_length,
        split=split, seed=seed,
    )

    def _collate(batch):
        return collate_fn(batch, tokenizer, max_seq_length)

    return DataLoader(
        dataset, batch_size=batch_size, shuffle=True,
        collate_fn=_collate, num_workers=num_workers,
        drop_last=True,
    )


class MSMARCOEvalDataset:
    """MS MARCO dev set for evaluation."""

    def __init__(self, tokenizer, max_queries: int = 5000,
                 max_seq_length: int = 128, seed: int = 42):
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

        print(f"Loading MS MARCO dev set (max {max_queries} queries)...")
        dataset = load_dataset("ms_marco", "v2.1", split="validation", trust_remote_code=True)

        self.queries = []
        self.positives = []  # list of list of positive passage texts
        self.all_passages = []  # flat list of all passages for retrieval
        self.passage_set = set()

        rng = random.Random(seed)
        indices = list(range(len(dataset)))
        rng.shuffle(indices)

        for i in indices:
            if len(self.queries) >= max_queries:
                break
            ex = dataset[i]
            passages = ex["passages"]
            selected = [j for j, s in enumerate(passages["is_selected"]) if s == 1]
            if not selected:
                continue

            self.queries.append(ex["query"])
            pos_texts = [passages["passage_text"][j] for j in selected]
            self.positives.append(pos_texts)

            # Add all passages to the corpus
            for text in passages["passage_text"]:
                if text not in self.passage_set:
                    self.passage_set.add(text)
                    self.all_passages.append(text)

        print(f"Loaded {len(self.queries)} eval queries, "
              f"{len(self.all_passages)} unique passages.")