jasonfan's picture
Upload folder using huggingface_hub
09dd617 verified
"""
Dataset loaders for DPA experiments.
Loads HotpotQA, GSM8K, ToolBench for multi-step reasoning evaluation.
"""
from datasets import load_dataset
from torch.utils.data import Dataset
import torch
import json
class MultiStepReasoningDataset(Dataset):
"""Unified dataset for multi-step reasoning tasks."""
def __init__(self, dataset_name="hotpotqa", split="validation",
tokenizer=None, max_length=2048, max_samples=None):
self.tokenizer = tokenizer
self.max_length = max_length
self.dataset_name = dataset_name
if dataset_name == "hotpotqa":
ds = load_dataset("hotpot_qa", "distractor", split=split)
self.data = [self._process_hotpotqa(item) for item in ds]
elif dataset_name == "gsm8k":
ds = load_dataset("openai/gsm8k", "main", split=split)
self.data = [self._process_gsm8k(item) for item in ds]
elif dataset_name == "toolbench":
# ToolBench needs manual download
self.data = self._load_toolbench(split)
else:
raise ValueError(f"Unknown dataset: {dataset_name}")
if max_samples:
self.data = self.data[:max_samples]
print(f"Loaded {len(self.data)} samples from {dataset_name}/{split}")
def _process_hotpotqa(self, item):
context = " ".join([
" ".join(sents) for sents in item["context"]["sentences"]
])
return {
"input": f"Answer the question based on the context.\n\nContext: {context}\n\nQuestion: {item['question']}\n\nAnswer:",
"target": item["answer"],
"type": item["type"], # "bridge" or "comparison"
"num_hops": 2 if item["type"] == "bridge" else 1,
}
def _process_gsm8k(self, item):
return {
"input": f"Solve step by step:\n\n{item['question']}\n\nSolution:",
"target": item["answer"],
"type": "math_reasoning",
"num_hops": item["answer"].count("\n") + 1,
}
def _load_toolbench(self, split):
# Placeholder — ToolBench needs separate download
return []
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
if self.tokenizer:
encoding = self.tokenizer(
item["input"], max_length=self.max_length,
truncation=True, padding="max_length", return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"target": item["target"],
"num_hops": item["num_hops"],
}
return item
def get_dataset(name, split="validation", tokenizer=None, max_samples=500):
"""Convenience function to load a dataset."""
return MultiStepReasoningDataset(
dataset_name=name, split=split,
tokenizer=tokenizer, max_samples=max_samples,
)