File size: 4,507 Bytes
e34b94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
#!/usr/bin/env python3
"""
Update Math Vision prompts to include instruction for multiple choice questions.
This script updates all JSON files in data/math_vision/ to include:
"如果是选择题直接给出选项字母" instruction.
"""
import json
import os
from pathlib import Path
def update_prompt(prompt: str) -> str:
"""
Update prompt to include instruction for multiple choice questions.
Old: "Solve the problem and output the answer in the format of \\boxed{your answer}.\\n Question:"
New: "Solve the problem and output the answer in the format of \\boxed{your answer}. If it is a multiple choice question, directly give the option letter.\\n Question:"
"""
# First, handle Chinese version if exists (rollback)
chinese_text = "Solve the problem and output the answer in the format of \\boxed{your answer}. 如果是选择题直接给出选项字母.\\n Question:"
old_text = "Solve the problem and output the answer in the format of \\boxed{your answer}.\\n Question:"
new_text = "Solve the problem and output the answer in the format of \\boxed{your answer}. If it is a multiple choice question, directly give the option letter.\\n Question:"
# Replace Chinese version with English
if chinese_text in prompt:
return prompt.replace(chinese_text, new_text)
elif old_text in prompt:
return prompt.replace(old_text, new_text)
else:
# Already updated or different format
return prompt
def update_json_file(file_path: str) -> dict:
"""
Update a JSON file with new prompts.
Returns:
dict with statistics: total, updated, skipped
"""
print(f"\n处理文件: {file_path}")
# Load JSON
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Update prompts
total = len(data)
updated = 0
for item in data:
old_prompt = item.get('prompt', '')
new_prompt = update_prompt(old_prompt)
if new_prompt != old_prompt:
item['prompt'] = new_prompt
updated += 1
# Backup original file
backup_path = file_path + '.backup'
if not os.path.exists(backup_path):
with open(backup_path, 'w', encoding='utf-8') as f:
# Read original again for backup
with open(file_path, 'r', encoding='utf-8') as f_orig:
original_data = json.load(f_orig)
json.dump(original_data, f, ensure_ascii=False, indent=2)
print(f" ✓ 备份创建: {backup_path}")
# Save updated JSON
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
stats = {
'total': total,
'updated': updated,
'skipped': total - updated
}
print(f" ✓ 总样本数: {stats['total']}")
print(f" ✓ 已更新: {stats['updated']}")
print(f" ✓ 跳过: {stats['skipped']}")
return stats
def main():
data_dir = "data/math_vision"
if not os.path.exists(data_dir):
print(f"错误: 目录不存在: {data_dir}")
return
print("=" * 80)
print("Math Vision Prompt 更新脚本")
print("=" * 80)
print(f"数据目录: {data_dir}")
print(f"更新内容: 添加 'If it is a multiple choice question, directly give the option letter.' 指令")
# Find all JSON files
json_files = ['train.json', 'valid.json', 'test.json']
total_stats = {'total': 0, 'updated': 0, 'skipped': 0}
for filename in json_files:
file_path = os.path.join(data_dir, filename)
if not os.path.exists(file_path):
print(f"\n⚠ 跳过不存在的文件: {file_path}")
continue
stats = update_json_file(file_path)
total_stats['total'] += stats['total']
total_stats['updated'] += stats['updated']
total_stats['skipped'] += stats['skipped']
print("\n" + "=" * 80)
print("总结")
print("=" * 80)
print(f"总样本数: {total_stats['total']}")
print(f"已更新: {total_stats['updated']}")
print(f"跳过: {total_stats['skipped']}")
print(f"\n✓ 完成!所有prompt已更新。")
print(f"\n备份文件位置:")
for filename in json_files:
backup_path = os.path.join(data_dir, filename + '.backup')
if os.path.exists(backup_path):
print(f" - {backup_path}")
if __name__ == "__main__":
main()
|