#!/usr/bin/env python3 """ Update Math Vision prompts to include instruction for multiple choice questions. This script updates all JSON files in data/math_vision/ to include: "如果是选择题直接给出选项字母" instruction. """ import json import os from pathlib import Path def update_prompt(prompt: str) -> str: """ Update prompt to include instruction for multiple choice questions. Old: "Solve the problem and output the answer in the format of \\boxed{your answer}.\\n Question:" New: "Solve the problem and output the answer in the format of \\boxed{your answer}. If it is a multiple choice question, directly give the option letter.\\n Question:" """ # First, handle Chinese version if exists (rollback) chinese_text = "Solve the problem and output the answer in the format of \\boxed{your answer}. 如果是选择题直接给出选项字母.\\n Question:" old_text = "Solve the problem and output the answer in the format of \\boxed{your answer}.\\n Question:" new_text = "Solve the problem and output the answer in the format of \\boxed{your answer}. If it is a multiple choice question, directly give the option letter.\\n Question:" # Replace Chinese version with English if chinese_text in prompt: return prompt.replace(chinese_text, new_text) elif old_text in prompt: return prompt.replace(old_text, new_text) else: # Already updated or different format return prompt def update_json_file(file_path: str) -> dict: """ Update a JSON file with new prompts. Returns: dict with statistics: total, updated, skipped """ print(f"\n处理文件: {file_path}") # Load JSON with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) # Update prompts total = len(data) updated = 0 for item in data: old_prompt = item.get('prompt', '') new_prompt = update_prompt(old_prompt) if new_prompt != old_prompt: item['prompt'] = new_prompt updated += 1 # Backup original file backup_path = file_path + '.backup' if not os.path.exists(backup_path): with open(backup_path, 'w', encoding='utf-8') as f: # Read original again for backup with open(file_path, 'r', encoding='utf-8') as f_orig: original_data = json.load(f_orig) json.dump(original_data, f, ensure_ascii=False, indent=2) print(f" ✓ 备份创建: {backup_path}") # Save updated JSON with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) stats = { 'total': total, 'updated': updated, 'skipped': total - updated } print(f" ✓ 总样本数: {stats['total']}") print(f" ✓ 已更新: {stats['updated']}") print(f" ✓ 跳过: {stats['skipped']}") return stats def main(): data_dir = "data/math_vision" if not os.path.exists(data_dir): print(f"错误: 目录不存在: {data_dir}") return print("=" * 80) print("Math Vision Prompt 更新脚本") print("=" * 80) print(f"数据目录: {data_dir}") print(f"更新内容: 添加 'If it is a multiple choice question, directly give the option letter.' 指令") # Find all JSON files json_files = ['train.json', 'valid.json', 'test.json'] total_stats = {'total': 0, 'updated': 0, 'skipped': 0} for filename in json_files: file_path = os.path.join(data_dir, filename) if not os.path.exists(file_path): print(f"\n⚠ 跳过不存在的文件: {file_path}") continue stats = update_json_file(file_path) total_stats['total'] += stats['total'] total_stats['updated'] += stats['updated'] total_stats['skipped'] += stats['skipped'] print("\n" + "=" * 80) print("总结") print("=" * 80) print(f"总样本数: {total_stats['total']}") print(f"已更新: {total_stats['updated']}") print(f"跳过: {total_stats['skipped']}") print(f"\n✓ 完成!所有prompt已更新。") print(f"\n备份文件位置:") for filename in json_files: backup_path = os.path.join(data_dir, filename + '.backup') if os.path.exists(backup_path): print(f" - {backup_path}") if __name__ == "__main__": main()