model111 / scripts /add_option_letters.py
LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
#!/usr/bin/env python3
"""
Add letter labels (A, B, C, D, E...) to multiple choice options in Math Vision dataset.
This script updates options from:
"Options: option1, option2, option3"
to:
"Options: A. option1, B. option2, C. option3"
"""
import json
import os
import re
def add_option_letters(prompt: str) -> tuple:
"""
Add letter labels to options in the prompt.
Args:
prompt: Original prompt text
Returns:
(updated_prompt, was_updated) tuple
"""
# Check if this prompt has options
if "\\n Options: " not in prompt:
return prompt, False
# Split by "\\n Options: " to separate question and options
parts = prompt.split("\\n Options: ")
if len(parts) != 2:
return prompt, False
question_part = parts[0]
options_part = parts[1]
# Split options by comma
# We need to be careful with commas inside math expressions
options = []
current_option = ""
dollar_count = 0
for char in options_part:
if char == '$':
dollar_count += 1
if char == ',' and dollar_count % 2 == 0:
# This comma is outside of math expressions
options.append(current_option.strip())
current_option = ""
else:
current_option += char
# Add the last option
if current_option.strip():
options.append(current_option.strip())
# Add letter labels
letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
labeled_options = []
for i, option in enumerate(options):
if i < len(letters):
labeled_options.append(f"{letters[i]}. {option}")
else:
# If more than 10 options (unlikely), just add the option without label
labeled_options.append(option)
# Reconstruct the prompt
new_options_part = ", ".join(labeled_options)
updated_prompt = f"{question_part}\\n Options: {new_options_part}"
return updated_prompt, True
def process_json_file(file_path: str) -> dict:
"""
Process a JSON file and add option letters.
Returns:
dict with statistics: total, updated, skipped
"""
print(f"\n处理文件: {file_path}")
# Load JSON
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Process prompts
total = len(data)
updated = 0
for item in data:
old_prompt = item.get('prompt', '')
new_prompt, was_updated = add_option_letters(old_prompt)
if was_updated:
item['prompt'] = new_prompt
updated += 1
# Backup original file if not already backed up
backup_path = file_path + '.backup_no_letters'
if not os.path.exists(backup_path):
with open(backup_path, 'w', encoding='utf-8') as f:
# Read original again for backup
with open(file_path, 'r', encoding='utf-8') as f_orig:
original_data = json.load(f_orig)
json.dump(original_data, f, ensure_ascii=False, indent=2)
print(f" ✓ 备份创建: {backup_path}")
# Save updated JSON
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
stats = {
'total': total,
'updated': updated,
'skipped': total - updated
}
print(f" ✓ 总样本数: {stats['total']}")
print(f" ✓ 已更新(有选项): {stats['updated']}")
print(f" ✓ 跳过(无选项): {stats['skipped']}")
return stats
def main():
data_dir = "data/math_vision"
if not os.path.exists(data_dir):
print(f"错误: 目录不存在: {data_dir}")
return
print("=" * 80)
print("Math Vision 选项字母标注脚本")
print("=" * 80)
print(f"数据目录: {data_dir}")
print(f"更新内容: 给每个选项添加字母标识 (A. B. C. D. E. ...)")
# Find all JSON files
json_files = ['train.json', 'valid.json', 'test.json']
total_stats = {'total': 0, 'updated': 0, 'skipped': 0}
for filename in json_files:
file_path = os.path.join(data_dir, filename)
if not os.path.exists(file_path):
print(f"\n⚠ 跳过不存在的文件: {file_path}")
continue
stats = process_json_file(file_path)
total_stats['total'] += stats['total']
total_stats['updated'] += stats['updated']
total_stats['skipped'] += stats['skipped']
print("\n" + "=" * 80)
print("总结")
print("=" * 80)
print(f"总样本数: {total_stats['total']}")
print(f"有选项的样本(已添加字母): {total_stats['updated']}")
print(f"无选项的样本(非选择题): {total_stats['skipped']}")
print(f"\n✓ 完成!所有选择题选项已添加字母标识。")
print(f"\n备份文件位置:")
for filename in json_files:
backup_path = os.path.join(data_dir, filename + '.backup_no_letters')
if os.path.exists(backup_path):
print(f" - {backup_path}")
if __name__ == "__main__":
main()