File size: 3,029 Bytes
09dd617
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Dataset loaders for DPA experiments.
Loads HotpotQA, GSM8K, ToolBench for multi-step reasoning evaluation.
"""

from datasets import load_dataset
from torch.utils.data import Dataset
import torch
import json


class MultiStepReasoningDataset(Dataset):
    """Unified dataset for multi-step reasoning tasks."""

    def __init__(self, dataset_name="hotpotqa", split="validation",
                 tokenizer=None, max_length=2048, max_samples=None):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.dataset_name = dataset_name

        if dataset_name == "hotpotqa":
            ds = load_dataset("hotpot_qa", "distractor", split=split)
            self.data = [self._process_hotpotqa(item) for item in ds]
        elif dataset_name == "gsm8k":
            ds = load_dataset("openai/gsm8k", "main", split=split)
            self.data = [self._process_gsm8k(item) for item in ds]
        elif dataset_name == "toolbench":
            # ToolBench needs manual download
            self.data = self._load_toolbench(split)
        else:
            raise ValueError(f"Unknown dataset: {dataset_name}")

        if max_samples:
            self.data = self.data[:max_samples]

        print(f"Loaded {len(self.data)} samples from {dataset_name}/{split}")

    def _process_hotpotqa(self, item):
        context = " ".join([
            " ".join(sents) for sents in item["context"]["sentences"]
        ])
        return {
            "input": f"Answer the question based on the context.\n\nContext: {context}\n\nQuestion: {item['question']}\n\nAnswer:",
            "target": item["answer"],
            "type": item["type"],  # "bridge" or "comparison"
            "num_hops": 2 if item["type"] == "bridge" else 1,
        }

    def _process_gsm8k(self, item):
        return {
            "input": f"Solve step by step:\n\n{item['question']}\n\nSolution:",
            "target": item["answer"],
            "type": "math_reasoning",
            "num_hops": item["answer"].count("\n") + 1,
        }

    def _load_toolbench(self, split):
        # Placeholder — ToolBench needs separate download
        return []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        if self.tokenizer:
            encoding = self.tokenizer(
                item["input"], max_length=self.max_length,
                truncation=True, padding="max_length", return_tensors="pt",
            )
            return {
                "input_ids": encoding["input_ids"].squeeze(0),
                "attention_mask": encoding["attention_mask"].squeeze(0),
                "target": item["target"],
                "num_hops": item["num_hops"],
            }
        return item


def get_dataset(name, split="validation", tokenizer=None, max_samples=500):
    """Convenience function to load a dataset."""
    return MultiStepReasoningDataset(
        dataset_name=name, split=split,
        tokenizer=tokenizer, max_samples=max_samples,
    )