learn / test_time_scaling /concept_difficulty_augment_extract.py
unfair11212's picture
Upload folder using huggingface_hub
a80f6e6 verified
import os
import json
import re
from glob import glob
# 目标根目录(可根据需要修改)
DATA_ROOT = 'data/concept_difficulty_augment/Qwen__Qwen2.5-7B-Instruct/abstract_algebra/harder'
OUTPUT_ROOT = os.path.join(DATA_ROOT, 'extracted_entries')
# 正则表达式,提取 JSON 格式的 question/options
QUESTION_JSON_RE = re.compile(r'\{\s*"question"\s*:\s*".*?",\s*"options"\s*:\s*\{.*?\}\s*\}', re.DOTALL)
# 递归查找所有 all_results.json
all_results_files = glob(os.path.join(DATA_ROOT, 'all_results.json'))
os.makedirs(OUTPUT_ROOT, exist_ok=True)
for results_path in all_results_files:
with open(results_path, 'r', encoding='utf-8') as f:
results = json.load(f)
for idx, entry in enumerate(results):
model_output = entry.get('model_output', '')
# 只提取第一个合法 JSON 片段
match = QUESTION_JSON_RE.search(model_output)
if not match:
continue # 跳过无法提取的
try:
parsed = json.loads(match.group())
except Exception as e:
continue # 跳过解析失败的
# 保存为 question_XXXX.json
fname = f'question_{idx:04d}.json'
with open(os.path.join(OUTPUT_ROOT, fname), 'w', encoding='utf-8') as f:
json.dump(parsed, f, ensure_ascii=False, indent=2)
print(f'Extracted {len(results)} entries to {OUTPUT_ROOT}/')