model111 / scripts /update_math_vision_prompts.py
LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
#!/usr/bin/env python3
"""
Update Math Vision prompts to include instruction for multiple choice questions.
This script updates all JSON files in data/math_vision/ to include:
"如果是选择题直接给出选项字母" instruction.
"""
import json
import os
from pathlib import Path
def update_prompt(prompt: str) -> str:
"""
Update prompt to include instruction for multiple choice questions.
Old: "Solve the problem and output the answer in the format of \\boxed{your answer}.\\n Question:"
New: "Solve the problem and output the answer in the format of \\boxed{your answer}. If it is a multiple choice question, directly give the option letter.\\n Question:"
"""
# First, handle Chinese version if exists (rollback)
chinese_text = "Solve the problem and output the answer in the format of \\boxed{your answer}. 如果是选择题直接给出选项字母.\\n Question:"
old_text = "Solve the problem and output the answer in the format of \\boxed{your answer}.\\n Question:"
new_text = "Solve the problem and output the answer in the format of \\boxed{your answer}. If it is a multiple choice question, directly give the option letter.\\n Question:"
# Replace Chinese version with English
if chinese_text in prompt:
return prompt.replace(chinese_text, new_text)
elif old_text in prompt:
return prompt.replace(old_text, new_text)
else:
# Already updated or different format
return prompt
def update_json_file(file_path: str) -> dict:
"""
Update a JSON file with new prompts.
Returns:
dict with statistics: total, updated, skipped
"""
print(f"\n处理文件: {file_path}")
# Load JSON
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Update prompts
total = len(data)
updated = 0
for item in data:
old_prompt = item.get('prompt', '')
new_prompt = update_prompt(old_prompt)
if new_prompt != old_prompt:
item['prompt'] = new_prompt
updated += 1
# Backup original file
backup_path = file_path + '.backup'
if not os.path.exists(backup_path):
with open(backup_path, 'w', encoding='utf-8') as f:
# Read original again for backup
with open(file_path, 'r', encoding='utf-8') as f_orig:
original_data = json.load(f_orig)
json.dump(original_data, f, ensure_ascii=False, indent=2)
print(f" ✓ 备份创建: {backup_path}")
# Save updated JSON
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
stats = {
'total': total,
'updated': updated,
'skipped': total - updated
}
print(f" ✓ 总样本数: {stats['total']}")
print(f" ✓ 已更新: {stats['updated']}")
print(f" ✓ 跳过: {stats['skipped']}")
return stats
def main():
data_dir = "data/math_vision"
if not os.path.exists(data_dir):
print(f"错误: 目录不存在: {data_dir}")
return
print("=" * 80)
print("Math Vision Prompt 更新脚本")
print("=" * 80)
print(f"数据目录: {data_dir}")
print(f"更新内容: 添加 'If it is a multiple choice question, directly give the option letter.' 指令")
# Find all JSON files
json_files = ['train.json', 'valid.json', 'test.json']
total_stats = {'total': 0, 'updated': 0, 'skipped': 0}
for filename in json_files:
file_path = os.path.join(data_dir, filename)
if not os.path.exists(file_path):
print(f"\n⚠ 跳过不存在的文件: {file_path}")
continue
stats = update_json_file(file_path)
total_stats['total'] += stats['total']
total_stats['updated'] += stats['updated']
total_stats['skipped'] += stats['skipped']
print("\n" + "=" * 80)
print("总结")
print("=" * 80)
print(f"总样本数: {total_stats['total']}")
print(f"已更新: {total_stats['updated']}")
print(f"跳过: {total_stats['skipped']}")
print(f"\n✓ 完成!所有prompt已更新。")
print(f"\n备份文件位置:")
for filename in json_files:
backup_path = os.path.join(data_dir, filename + '.backup')
if os.path.exists(backup_path):
print(f" - {backup_path}")
if __name__ == "__main__":
main()