|
|
"""Load and preprocess Easy2Hard-Bench dataset for complexity classification.""" |
|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import Literal |
|
|
|
|
|
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset |
|
|
|
|
|
|
|
|
def load_easy2hard_bench( |
|
|
subset: Literal["all", "gsm8k", "arc", "winogrande"] = "all", |
|
|
difficulty_threshold: float = 0.5, |
|
|
max_samples: int | None = None, |
|
|
seed: int = 42, |
|
|
) -> DatasetDict: |
|
|
""" |
|
|
Load Easy2Hard-Bench dataset and convert to binary classification. |
|
|
|
|
|
Args: |
|
|
subset: Which subset to load ("all" for combined dataset) |
|
|
difficulty_threshold: Score above this is "complex" (1), below is "simple" (0) |
|
|
max_samples: Maximum samples to use (None for all) |
|
|
seed: Random seed for shuffling |
|
|
|
|
|
Returns: |
|
|
DatasetDict with train/validation/test splits |
|
|
""" |
|
|
print(f"Loading Easy2Hard-Bench dataset (subset={subset})...") |
|
|
|
|
|
|
|
|
dataset = load_dataset("furonghuang-lab/Easy2Hard-Bench") |
|
|
|
|
|
|
|
|
all_data = [] |
|
|
|
|
|
for split_name in dataset.keys(): |
|
|
split_data = dataset[split_name] |
|
|
all_data.append(split_data) |
|
|
|
|
|
|
|
|
combined = concatenate_datasets(all_data) |
|
|
|
|
|
print(f"Total examples loaded: {len(combined)}") |
|
|
|
|
|
|
|
|
def process_example(example: dict) -> dict: |
|
|
"""Extract text and create binary label from difficulty score.""" |
|
|
|
|
|
text = example.get("question", "") or example.get("prompt", "") or example.get("input", "") |
|
|
|
|
|
|
|
|
difficulty = example.get("difficulty", 0.5) |
|
|
|
|
|
|
|
|
label = 1 if difficulty >= difficulty_threshold else 0 |
|
|
|
|
|
return { |
|
|
"text": str(text).strip(), |
|
|
"label": label, |
|
|
"difficulty_score": float(difficulty), |
|
|
} |
|
|
|
|
|
|
|
|
processed = combined.map( |
|
|
process_example, |
|
|
remove_columns=combined.column_names, |
|
|
desc="Processing examples", |
|
|
) |
|
|
|
|
|
|
|
|
processed = processed.filter(lambda x: len(x["text"]) > 0) |
|
|
|
|
|
print(f"After filtering empty texts: {len(processed)}") |
|
|
|
|
|
|
|
|
processed = processed.shuffle(seed=seed) |
|
|
|
|
|
|
|
|
if max_samples and len(processed) > max_samples: |
|
|
processed = processed.select(range(max_samples)) |
|
|
print(f"Limited to {max_samples} samples") |
|
|
|
|
|
|
|
|
train_test = processed.train_test_split(test_size=0.3, seed=seed) |
|
|
val_test = train_test["test"].train_test_split(test_size=0.5, seed=seed) |
|
|
|
|
|
dataset_dict = DatasetDict( |
|
|
{ |
|
|
"train": train_test["train"], |
|
|
"validation": val_test["train"], |
|
|
"test": val_test["test"], |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
print("\nDataset splits:") |
|
|
for split_name, split_data in dataset_dict.items(): |
|
|
n_simple = sum(1 for x in split_data if x["label"] == 0) |
|
|
n_complex = sum(1 for x in split_data if x["label"] == 1) |
|
|
print(f" {split_name}: {len(split_data)} total ({n_simple} simple, {n_complex} complex)") |
|
|
|
|
|
return dataset_dict |
|
|
|
|
|
|
|
|
def load_arc_dataset(max_samples: int | None = None, seed: int = 42) -> DatasetDict: |
|
|
""" |
|
|
Load ARC dataset with pre-defined Easy/Challenge splits. |
|
|
|
|
|
This is an alternative to Easy2Hard-Bench that has explicit easy/hard labels. |
|
|
|
|
|
Args: |
|
|
max_samples: Maximum samples per split (None for all) |
|
|
seed: Random seed for shuffling |
|
|
|
|
|
Returns: |
|
|
DatasetDict with train/validation/test splits |
|
|
""" |
|
|
print("Loading ARC dataset (Easy + Challenge)...") |
|
|
|
|
|
|
|
|
arc_easy = load_dataset("allenai/ai2_arc", "ARC-Easy") |
|
|
arc_challenge = load_dataset("allenai/ai2_arc", "ARC-Challenge") |
|
|
|
|
|
def process_arc(example: dict, label: int) -> dict: |
|
|
"""Process ARC example.""" |
|
|
return { |
|
|
"text": example["question"].strip(), |
|
|
"label": label, |
|
|
"difficulty_score": 0.25 if label == 0 else 0.75, |
|
|
} |
|
|
|
|
|
|
|
|
easy_data = arc_easy["train"].map( |
|
|
lambda x: process_arc(x, 0), |
|
|
remove_columns=arc_easy["train"].column_names, |
|
|
) |
|
|
challenge_data = arc_challenge["train"].map( |
|
|
lambda x: process_arc(x, 1), |
|
|
remove_columns=arc_challenge["train"].column_names, |
|
|
) |
|
|
|
|
|
|
|
|
combined = concatenate_datasets([easy_data, challenge_data]) |
|
|
combined = combined.shuffle(seed=seed) |
|
|
|
|
|
if max_samples and len(combined) > max_samples: |
|
|
combined = combined.select(range(max_samples)) |
|
|
|
|
|
|
|
|
train_test = combined.train_test_split(test_size=0.3, seed=seed) |
|
|
val_test = train_test["test"].train_test_split(test_size=0.5, seed=seed) |
|
|
|
|
|
dataset_dict = DatasetDict( |
|
|
{ |
|
|
"train": train_test["train"], |
|
|
"validation": val_test["train"], |
|
|
"test": val_test["test"], |
|
|
} |
|
|
) |
|
|
|
|
|
print("\nDataset splits:") |
|
|
for split_name, split_data in dataset_dict.items(): |
|
|
n_simple = sum(1 for x in split_data if x["label"] == 0) |
|
|
n_complex = sum(1 for x in split_data if x["label"] == 1) |
|
|
print(f" {split_name}: {len(split_data)} total ({n_simple} simple, {n_complex} complex)") |
|
|
|
|
|
return dataset_dict |
|
|
|
|
|
|
|
|
def save_dataset(dataset: DatasetDict, output_dir: str | Path) -> None: |
|
|
"""Save processed dataset to disk.""" |
|
|
output_dir = Path(output_dir) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
for split_name, split_data in dataset.items(): |
|
|
output_path = output_dir / f"{split_name}.jsonl" |
|
|
with open(output_path, "w") as f: |
|
|
for example in split_data: |
|
|
f.write(json.dumps(example) + "\n") |
|
|
print(f"Saved {split_name} to {output_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Load complexity classification dataset") |
|
|
parser.add_argument( |
|
|
"--dataset", |
|
|
choices=["easy2hard", "arc"], |
|
|
default="easy2hard", |
|
|
help="Dataset to load", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-samples", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Maximum samples to use", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--threshold", |
|
|
type=float, |
|
|
default=0.5, |
|
|
help="Difficulty threshold for binary classification", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=str, |
|
|
default="ml/data/processed", |
|
|
help="Output directory for processed data", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--save", |
|
|
action="store_true", |
|
|
help="Save processed dataset to disk", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.dataset == "easy2hard": |
|
|
dataset = load_easy2hard_bench( |
|
|
difficulty_threshold=args.threshold, |
|
|
max_samples=args.max_samples, |
|
|
) |
|
|
else: |
|
|
dataset = load_arc_dataset(max_samples=args.max_samples) |
|
|
|
|
|
if args.save: |
|
|
save_dataset(dataset, args.output_dir) |
|
|
|
|
|
|
|
|
print("\nSample examples:") |
|
|
for i, example in enumerate(dataset["train"].select(range(3))): |
|
|
label_str = "complex" if example["label"] == 1 else "simple" |
|
|
print(f"\n[{i+1}] ({label_str}, score={example['difficulty_score']:.2f})") |
|
|
print(f" {example['text'][:100]}...") |
|
|
|