| """ |
| Dataset loaders for DPA experiments. |
| Loads HotpotQA, GSM8K, ToolBench for multi-step reasoning evaluation. |
| """ |
|
|
| from datasets import load_dataset |
| from torch.utils.data import Dataset |
| import torch |
| import json |
|
|
|
|
| class MultiStepReasoningDataset(Dataset): |
| """Unified dataset for multi-step reasoning tasks.""" |
|
|
| def __init__(self, dataset_name="hotpotqa", split="validation", |
| tokenizer=None, max_length=2048, max_samples=None): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.dataset_name = dataset_name |
|
|
| if dataset_name == "hotpotqa": |
| ds = load_dataset("hotpot_qa", "distractor", split=split) |
| self.data = [self._process_hotpotqa(item) for item in ds] |
| elif dataset_name == "gsm8k": |
| ds = load_dataset("openai/gsm8k", "main", split=split) |
| self.data = [self._process_gsm8k(item) for item in ds] |
| elif dataset_name == "toolbench": |
| |
| self.data = self._load_toolbench(split) |
| else: |
| raise ValueError(f"Unknown dataset: {dataset_name}") |
|
|
| if max_samples: |
| self.data = self.data[:max_samples] |
|
|
| print(f"Loaded {len(self.data)} samples from {dataset_name}/{split}") |
|
|
| def _process_hotpotqa(self, item): |
| context = " ".join([ |
| " ".join(sents) for sents in item["context"]["sentences"] |
| ]) |
| return { |
| "input": f"Answer the question based on the context.\n\nContext: {context}\n\nQuestion: {item['question']}\n\nAnswer:", |
| "target": item["answer"], |
| "type": item["type"], |
| "num_hops": 2 if item["type"] == "bridge" else 1, |
| } |
|
|
| def _process_gsm8k(self, item): |
| return { |
| "input": f"Solve step by step:\n\n{item['question']}\n\nSolution:", |
| "target": item["answer"], |
| "type": "math_reasoning", |
| "num_hops": item["answer"].count("\n") + 1, |
| } |
|
|
| def _load_toolbench(self, split): |
| |
| return [] |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| item = self.data[idx] |
| if self.tokenizer: |
| encoding = self.tokenizer( |
| item["input"], max_length=self.max_length, |
| truncation=True, padding="max_length", return_tensors="pt", |
| ) |
| return { |
| "input_ids": encoding["input_ids"].squeeze(0), |
| "attention_mask": encoding["attention_mask"].squeeze(0), |
| "target": item["target"], |
| "num_hops": item["num_hops"], |
| } |
| return item |
|
|
|
|
| def get_dataset(name, split="validation", tokenizer=None, max_samples=500): |
| """Convenience function to load a dataset.""" |
| return MultiStepReasoningDataset( |
| dataset_name=name, split=split, |
| tokenizer=tokenizer, max_samples=max_samples, |
| ) |
|
|