File size: 7,479 Bytes
ad8fa3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""Load and preprocess Easy2Hard-Bench dataset for complexity classification."""

import json
from pathlib import Path
from typing import Literal

from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset


def load_easy2hard_bench(
    subset: Literal["all", "gsm8k", "arc", "winogrande"] = "all",
    difficulty_threshold: float = 0.5,
    max_samples: int | None = None,
    seed: int = 42,
) -> DatasetDict:
    """
    Load Easy2Hard-Bench dataset and convert to binary classification.

    Args:
        subset: Which subset to load ("all" for combined dataset)
        difficulty_threshold: Score above this is "complex" (1), below is "simple" (0)
        max_samples: Maximum samples to use (None for all)
        seed: Random seed for shuffling

    Returns:
        DatasetDict with train/validation/test splits
    """
    print(f"Loading Easy2Hard-Bench dataset (subset={subset})...")

    # Load the dataset from HuggingFace
    dataset = load_dataset("furonghuang-lab/Easy2Hard-Bench")

    # Get all available splits
    all_data = []

    for split_name in dataset.keys():
        split_data = dataset[split_name]
        all_data.append(split_data)

    # Combine all splits
    combined = concatenate_datasets(all_data)

    print(f"Total examples loaded: {len(combined)}")

    # Process the dataset
    def process_example(example: dict) -> dict:
        """Extract text and create binary label from difficulty score."""
        # Get the question/prompt text
        text = example.get("question", "") or example.get("prompt", "") or example.get("input", "")

        # Get difficulty score (normalize to 0-1 if needed)
        difficulty = example.get("difficulty", 0.5)

        # Convert to binary label
        label = 1 if difficulty >= difficulty_threshold else 0

        return {
            "text": str(text).strip(),
            "label": label,
            "difficulty_score": float(difficulty),
        }

    # Apply processing
    processed = combined.map(
        process_example,
        remove_columns=combined.column_names,
        desc="Processing examples",
    )

    # Filter out empty texts
    processed = processed.filter(lambda x: len(x["text"]) > 0)

    print(f"After filtering empty texts: {len(processed)}")

    # Shuffle the dataset
    processed = processed.shuffle(seed=seed)

    # Limit samples if specified
    if max_samples and len(processed) > max_samples:
        processed = processed.select(range(max_samples))
        print(f"Limited to {max_samples} samples")

    # Create train/val/test splits (70/15/15)
    train_test = processed.train_test_split(test_size=0.3, seed=seed)
    val_test = train_test["test"].train_test_split(test_size=0.5, seed=seed)

    dataset_dict = DatasetDict(
        {
            "train": train_test["train"],
            "validation": val_test["train"],
            "test": val_test["test"],
        }
    )

    # Print statistics
    print("\nDataset splits:")
    for split_name, split_data in dataset_dict.items():
        n_simple = sum(1 for x in split_data if x["label"] == 0)
        n_complex = sum(1 for x in split_data if x["label"] == 1)
        print(f"  {split_name}: {len(split_data)} total ({n_simple} simple, {n_complex} complex)")

    return dataset_dict


def load_arc_dataset(max_samples: int | None = None, seed: int = 42) -> DatasetDict:
    """
    Load ARC dataset with pre-defined Easy/Challenge splits.

    This is an alternative to Easy2Hard-Bench that has explicit easy/hard labels.

    Args:
        max_samples: Maximum samples per split (None for all)
        seed: Random seed for shuffling

    Returns:
        DatasetDict with train/validation/test splits
    """
    print("Loading ARC dataset (Easy + Challenge)...")

    # Load both splits
    arc_easy = load_dataset("allenai/ai2_arc", "ARC-Easy")
    arc_challenge = load_dataset("allenai/ai2_arc", "ARC-Challenge")

    def process_arc(example: dict, label: int) -> dict:
        """Process ARC example."""
        return {
            "text": example["question"].strip(),
            "label": label,
            "difficulty_score": 0.25 if label == 0 else 0.75,
        }

    # Process and label
    easy_data = arc_easy["train"].map(
        lambda x: process_arc(x, 0),
        remove_columns=arc_easy["train"].column_names,
    )
    challenge_data = arc_challenge["train"].map(
        lambda x: process_arc(x, 1),
        remove_columns=arc_challenge["train"].column_names,
    )

    # Combine
    combined = concatenate_datasets([easy_data, challenge_data])
    combined = combined.shuffle(seed=seed)

    if max_samples and len(combined) > max_samples:
        combined = combined.select(range(max_samples))

    # Split
    train_test = combined.train_test_split(test_size=0.3, seed=seed)
    val_test = train_test["test"].train_test_split(test_size=0.5, seed=seed)

    dataset_dict = DatasetDict(
        {
            "train": train_test["train"],
            "validation": val_test["train"],
            "test": val_test["test"],
        }
    )

    print("\nDataset splits:")
    for split_name, split_data in dataset_dict.items():
        n_simple = sum(1 for x in split_data if x["label"] == 0)
        n_complex = sum(1 for x in split_data if x["label"] == 1)
        print(f"  {split_name}: {len(split_data)} total ({n_simple} simple, {n_complex} complex)")

    return dataset_dict


def save_dataset(dataset: DatasetDict, output_dir: str | Path) -> None:
    """Save processed dataset to disk."""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    for split_name, split_data in dataset.items():
        output_path = output_dir / f"{split_name}.jsonl"
        with open(output_path, "w") as f:
            for example in split_data:
                f.write(json.dumps(example) + "\n")
        print(f"Saved {split_name} to {output_path}")


if __name__ == "__main__":
    # Example usage
    import argparse

    parser = argparse.ArgumentParser(description="Load complexity classification dataset")
    parser.add_argument(
        "--dataset",
        choices=["easy2hard", "arc"],
        default="easy2hard",
        help="Dataset to load",
    )
    parser.add_argument(
        "--max-samples",
        type=int,
        default=None,
        help="Maximum samples to use",
    )
    parser.add_argument(
        "--threshold",
        type=float,
        default=0.5,
        help="Difficulty threshold for binary classification",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="ml/data/processed",
        help="Output directory for processed data",
    )
    parser.add_argument(
        "--save",
        action="store_true",
        help="Save processed dataset to disk",
    )

    args = parser.parse_args()

    if args.dataset == "easy2hard":
        dataset = load_easy2hard_bench(
            difficulty_threshold=args.threshold,
            max_samples=args.max_samples,
        )
    else:
        dataset = load_arc_dataset(max_samples=args.max_samples)

    if args.save:
        save_dataset(dataset, args.output_dir)

    # Show some examples
    print("\nSample examples:")
    for i, example in enumerate(dataset["train"].select(range(3))):
        label_str = "complex" if example["label"] == 1 else "simple"
        print(f"\n[{i+1}] ({label_str}, score={example['difficulty_score']:.2f})")
        print(f"    {example['text'][:100]}...")