"""MS MARCO data loading for training and evaluation.""" import random from typing import Optional import torch from torch.utils.data import Dataset, DataLoader from datasets import load_dataset class MSMARCOTripleDataset(Dataset): """MS MARCO passage ranking dataset with hard negatives. Each example yields (query, positive_passage, [negative_passages]). """ def __init__(self, tokenizer, max_samples: int = 100_000, num_negatives: int = 7, max_seq_length: int = 128, split: str = "train", seed: int = 42): self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.num_negatives = num_negatives # Load MS MARCO dataset print(f"Loading MS MARCO ({split} split, max {max_samples} samples)...") dataset = load_dataset("ms_marco", "v2.1", split=split, trust_remote_code=True) # Filter to examples with at least one selected passage self.examples = [] for i, ex in enumerate(dataset): if len(self.examples) >= max_samples: break passages = ex["passages"] selected = [j for j, s in enumerate(passages["is_selected"]) if s == 1] if selected: self.examples.append({ "query": ex["query"], "positive": passages["passage_text"][selected[0]], "negatives": [ passages["passage_text"][j] for j in range(len(passages["passage_text"])) if j not in selected ], }) print(f"Loaded {len(self.examples)} training examples.") self.rng = random.Random(seed) def __len__(self) -> int: return len(self.examples) def __getitem__(self, idx: int) -> dict: ex = self.examples[idx] # Sample negatives (from in-passage negatives, pad with random if needed) available_negs = ex["negatives"] if len(available_negs) >= self.num_negatives: negs = self.rng.sample(available_negs, self.num_negatives) else: negs = available_negs[:] # Pad with random negatives from other examples while len(negs) < self.num_negatives: rand_ex = self.examples[self.rng.randint(0, len(self.examples) - 1)] if rand_ex["positive"] != ex["positive"]: negs.append(rand_ex["positive"]) return { "query": ex["query"], "positive": ex["positive"], "negatives": negs, } def collate_fn(batch: list[dict], tokenizer, max_seq_length: int = 128) -> dict: """Collate batch into tokenized tensors.""" queries = [b["query"] for b in batch] positives = [b["positive"] for b in batch] all_negatives = [] for b in batch: all_negatives.extend(b["negatives"]) # Tokenize q_enc = tokenizer( queries, padding=True, truncation=True, max_length=max_seq_length, return_tensors="pt", ) p_enc = tokenizer( positives, padding=True, truncation=True, max_length=max_seq_length, return_tensors="pt", ) n_enc = tokenizer( all_negatives, padding=True, truncation=True, max_length=max_seq_length, return_tensors="pt", ) num_negatives = len(batch[0]["negatives"]) return { "query_input_ids": q_enc["input_ids"], "query_attention_mask": q_enc["attention_mask"], "pos_input_ids": p_enc["input_ids"], "pos_attention_mask": p_enc["attention_mask"], "neg_input_ids": n_enc["input_ids"], "neg_attention_mask": n_enc["attention_mask"], "num_negatives": num_negatives, } def get_dataloader(tokenizer, max_samples: int = 100_000, num_negatives: int = 7, batch_size: int = 64, max_seq_length: int = 128, split: str = "train", seed: int = 42, num_workers: int = 0) -> DataLoader: """Create a DataLoader for MS MARCO training.""" dataset = MSMARCOTripleDataset( tokenizer=tokenizer, max_samples=max_samples, num_negatives=num_negatives, max_seq_length=max_seq_length, split=split, seed=seed, ) def _collate(batch): return collate_fn(batch, tokenizer, max_seq_length) return DataLoader( dataset, batch_size=batch_size, shuffle=True, collate_fn=_collate, num_workers=num_workers, drop_last=True, ) class MSMARCOEvalDataset: """MS MARCO dev set for evaluation.""" def __init__(self, tokenizer, max_queries: int = 5000, max_seq_length: int = 128, seed: int = 42): self.tokenizer = tokenizer self.max_seq_length = max_seq_length print(f"Loading MS MARCO dev set (max {max_queries} queries)...") dataset = load_dataset("ms_marco", "v2.1", split="validation", trust_remote_code=True) self.queries = [] self.positives = [] # list of list of positive passage texts self.all_passages = [] # flat list of all passages for retrieval self.passage_set = set() rng = random.Random(seed) indices = list(range(len(dataset))) rng.shuffle(indices) for i in indices: if len(self.queries) >= max_queries: break ex = dataset[i] passages = ex["passages"] selected = [j for j, s in enumerate(passages["is_selected"]) if s == 1] if not selected: continue self.queries.append(ex["query"]) pos_texts = [passages["passage_text"][j] for j in selected] self.positives.append(pos_texts) # Add all passages to the corpus for text in passages["passage_text"]: if text not in self.passage_set: self.passage_set.add(text) self.all_passages.append(text) print(f"Loaded {len(self.queries)} eval queries, " f"{len(self.all_passages)} unique passages.")