cascade / ml /data /load_dataset.py
ayushm98's picture
feat: add dataset loader for complexity classification
ad8fa3f
"""Load and preprocess Easy2Hard-Bench dataset for complexity classification."""
import json
from pathlib import Path
from typing import Literal
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
def load_easy2hard_bench(
subset: Literal["all", "gsm8k", "arc", "winogrande"] = "all",
difficulty_threshold: float = 0.5,
max_samples: int | None = None,
seed: int = 42,
) -> DatasetDict:
"""
Load Easy2Hard-Bench dataset and convert to binary classification.
Args:
subset: Which subset to load ("all" for combined dataset)
difficulty_threshold: Score above this is "complex" (1), below is "simple" (0)
max_samples: Maximum samples to use (None for all)
seed: Random seed for shuffling
Returns:
DatasetDict with train/validation/test splits
"""
print(f"Loading Easy2Hard-Bench dataset (subset={subset})...")
# Load the dataset from HuggingFace
dataset = load_dataset("furonghuang-lab/Easy2Hard-Bench")
# Get all available splits
all_data = []
for split_name in dataset.keys():
split_data = dataset[split_name]
all_data.append(split_data)
# Combine all splits
combined = concatenate_datasets(all_data)
print(f"Total examples loaded: {len(combined)}")
# Process the dataset
def process_example(example: dict) -> dict:
"""Extract text and create binary label from difficulty score."""
# Get the question/prompt text
text = example.get("question", "") or example.get("prompt", "") or example.get("input", "")
# Get difficulty score (normalize to 0-1 if needed)
difficulty = example.get("difficulty", 0.5)
# Convert to binary label
label = 1 if difficulty >= difficulty_threshold else 0
return {
"text": str(text).strip(),
"label": label,
"difficulty_score": float(difficulty),
}
# Apply processing
processed = combined.map(
process_example,
remove_columns=combined.column_names,
desc="Processing examples",
)
# Filter out empty texts
processed = processed.filter(lambda x: len(x["text"]) > 0)
print(f"After filtering empty texts: {len(processed)}")
# Shuffle the dataset
processed = processed.shuffle(seed=seed)
# Limit samples if specified
if max_samples and len(processed) > max_samples:
processed = processed.select(range(max_samples))
print(f"Limited to {max_samples} samples")
# Create train/val/test splits (70/15/15)
train_test = processed.train_test_split(test_size=0.3, seed=seed)
val_test = train_test["test"].train_test_split(test_size=0.5, seed=seed)
dataset_dict = DatasetDict(
{
"train": train_test["train"],
"validation": val_test["train"],
"test": val_test["test"],
}
)
# Print statistics
print("\nDataset splits:")
for split_name, split_data in dataset_dict.items():
n_simple = sum(1 for x in split_data if x["label"] == 0)
n_complex = sum(1 for x in split_data if x["label"] == 1)
print(f" {split_name}: {len(split_data)} total ({n_simple} simple, {n_complex} complex)")
return dataset_dict
def load_arc_dataset(max_samples: int | None = None, seed: int = 42) -> DatasetDict:
"""
Load ARC dataset with pre-defined Easy/Challenge splits.
This is an alternative to Easy2Hard-Bench that has explicit easy/hard labels.
Args:
max_samples: Maximum samples per split (None for all)
seed: Random seed for shuffling
Returns:
DatasetDict with train/validation/test splits
"""
print("Loading ARC dataset (Easy + Challenge)...")
# Load both splits
arc_easy = load_dataset("allenai/ai2_arc", "ARC-Easy")
arc_challenge = load_dataset("allenai/ai2_arc", "ARC-Challenge")
def process_arc(example: dict, label: int) -> dict:
"""Process ARC example."""
return {
"text": example["question"].strip(),
"label": label,
"difficulty_score": 0.25 if label == 0 else 0.75,
}
# Process and label
easy_data = arc_easy["train"].map(
lambda x: process_arc(x, 0),
remove_columns=arc_easy["train"].column_names,
)
challenge_data = arc_challenge["train"].map(
lambda x: process_arc(x, 1),
remove_columns=arc_challenge["train"].column_names,
)
# Combine
combined = concatenate_datasets([easy_data, challenge_data])
combined = combined.shuffle(seed=seed)
if max_samples and len(combined) > max_samples:
combined = combined.select(range(max_samples))
# Split
train_test = combined.train_test_split(test_size=0.3, seed=seed)
val_test = train_test["test"].train_test_split(test_size=0.5, seed=seed)
dataset_dict = DatasetDict(
{
"train": train_test["train"],
"validation": val_test["train"],
"test": val_test["test"],
}
)
print("\nDataset splits:")
for split_name, split_data in dataset_dict.items():
n_simple = sum(1 for x in split_data if x["label"] == 0)
n_complex = sum(1 for x in split_data if x["label"] == 1)
print(f" {split_name}: {len(split_data)} total ({n_simple} simple, {n_complex} complex)")
return dataset_dict
def save_dataset(dataset: DatasetDict, output_dir: str | Path) -> None:
"""Save processed dataset to disk."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for split_name, split_data in dataset.items():
output_path = output_dir / f"{split_name}.jsonl"
with open(output_path, "w") as f:
for example in split_data:
f.write(json.dumps(example) + "\n")
print(f"Saved {split_name} to {output_path}")
if __name__ == "__main__":
# Example usage
import argparse
parser = argparse.ArgumentParser(description="Load complexity classification dataset")
parser.add_argument(
"--dataset",
choices=["easy2hard", "arc"],
default="easy2hard",
help="Dataset to load",
)
parser.add_argument(
"--max-samples",
type=int,
default=None,
help="Maximum samples to use",
)
parser.add_argument(
"--threshold",
type=float,
default=0.5,
help="Difficulty threshold for binary classification",
)
parser.add_argument(
"--output-dir",
type=str,
default="ml/data/processed",
help="Output directory for processed data",
)
parser.add_argument(
"--save",
action="store_true",
help="Save processed dataset to disk",
)
args = parser.parse_args()
if args.dataset == "easy2hard":
dataset = load_easy2hard_bench(
difficulty_threshold=args.threshold,
max_samples=args.max_samples,
)
else:
dataset = load_arc_dataset(max_samples=args.max_samples)
if args.save:
save_dataset(dataset, args.output_dir)
# Show some examples
print("\nSample examples:")
for i, example in enumerate(dataset["train"].select(range(3))):
label_str = "complex" if example["label"] == 1 else "simple"
print(f"\n[{i+1}] ({label_str}, score={example['difficulty_score']:.2f})")
print(f" {example['text'][:100]}...")