File size: 5,203 Bytes
e34b94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
#!/usr/bin/env python3
"""
Add letter labels (A, B, C, D, E...) to multiple choice options in Math Vision dataset.
This script updates options from:
"Options: option1, option2, option3"
to:
"Options: A. option1, B. option2, C. option3"
"""
import json
import os
import re
def add_option_letters(prompt: str) -> tuple:
"""
Add letter labels to options in the prompt.
Args:
prompt: Original prompt text
Returns:
(updated_prompt, was_updated) tuple
"""
# Check if this prompt has options
if "\\n Options: " not in prompt:
return prompt, False
# Split by "\\n Options: " to separate question and options
parts = prompt.split("\\n Options: ")
if len(parts) != 2:
return prompt, False
question_part = parts[0]
options_part = parts[1]
# Split options by comma
# We need to be careful with commas inside math expressions
options = []
current_option = ""
dollar_count = 0
for char in options_part:
if char == '$':
dollar_count += 1
if char == ',' and dollar_count % 2 == 0:
# This comma is outside of math expressions
options.append(current_option.strip())
current_option = ""
else:
current_option += char
# Add the last option
if current_option.strip():
options.append(current_option.strip())
# Add letter labels
letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
labeled_options = []
for i, option in enumerate(options):
if i < len(letters):
labeled_options.append(f"{letters[i]}. {option}")
else:
# If more than 10 options (unlikely), just add the option without label
labeled_options.append(option)
# Reconstruct the prompt
new_options_part = ", ".join(labeled_options)
updated_prompt = f"{question_part}\\n Options: {new_options_part}"
return updated_prompt, True
def process_json_file(file_path: str) -> dict:
"""
Process a JSON file and add option letters.
Returns:
dict with statistics: total, updated, skipped
"""
print(f"\n处理文件: {file_path}")
# Load JSON
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Process prompts
total = len(data)
updated = 0
for item in data:
old_prompt = item.get('prompt', '')
new_prompt, was_updated = add_option_letters(old_prompt)
if was_updated:
item['prompt'] = new_prompt
updated += 1
# Backup original file if not already backed up
backup_path = file_path + '.backup_no_letters'
if not os.path.exists(backup_path):
with open(backup_path, 'w', encoding='utf-8') as f:
# Read original again for backup
with open(file_path, 'r', encoding='utf-8') as f_orig:
original_data = json.load(f_orig)
json.dump(original_data, f, ensure_ascii=False, indent=2)
print(f" ✓ 备份创建: {backup_path}")
# Save updated JSON
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
stats = {
'total': total,
'updated': updated,
'skipped': total - updated
}
print(f" ✓ 总样本数: {stats['total']}")
print(f" ✓ 已更新(有选项): {stats['updated']}")
print(f" ✓ 跳过(无选项): {stats['skipped']}")
return stats
def main():
data_dir = "data/math_vision"
if not os.path.exists(data_dir):
print(f"错误: 目录不存在: {data_dir}")
return
print("=" * 80)
print("Math Vision 选项字母标注脚本")
print("=" * 80)
print(f"数据目录: {data_dir}")
print(f"更新内容: 给每个选项添加字母标识 (A. B. C. D. E. ...)")
# Find all JSON files
json_files = ['train.json', 'valid.json', 'test.json']
total_stats = {'total': 0, 'updated': 0, 'skipped': 0}
for filename in json_files:
file_path = os.path.join(data_dir, filename)
if not os.path.exists(file_path):
print(f"\n⚠ 跳过不存在的文件: {file_path}")
continue
stats = process_json_file(file_path)
total_stats['total'] += stats['total']
total_stats['updated'] += stats['updated']
total_stats['skipped'] += stats['skipped']
print("\n" + "=" * 80)
print("总结")
print("=" * 80)
print(f"总样本数: {total_stats['total']}")
print(f"有选项的样本(已添加字母): {total_stats['updated']}")
print(f"无选项的样本(非选择题): {total_stats['skipped']}")
print(f"\n✓ 完成!所有选择题选项已添加字母标识。")
print(f"\n备份文件位置:")
for filename in json_files:
backup_path = os.path.join(data_dir, filename + '.backup_no_letters')
if os.path.exists(backup_path):
print(f" - {backup_path}")
if __name__ == "__main__":
main()
|