|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import csv |
|
|
import io |
|
|
import json |
|
|
import os |
|
|
import random |
|
|
import tarfile |
|
|
import urllib.request |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
random.seed(42) |
|
|
|
|
|
subcategories = { |
|
|
"abstract_algebra": ["math"], |
|
|
"anatomy": ["health"], |
|
|
"astronomy": ["physics"], |
|
|
"business_ethics": ["business"], |
|
|
"clinical_knowledge": ["health"], |
|
|
"college_biology": ["biology"], |
|
|
"college_chemistry": ["chemistry"], |
|
|
"college_computer_science": ["computer science"], |
|
|
"college_mathematics": ["math"], |
|
|
"college_medicine": ["health"], |
|
|
"college_physics": ["physics"], |
|
|
"computer_security": ["computer science"], |
|
|
"conceptual_physics": ["physics"], |
|
|
"econometrics": ["economics"], |
|
|
"electrical_engineering": ["engineering"], |
|
|
"elementary_mathematics": ["math"], |
|
|
"formal_logic": ["philosophy"], |
|
|
"global_facts": ["other"], |
|
|
"high_school_biology": ["biology"], |
|
|
"high_school_chemistry": ["chemistry"], |
|
|
"high_school_computer_science": ["computer science"], |
|
|
"high_school_european_history": ["history"], |
|
|
"high_school_geography": ["geography"], |
|
|
"high_school_government_and_politics": ["politics"], |
|
|
"high_school_macroeconomics": ["economics"], |
|
|
"high_school_mathematics": ["math"], |
|
|
"high_school_microeconomics": ["economics"], |
|
|
"high_school_physics": ["physics"], |
|
|
"high_school_psychology": ["psychology"], |
|
|
"high_school_statistics": ["math"], |
|
|
"high_school_us_history": ["history"], |
|
|
"high_school_world_history": ["history"], |
|
|
"human_aging": ["health"], |
|
|
"human_sexuality": ["culture"], |
|
|
"international_law": ["law"], |
|
|
"jurisprudence": ["law"], |
|
|
"logical_fallacies": ["philosophy"], |
|
|
"machine_learning": ["computer science"], |
|
|
"management": ["business"], |
|
|
"marketing": ["business"], |
|
|
"medical_genetics": ["health"], |
|
|
"miscellaneous": ["other"], |
|
|
"moral_disputes": ["philosophy"], |
|
|
"moral_scenarios": ["philosophy"], |
|
|
"nutrition": ["health"], |
|
|
"philosophy": ["philosophy"], |
|
|
"prehistory": ["history"], |
|
|
"professional_accounting": ["other"], |
|
|
"professional_law": ["law"], |
|
|
"professional_medicine": ["health"], |
|
|
"professional_psychology": ["psychology"], |
|
|
"public_relations": ["politics"], |
|
|
"security_studies": ["politics"], |
|
|
"sociology": ["culture"], |
|
|
"us_foreign_policy": ["politics"], |
|
|
"virology": ["health"], |
|
|
"world_religions": ["philosophy"], |
|
|
} |
|
|
|
|
|
|
|
|
def process_mmlu_dataset(split) -> List[Dict[str, Any]]: |
|
|
"""Process the MMLU dataset and return a list of formatted entries.""" |
|
|
data_url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar" |
|
|
data_dir = Path(__file__).absolute().parent |
|
|
data_file = str(data_dir / f"data.tar") |
|
|
data_dir.mkdir(exist_ok=True) |
|
|
output_file = str(data_dir / f"{split}.jsonl") |
|
|
|
|
|
urllib.request.urlretrieve(data_url, data_file) |
|
|
result = {} |
|
|
|
|
|
column_names = ["question", "A", "B", "C", "D", "expected_answer"] |
|
|
|
|
|
with tarfile.open(data_file, 'r') as tar: |
|
|
|
|
|
members = tar.getmembers() |
|
|
|
|
|
|
|
|
csv_files = [ |
|
|
member for member in members if member.name.startswith(f'data/{split}/') and member.name.endswith('.csv') |
|
|
] |
|
|
|
|
|
for csv_file in csv_files: |
|
|
|
|
|
file_name = os.path.basename(csv_file.name) |
|
|
|
|
|
|
|
|
file_content = tar.extractfile(csv_file) |
|
|
if file_content is not None: |
|
|
|
|
|
content_str = io.TextIOWrapper(file_content, encoding='utf-8') |
|
|
|
|
|
|
|
|
csv_reader = csv.reader(content_str) |
|
|
|
|
|
|
|
|
csv_data = [] |
|
|
for row in csv_reader: |
|
|
if len(row) == len(column_names): |
|
|
csv_data.append(dict(zip(column_names, row))) |
|
|
else: |
|
|
print(f"Warning: Skipping row in {file_name} due to incorrect number of columns") |
|
|
|
|
|
|
|
|
result[file_name.rsplit('_', 1)[0]] = csv_data |
|
|
|
|
|
|
|
|
chosen_subcategories = {"math", "health", "physics", "biology", "chemistry", "computer science", "engineering"} |
|
|
|
|
|
|
|
|
filtered_result = { |
|
|
k: v |
|
|
for k, v in result.items() |
|
|
if any(supercat in chosen_subcategories for supercat in subcategories.get(k, [])) |
|
|
} |
|
|
|
|
|
output_data = [] |
|
|
skipped = 0 |
|
|
|
|
|
for category, question_list in filtered_result.items(): |
|
|
for item in question_list: |
|
|
question = item['question'] |
|
|
choice_1 = item['A'] |
|
|
choice_2 = item['B'] |
|
|
choice_3 = item['C'] |
|
|
choice_4 = item['D'] |
|
|
answer = item['expected_answer'] |
|
|
|
|
|
if not question or not answer or not choice_1 or not choice_2 or not choice_3 or not choice_4: |
|
|
skipped += 1 |
|
|
continue |
|
|
|
|
|
entry = { |
|
|
'Question': question, |
|
|
'Choice 1': choice_1, |
|
|
'Choice 2': choice_2, |
|
|
'Choice 3': choice_3, |
|
|
'Choice 4': choice_4, |
|
|
'Answer': answer, |
|
|
'Subject': category, |
|
|
} |
|
|
|
|
|
output_data.append(entry) |
|
|
|
|
|
return output_data, skipped |
|
|
|
|
|
|
|
|
def process_gpqa_dataset() -> List[Dict[str, Any]]: |
|
|
"""Process the GPQA dataset and return a list of formatted entries.""" |
|
|
dataset = load_dataset("Idavidrein/gpqa", "gpqa_main") |
|
|
output_data = [] |
|
|
skipped = 0 |
|
|
|
|
|
for split in dataset.keys(): |
|
|
for item in dataset[split]: |
|
|
question = item['Question'] |
|
|
correct_answer = item['Correct Answer'] |
|
|
incorrect_1 = item['Incorrect Answer 1'] |
|
|
incorrect_2 = item['Incorrect Answer 2'] |
|
|
incorrect_3 = item['Incorrect Answer 3'] |
|
|
|
|
|
if not question or not correct_answer or not incorrect_1 or not incorrect_2 or not incorrect_3: |
|
|
skipped += 1 |
|
|
continue |
|
|
|
|
|
answers = [(correct_answer, True), (incorrect_1, False), (incorrect_2, False), (incorrect_3, False)] |
|
|
|
|
|
random.shuffle(answers) |
|
|
|
|
|
choices = {} |
|
|
correct_letter = None |
|
|
|
|
|
for i, (answer_text, is_correct) in enumerate(answers): |
|
|
letter = chr(65 + i) |
|
|
choices[f'Choice {i+1}'] = answer_text |
|
|
if is_correct: |
|
|
correct_letter = letter |
|
|
|
|
|
entry = { |
|
|
'Question': question, |
|
|
'Choice 1': choices['Choice 1'], |
|
|
'Choice 2': choices['Choice 2'], |
|
|
'Choice 3': choices['Choice 3'], |
|
|
'Choice 4': choices['Choice 4'], |
|
|
'Answer': correct_letter, |
|
|
} |
|
|
|
|
|
output_data.append(entry) |
|
|
|
|
|
return output_data, skipped |
|
|
|
|
|
|
|
|
def process_gpqa_diamond_dataset() -> List[Dict[str, Any]]: |
|
|
"""Process the GPQA Diamond dataset and return a list of formatted entries.""" |
|
|
dataset = load_dataset("Idavidrein/gpqa", "gpqa_diamond") |
|
|
output_data = [] |
|
|
skipped = 0 |
|
|
|
|
|
for split in dataset.keys(): |
|
|
for item in dataset[split]: |
|
|
question = item['Question'] |
|
|
correct_answer = item['Correct Answer'] |
|
|
incorrect_1 = item['Incorrect Answer 1'] |
|
|
incorrect_2 = item['Incorrect Answer 2'] |
|
|
incorrect_3 = item['Incorrect Answer 3'] |
|
|
|
|
|
if not question or not correct_answer or not incorrect_1 or not incorrect_2 or not incorrect_3: |
|
|
skipped += 1 |
|
|
continue |
|
|
|
|
|
answers = [(correct_answer, True), (incorrect_1, False), (incorrect_2, False), (incorrect_3, False)] |
|
|
|
|
|
random.shuffle(answers) |
|
|
|
|
|
choices = {} |
|
|
correct_letter = None |
|
|
|
|
|
for i, (answer_text, is_correct) in enumerate(answers): |
|
|
letter = chr(65 + i) |
|
|
choices[f'Choice {i+1}'] = answer_text |
|
|
if is_correct: |
|
|
correct_letter = letter |
|
|
|
|
|
entry = { |
|
|
'Question': question, |
|
|
'Choice 1': choices['Choice 1'], |
|
|
'Choice 2': choices['Choice 2'], |
|
|
'Choice 3': choices['Choice 3'], |
|
|
'Choice 4': choices['Choice 4'], |
|
|
'Answer': correct_letter, |
|
|
} |
|
|
|
|
|
output_data.append(entry) |
|
|
|
|
|
return output_data, skipped |
|
|
|
|
|
|
|
|
def write_to_jsonl(data: List[Dict[str, Any]], filename: str) -> None: |
|
|
"""Write the processed data to a JSONL file.""" |
|
|
with open(filename, 'w') as f: |
|
|
for entry in data: |
|
|
f.write(json.dumps(entry) + '\n') |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
"""Parse command line arguments.""" |
|
|
parser = argparse.ArgumentParser(description='Convert datasets to JSONL format') |
|
|
parser.add_argument( |
|
|
'--datasets', |
|
|
nargs='+', |
|
|
choices=['mmlu', 'gpqa', 'gpqa_diamond', 'all'], |
|
|
default=['all'], |
|
|
help='Datasets to process (default: all)', |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
if 'all' in args.datasets: |
|
|
args.datasets = ['mmlu', 'gpqa', 'gpqa_diamond'] |
|
|
|
|
|
|
|
|
if 'mmlu' in args.datasets: |
|
|
for split in ["test", "val"]: |
|
|
mmlu_data, mmlu_skipped = process_mmlu_dataset(split) |
|
|
write_to_jsonl(mmlu_data, f'mmlu_dataset_{split}.jsonl') |
|
|
print(f"Converted {len(mmlu_data)} {split} questions to mmlu_dataset_{split}.jsonl") |
|
|
if mmlu_skipped > 0: |
|
|
print(f"Skipped {mmlu_skipped} MMLU entries due to missing data") |
|
|
print("\nMMLU Sample entry:") |
|
|
print(json.dumps(mmlu_data[0], indent=2)) |
|
|
|
|
|
if 'gpqa' in args.datasets: |
|
|
gpqa_data, gpqa_skipped = process_gpqa_dataset() |
|
|
write_to_jsonl(gpqa_data, 'gpqa_dataset.jsonl') |
|
|
print(f"\nConverted {len(gpqa_data)} questions to gpqa_dataset.jsonl") |
|
|
if gpqa_skipped > 0: |
|
|
print(f"Skipped {gpqa_skipped} GPQA entries due to missing data") |
|
|
print("\nGPQA Sample entry:") |
|
|
print(json.dumps(gpqa_data[0], indent=2)) |
|
|
|
|
|
if 'gpqa_diamond' in args.datasets: |
|
|
gpqa_diamond_data, gpqa_diamond_skipped = process_gpqa_diamond_dataset() |
|
|
write_to_jsonl(gpqa_diamond_data, 'gpqa_diamond_dataset.jsonl') |
|
|
print(f"\nConverted {len(gpqa_diamond_data)} questions to gpqa_diamond_dataset.jsonl") |
|
|
if gpqa_diamond_skipped > 0: |
|
|
print(f"Skipped {gpqa_diamond_skipped} GPQA Diamond entries due to missing data") |
|
|
print("\nGPQA Diamond Sample entry:") |
|
|
print(json.dumps(gpqa_diamond_data[0], indent=2)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|