File size: 3,029 Bytes
09dd617 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | """
Dataset loaders for DPA experiments.
Loads HotpotQA, GSM8K, ToolBench for multi-step reasoning evaluation.
"""
from datasets import load_dataset
from torch.utils.data import Dataset
import torch
import json
class MultiStepReasoningDataset(Dataset):
"""Unified dataset for multi-step reasoning tasks."""
def __init__(self, dataset_name="hotpotqa", split="validation",
tokenizer=None, max_length=2048, max_samples=None):
self.tokenizer = tokenizer
self.max_length = max_length
self.dataset_name = dataset_name
if dataset_name == "hotpotqa":
ds = load_dataset("hotpot_qa", "distractor", split=split)
self.data = [self._process_hotpotqa(item) for item in ds]
elif dataset_name == "gsm8k":
ds = load_dataset("openai/gsm8k", "main", split=split)
self.data = [self._process_gsm8k(item) for item in ds]
elif dataset_name == "toolbench":
# ToolBench needs manual download
self.data = self._load_toolbench(split)
else:
raise ValueError(f"Unknown dataset: {dataset_name}")
if max_samples:
self.data = self.data[:max_samples]
print(f"Loaded {len(self.data)} samples from {dataset_name}/{split}")
def _process_hotpotqa(self, item):
context = " ".join([
" ".join(sents) for sents in item["context"]["sentences"]
])
return {
"input": f"Answer the question based on the context.\n\nContext: {context}\n\nQuestion: {item['question']}\n\nAnswer:",
"target": item["answer"],
"type": item["type"], # "bridge" or "comparison"
"num_hops": 2 if item["type"] == "bridge" else 1,
}
def _process_gsm8k(self, item):
return {
"input": f"Solve step by step:\n\n{item['question']}\n\nSolution:",
"target": item["answer"],
"type": "math_reasoning",
"num_hops": item["answer"].count("\n") + 1,
}
def _load_toolbench(self, split):
# Placeholder — ToolBench needs separate download
return []
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
if self.tokenizer:
encoding = self.tokenizer(
item["input"], max_length=self.max_length,
truncation=True, padding="max_length", return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"target": item["target"],
"num_hops": item["num_hops"],
}
return item
def get_dataset(name, split="validation", tokenizer=None, max_samples=500):
"""Convenience function to load a dataset."""
return MultiStepReasoningDataset(
dataset_name=name, split=split,
tokenizer=tokenizer, max_samples=max_samples,
)
|