File size: 4,839 Bytes
2fdf3c9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """
Data API
"""
import json
from datasets import Dataset, load_dataset, load_from_disk, concatenate_datasets
from prompts import gsm8k_prompt, asdiv_aug_prompt, math_500_prompt, aime_prompt, strategy_qa_prompt, du_prompt
def get_dataset(data_name_or_path, tokenizer, prompt_idx):
"""
Args:
data_name_or_path: dataset name or path
tokenizer: tokenizer
prompt_idx: which query prompt to use
Returns:
dataset: dataset
"""
### Load dataset ###
if "gsm8k" in data_name_or_path.lower():
try:
dataset = load_from_disk(data_name_or_path)['test']
except:
dataset = load_dataset("openai/gsm8k", "socratic")["test"]
question_col = "question"
answer_col = "answer"
elif "asdiv-aug" in data_name_or_path.lower():
try:
dataset = load_from_disk(data_name_or_path)['test']
except:
dataset = load_dataset("xuyige/ASDiv-Aug")["test"]
question_col = "question"
answer_col = "answer"
elif "math-500" in data_name_or_path.lower():
try:
dataset = load_from_disk(data_name_or_path)['test']
except:
dataset = load_dataset("HuggingFaceH4/MATH-500")["test"]
question_col = "problem"
answer_col = "answer"
elif "aime_2024" in data_name_or_path.lower():
try:
dataset = load_from_disk(data_name_or_path)
except:
dataset = load_dataset("Maxwell-Jia/AIME_2024")['train']
question_col = "Problem"
answer_col = "Answer"
elif "aime2025" in data_name_or_path.lower():
try:
dataset = load_from_disk(data_name_or_path)
except:
dataset = concatenate_datasets([
load_dataset("opencompass/AIME2025", "AIME2025-I")['test'],
load_dataset("opencompass/AIME2025", "AIME2025-II")['test'],
])
question_col = "question"
answer_col = "answer"
elif "strategyqa" in data_name_or_path.lower():
return get_strategyqa(tokenizer, prompt_idx)
elif "date_understanding" in data_name_or_path.lower():
dataset = load_dataset("maveriq/bigbenchhard", "date_understanding")['train']
question_col = "input"
answer_col = "target"
else:
raise ValueError(f"Unsupported dataset: {data_name_or_path}")
# preprocess dataset
def preprocess_function(examples):
'''
Preprocess dataset
Args:
examples: dataset examples
Returns:
formatted: formatted dataset
'''
prompt = []
formatted = []
answers = examples[answer_col]
questions = examples[question_col]
for q in questions:
if "gsm8k" in data_name_or_path.lower():
messages = gsm8k_prompt(q, prompt_idx)
elif "asdiv-aug" in data_name_or_path.lower():
messages = asdiv_aug_prompt(q, prompt_idx)
elif "math-500" in data_name_or_path.lower():
messages = math_500_prompt(q, prompt_idx)
elif "aime" in data_name_or_path.lower():
messages = aime_prompt(q, prompt_idx)
elif "date_understanding" in data_name_or_path.lower():
messages = du_prompt(q, prompt_idx)
else:
raise ValueError(f"Unsupported dataset: {data_name_or_path}")
prompt.append(messages)
formatted.append(tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
))
if "aime" in data_name_or_path.lower() and "2025" in data_name_or_path.lower():
answers = [ans.replace('^\circ', '') for ans in answers]
if "date_understanding" in data_name_or_path.lower():
answers = [ans[1] for ans in answers]
return {
"prompt": prompt,
"formatted": formatted,
"question": questions,
"answer": answers,
}
dataset = dataset.map(preprocess_function, batched=True, load_from_cache_file=False)
return dataset
def get_strategyqa(tokenizer, prompt_idx):
prompt = []
formatted = []
answers = []
questions = []
with open('strategyqa_train.json', 'r') as f:
data = json.load(f)
for ins in data:
q, a = ins['question'], ins['answer']
msg = strategy_qa_prompt(q, prompt_idx)
questions.append(q)
answers.append(a)
prompt.append(msg)
formatted.append(tokenizer.apply_chat_template(
msg, tokenize=False, add_generation_prompt=True
))
return Dataset.from_dict({
"prompt": prompt,
"formatted": formatted,
"question": questions,
"answer": answers,
})
|