""" Dataset loaders for DPA experiments. Loads HotpotQA, GSM8K, ToolBench for multi-step reasoning evaluation. """ from datasets import load_dataset from torch.utils.data import Dataset import torch import json class MultiStepReasoningDataset(Dataset): """Unified dataset for multi-step reasoning tasks.""" def __init__(self, dataset_name="hotpotqa", split="validation", tokenizer=None, max_length=2048, max_samples=None): self.tokenizer = tokenizer self.max_length = max_length self.dataset_name = dataset_name if dataset_name == "hotpotqa": ds = load_dataset("hotpot_qa", "distractor", split=split) self.data = [self._process_hotpotqa(item) for item in ds] elif dataset_name == "gsm8k": ds = load_dataset("openai/gsm8k", "main", split=split) self.data = [self._process_gsm8k(item) for item in ds] elif dataset_name == "toolbench": # ToolBench needs manual download self.data = self._load_toolbench(split) else: raise ValueError(f"Unknown dataset: {dataset_name}") if max_samples: self.data = self.data[:max_samples] print(f"Loaded {len(self.data)} samples from {dataset_name}/{split}") def _process_hotpotqa(self, item): context = " ".join([ " ".join(sents) for sents in item["context"]["sentences"] ]) return { "input": f"Answer the question based on the context.\n\nContext: {context}\n\nQuestion: {item['question']}\n\nAnswer:", "target": item["answer"], "type": item["type"], # "bridge" or "comparison" "num_hops": 2 if item["type"] == "bridge" else 1, } def _process_gsm8k(self, item): return { "input": f"Solve step by step:\n\n{item['question']}\n\nSolution:", "target": item["answer"], "type": "math_reasoning", "num_hops": item["answer"].count("\n") + 1, } def _load_toolbench(self, split): # Placeholder — ToolBench needs separate download return [] def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] if self.tokenizer: encoding = self.tokenizer( item["input"], max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt", ) return { "input_ids": encoding["input_ids"].squeeze(0), "attention_mask": encoding["attention_mask"].squeeze(0), "target": item["target"], "num_hops": item["num_hops"], } return item def get_dataset(name, split="validation", tokenizer=None, max_samples=500): """Convenience function to load a dataset.""" return MultiStepReasoningDataset( dataset_name=name, split=split, tokenizer=tokenizer, max_samples=max_samples, )