File size: 4,507 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
"""
Update Math Vision prompts to include instruction for multiple choice questions.

This script updates all JSON files in data/math_vision/ to include:
"如果是选择题直接给出选项字母" instruction.
"""

import json
import os
from pathlib import Path


def update_prompt(prompt: str) -> str:
    """
    Update prompt to include instruction for multiple choice questions.
    
    Old: "Solve the problem and output the answer in the format of \\boxed{your answer}.\\n Question:"
    New: "Solve the problem and output the answer in the format of \\boxed{your answer}. If it is a multiple choice question, directly give the option letter.\\n Question:"
    """
    # First, handle Chinese version if exists (rollback)
    chinese_text = "Solve the problem and output the answer in the format of \\boxed{your answer}. 如果是选择题直接给出选项字母.\\n Question:"
    old_text = "Solve the problem and output the answer in the format of \\boxed{your answer}.\\n Question:"
    new_text = "Solve the problem and output the answer in the format of \\boxed{your answer}. If it is a multiple choice question, directly give the option letter.\\n Question:"
    
    # Replace Chinese version with English
    if chinese_text in prompt:
        return prompt.replace(chinese_text, new_text)
    elif old_text in prompt:
        return prompt.replace(old_text, new_text)
    else:
        # Already updated or different format
        return prompt


def update_json_file(file_path: str) -> dict:
    """
    Update a JSON file with new prompts.
    
    Returns:
        dict with statistics: total, updated, skipped
    """
    print(f"\n处理文件: {file_path}")
    
    # Load JSON
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # Update prompts
    total = len(data)
    updated = 0
    
    for item in data:
        old_prompt = item.get('prompt', '')
        new_prompt = update_prompt(old_prompt)
        
        if new_prompt != old_prompt:
            item['prompt'] = new_prompt
            updated += 1
    
    # Backup original file
    backup_path = file_path + '.backup'
    if not os.path.exists(backup_path):
        with open(backup_path, 'w', encoding='utf-8') as f:
            # Read original again for backup
            with open(file_path, 'r', encoding='utf-8') as f_orig:
                original_data = json.load(f_orig)
            json.dump(original_data, f, ensure_ascii=False, indent=2)
        print(f"  ✓ 备份创建: {backup_path}")
    
    # Save updated JSON
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    
    stats = {
        'total': total,
        'updated': updated,
        'skipped': total - updated
    }
    
    print(f"  ✓ 总样本数: {stats['total']}")
    print(f"  ✓ 已更新: {stats['updated']}")
    print(f"  ✓ 跳过: {stats['skipped']}")
    
    return stats


def main():
    data_dir = "data/math_vision"
    
    if not os.path.exists(data_dir):
        print(f"错误: 目录不存在: {data_dir}")
        return
    
    print("=" * 80)
    print("Math Vision Prompt 更新脚本")
    print("=" * 80)
    print(f"数据目录: {data_dir}")
    print(f"更新内容: 添加 'If it is a multiple choice question, directly give the option letter.' 指令")
    
    # Find all JSON files
    json_files = ['train.json', 'valid.json', 'test.json']
    
    total_stats = {'total': 0, 'updated': 0, 'skipped': 0}
    
    for filename in json_files:
        file_path = os.path.join(data_dir, filename)
        
        if not os.path.exists(file_path):
            print(f"\n⚠ 跳过不存在的文件: {file_path}")
            continue
        
        stats = update_json_file(file_path)
        total_stats['total'] += stats['total']
        total_stats['updated'] += stats['updated']
        total_stats['skipped'] += stats['skipped']
    
    print("\n" + "=" * 80)
    print("总结")
    print("=" * 80)
    print(f"总样本数: {total_stats['total']}")
    print(f"已更新: {total_stats['updated']}")
    print(f"跳过: {total_stats['skipped']}")
    print(f"\n✓ 完成!所有prompt已更新。")
    print(f"\n备份文件位置:")
    for filename in json_files:
        backup_path = os.path.join(data_dir, filename + '.backup')
        if os.path.exists(backup_path):
            print(f"  - {backup_path}")
    

if __name__ == "__main__":
    main()