ayushm98 commited on
Commit
ad8fa3f
·
1 Parent(s): 70f6529

feat: add dataset loader for complexity classification

Browse files

- Support Easy2Hard-Bench dataset with difficulty scores
- Support ARC dataset with Easy/Challenge splits
- Convert continuous difficulty to binary labels
- Create train/validation/test splits (70/15/15)

Files changed (1) hide show
  1. ml/data/load_dataset.py +236 -0
ml/data/load_dataset.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load and preprocess Easy2Hard-Bench dataset for complexity classification."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Literal
6
+
7
+ from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
8
+
9
+
10
+ def load_easy2hard_bench(
11
+ subset: Literal["all", "gsm8k", "arc", "winogrande"] = "all",
12
+ difficulty_threshold: float = 0.5,
13
+ max_samples: int | None = None,
14
+ seed: int = 42,
15
+ ) -> DatasetDict:
16
+ """
17
+ Load Easy2Hard-Bench dataset and convert to binary classification.
18
+
19
+ Args:
20
+ subset: Which subset to load ("all" for combined dataset)
21
+ difficulty_threshold: Score above this is "complex" (1), below is "simple" (0)
22
+ max_samples: Maximum samples to use (None for all)
23
+ seed: Random seed for shuffling
24
+
25
+ Returns:
26
+ DatasetDict with train/validation/test splits
27
+ """
28
+ print(f"Loading Easy2Hard-Bench dataset (subset={subset})...")
29
+
30
+ # Load the dataset from HuggingFace
31
+ dataset = load_dataset("furonghuang-lab/Easy2Hard-Bench")
32
+
33
+ # Get all available splits
34
+ all_data = []
35
+
36
+ for split_name in dataset.keys():
37
+ split_data = dataset[split_name]
38
+ all_data.append(split_data)
39
+
40
+ # Combine all splits
41
+ combined = concatenate_datasets(all_data)
42
+
43
+ print(f"Total examples loaded: {len(combined)}")
44
+
45
+ # Process the dataset
46
+ def process_example(example: dict) -> dict:
47
+ """Extract text and create binary label from difficulty score."""
48
+ # Get the question/prompt text
49
+ text = example.get("question", "") or example.get("prompt", "") or example.get("input", "")
50
+
51
+ # Get difficulty score (normalize to 0-1 if needed)
52
+ difficulty = example.get("difficulty", 0.5)
53
+
54
+ # Convert to binary label
55
+ label = 1 if difficulty >= difficulty_threshold else 0
56
+
57
+ return {
58
+ "text": str(text).strip(),
59
+ "label": label,
60
+ "difficulty_score": float(difficulty),
61
+ }
62
+
63
+ # Apply processing
64
+ processed = combined.map(
65
+ process_example,
66
+ remove_columns=combined.column_names,
67
+ desc="Processing examples",
68
+ )
69
+
70
+ # Filter out empty texts
71
+ processed = processed.filter(lambda x: len(x["text"]) > 0)
72
+
73
+ print(f"After filtering empty texts: {len(processed)}")
74
+
75
+ # Shuffle the dataset
76
+ processed = processed.shuffle(seed=seed)
77
+
78
+ # Limit samples if specified
79
+ if max_samples and len(processed) > max_samples:
80
+ processed = processed.select(range(max_samples))
81
+ print(f"Limited to {max_samples} samples")
82
+
83
+ # Create train/val/test splits (70/15/15)
84
+ train_test = processed.train_test_split(test_size=0.3, seed=seed)
85
+ val_test = train_test["test"].train_test_split(test_size=0.5, seed=seed)
86
+
87
+ dataset_dict = DatasetDict(
88
+ {
89
+ "train": train_test["train"],
90
+ "validation": val_test["train"],
91
+ "test": val_test["test"],
92
+ }
93
+ )
94
+
95
+ # Print statistics
96
+ print("\nDataset splits:")
97
+ for split_name, split_data in dataset_dict.items():
98
+ n_simple = sum(1 for x in split_data if x["label"] == 0)
99
+ n_complex = sum(1 for x in split_data if x["label"] == 1)
100
+ print(f" {split_name}: {len(split_data)} total ({n_simple} simple, {n_complex} complex)")
101
+
102
+ return dataset_dict
103
+
104
+
105
+ def load_arc_dataset(max_samples: int | None = None, seed: int = 42) -> DatasetDict:
106
+ """
107
+ Load ARC dataset with pre-defined Easy/Challenge splits.
108
+
109
+ This is an alternative to Easy2Hard-Bench that has explicit easy/hard labels.
110
+
111
+ Args:
112
+ max_samples: Maximum samples per split (None for all)
113
+ seed: Random seed for shuffling
114
+
115
+ Returns:
116
+ DatasetDict with train/validation/test splits
117
+ """
118
+ print("Loading ARC dataset (Easy + Challenge)...")
119
+
120
+ # Load both splits
121
+ arc_easy = load_dataset("allenai/ai2_arc", "ARC-Easy")
122
+ arc_challenge = load_dataset("allenai/ai2_arc", "ARC-Challenge")
123
+
124
+ def process_arc(example: dict, label: int) -> dict:
125
+ """Process ARC example."""
126
+ return {
127
+ "text": example["question"].strip(),
128
+ "label": label,
129
+ "difficulty_score": 0.25 if label == 0 else 0.75,
130
+ }
131
+
132
+ # Process and label
133
+ easy_data = arc_easy["train"].map(
134
+ lambda x: process_arc(x, 0),
135
+ remove_columns=arc_easy["train"].column_names,
136
+ )
137
+ challenge_data = arc_challenge["train"].map(
138
+ lambda x: process_arc(x, 1),
139
+ remove_columns=arc_challenge["train"].column_names,
140
+ )
141
+
142
+ # Combine
143
+ combined = concatenate_datasets([easy_data, challenge_data])
144
+ combined = combined.shuffle(seed=seed)
145
+
146
+ if max_samples and len(combined) > max_samples:
147
+ combined = combined.select(range(max_samples))
148
+
149
+ # Split
150
+ train_test = combined.train_test_split(test_size=0.3, seed=seed)
151
+ val_test = train_test["test"].train_test_split(test_size=0.5, seed=seed)
152
+
153
+ dataset_dict = DatasetDict(
154
+ {
155
+ "train": train_test["train"],
156
+ "validation": val_test["train"],
157
+ "test": val_test["test"],
158
+ }
159
+ )
160
+
161
+ print("\nDataset splits:")
162
+ for split_name, split_data in dataset_dict.items():
163
+ n_simple = sum(1 for x in split_data if x["label"] == 0)
164
+ n_complex = sum(1 for x in split_data if x["label"] == 1)
165
+ print(f" {split_name}: {len(split_data)} total ({n_simple} simple, {n_complex} complex)")
166
+
167
+ return dataset_dict
168
+
169
+
170
+ def save_dataset(dataset: DatasetDict, output_dir: str | Path) -> None:
171
+ """Save processed dataset to disk."""
172
+ output_dir = Path(output_dir)
173
+ output_dir.mkdir(parents=True, exist_ok=True)
174
+
175
+ for split_name, split_data in dataset.items():
176
+ output_path = output_dir / f"{split_name}.jsonl"
177
+ with open(output_path, "w") as f:
178
+ for example in split_data:
179
+ f.write(json.dumps(example) + "\n")
180
+ print(f"Saved {split_name} to {output_path}")
181
+
182
+
183
+ if __name__ == "__main__":
184
+ # Example usage
185
+ import argparse
186
+
187
+ parser = argparse.ArgumentParser(description="Load complexity classification dataset")
188
+ parser.add_argument(
189
+ "--dataset",
190
+ choices=["easy2hard", "arc"],
191
+ default="easy2hard",
192
+ help="Dataset to load",
193
+ )
194
+ parser.add_argument(
195
+ "--max-samples",
196
+ type=int,
197
+ default=None,
198
+ help="Maximum samples to use",
199
+ )
200
+ parser.add_argument(
201
+ "--threshold",
202
+ type=float,
203
+ default=0.5,
204
+ help="Difficulty threshold for binary classification",
205
+ )
206
+ parser.add_argument(
207
+ "--output-dir",
208
+ type=str,
209
+ default="ml/data/processed",
210
+ help="Output directory for processed data",
211
+ )
212
+ parser.add_argument(
213
+ "--save",
214
+ action="store_true",
215
+ help="Save processed dataset to disk",
216
+ )
217
+
218
+ args = parser.parse_args()
219
+
220
+ if args.dataset == "easy2hard":
221
+ dataset = load_easy2hard_bench(
222
+ difficulty_threshold=args.threshold,
223
+ max_samples=args.max_samples,
224
+ )
225
+ else:
226
+ dataset = load_arc_dataset(max_samples=args.max_samples)
227
+
228
+ if args.save:
229
+ save_dataset(dataset, args.output_dir)
230
+
231
+ # Show some examples
232
+ print("\nSample examples:")
233
+ for i, example in enumerate(dataset["train"].select(range(3))):
234
+ label_str = "complex" if example["label"] == 1 else "simple"
235
+ print(f"\n[{i+1}] ({label_str}, score={example['difficulty_score']:.2f})")
236
+ print(f" {example['text'][:100]}...")