| """MS MARCO data loading for training and evaluation.""" |
|
|
| import random |
| from typing import Optional |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from datasets import load_dataset |
|
|
|
|
| class MSMARCOTripleDataset(Dataset): |
| """MS MARCO passage ranking dataset with hard negatives. |
| |
| Each example yields (query, positive_passage, [negative_passages]). |
| """ |
|
|
| def __init__(self, tokenizer, max_samples: int = 100_000, |
| num_negatives: int = 7, max_seq_length: int = 128, |
| split: str = "train", seed: int = 42): |
| self.tokenizer = tokenizer |
| self.max_seq_length = max_seq_length |
| self.num_negatives = num_negatives |
|
|
| |
| print(f"Loading MS MARCO ({split} split, max {max_samples} samples)...") |
| dataset = load_dataset("ms_marco", "v2.1", split=split, trust_remote_code=True) |
|
|
| |
| self.examples = [] |
| for i, ex in enumerate(dataset): |
| if len(self.examples) >= max_samples: |
| break |
| passages = ex["passages"] |
| selected = [j for j, s in enumerate(passages["is_selected"]) if s == 1] |
| if selected: |
| self.examples.append({ |
| "query": ex["query"], |
| "positive": passages["passage_text"][selected[0]], |
| "negatives": [ |
| passages["passage_text"][j] |
| for j in range(len(passages["passage_text"])) |
| if j not in selected |
| ], |
| }) |
|
|
| print(f"Loaded {len(self.examples)} training examples.") |
| self.rng = random.Random(seed) |
|
|
| def __len__(self) -> int: |
| return len(self.examples) |
|
|
| def __getitem__(self, idx: int) -> dict: |
| ex = self.examples[idx] |
| |
| available_negs = ex["negatives"] |
| if len(available_negs) >= self.num_negatives: |
| negs = self.rng.sample(available_negs, self.num_negatives) |
| else: |
| negs = available_negs[:] |
| |
| while len(negs) < self.num_negatives: |
| rand_ex = self.examples[self.rng.randint(0, len(self.examples) - 1)] |
| if rand_ex["positive"] != ex["positive"]: |
| negs.append(rand_ex["positive"]) |
|
|
| return { |
| "query": ex["query"], |
| "positive": ex["positive"], |
| "negatives": negs, |
| } |
|
|
|
|
| def collate_fn(batch: list[dict], tokenizer, max_seq_length: int = 128) -> dict: |
| """Collate batch into tokenized tensors.""" |
| queries = [b["query"] for b in batch] |
| positives = [b["positive"] for b in batch] |
| all_negatives = [] |
| for b in batch: |
| all_negatives.extend(b["negatives"]) |
|
|
| |
| q_enc = tokenizer( |
| queries, padding=True, truncation=True, |
| max_length=max_seq_length, return_tensors="pt", |
| ) |
| p_enc = tokenizer( |
| positives, padding=True, truncation=True, |
| max_length=max_seq_length, return_tensors="pt", |
| ) |
| n_enc = tokenizer( |
| all_negatives, padding=True, truncation=True, |
| max_length=max_seq_length, return_tensors="pt", |
| ) |
|
|
| num_negatives = len(batch[0]["negatives"]) |
| return { |
| "query_input_ids": q_enc["input_ids"], |
| "query_attention_mask": q_enc["attention_mask"], |
| "pos_input_ids": p_enc["input_ids"], |
| "pos_attention_mask": p_enc["attention_mask"], |
| "neg_input_ids": n_enc["input_ids"], |
| "neg_attention_mask": n_enc["attention_mask"], |
| "num_negatives": num_negatives, |
| } |
|
|
|
|
| def get_dataloader(tokenizer, max_samples: int = 100_000, |
| num_negatives: int = 7, batch_size: int = 64, |
| max_seq_length: int = 128, split: str = "train", |
| seed: int = 42, num_workers: int = 0) -> DataLoader: |
| """Create a DataLoader for MS MARCO training.""" |
| dataset = MSMARCOTripleDataset( |
| tokenizer=tokenizer, max_samples=max_samples, |
| num_negatives=num_negatives, max_seq_length=max_seq_length, |
| split=split, seed=seed, |
| ) |
|
|
| def _collate(batch): |
| return collate_fn(batch, tokenizer, max_seq_length) |
|
|
| return DataLoader( |
| dataset, batch_size=batch_size, shuffle=True, |
| collate_fn=_collate, num_workers=num_workers, |
| drop_last=True, |
| ) |
|
|
|
|
| class MSMARCOEvalDataset: |
| """MS MARCO dev set for evaluation.""" |
|
|
| def __init__(self, tokenizer, max_queries: int = 5000, |
| max_seq_length: int = 128, seed: int = 42): |
| self.tokenizer = tokenizer |
| self.max_seq_length = max_seq_length |
|
|
| print(f"Loading MS MARCO dev set (max {max_queries} queries)...") |
| dataset = load_dataset("ms_marco", "v2.1", split="validation", trust_remote_code=True) |
|
|
| self.queries = [] |
| self.positives = [] |
| self.all_passages = [] |
| self.passage_set = set() |
|
|
| rng = random.Random(seed) |
| indices = list(range(len(dataset))) |
| rng.shuffle(indices) |
|
|
| for i in indices: |
| if len(self.queries) >= max_queries: |
| break |
| ex = dataset[i] |
| passages = ex["passages"] |
| selected = [j for j, s in enumerate(passages["is_selected"]) if s == 1] |
| if not selected: |
| continue |
|
|
| self.queries.append(ex["query"]) |
| pos_texts = [passages["passage_text"][j] for j in selected] |
| self.positives.append(pos_texts) |
|
|
| |
| for text in passages["passage_text"]: |
| if text not in self.passage_set: |
| self.passage_set.add(text) |
| self.all_passages.append(text) |
|
|
| print(f"Loaded {len(self.queries)} eval queries, " |
| f"{len(self.all_passages)} unique passages.") |
|
|