model111 / scripts /restore_math_vision_prompts.py
LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
#!/usr/bin/env python3
"""
Restore Math Vision prompts to original format (remove the multiple choice instruction).
This script removes the "If it is a multiple choice question, directly give the option letter." instruction.
"""
import json
import os
def restore_prompt(prompt: str) -> str:
"""
Restore prompt to original format by removing the multiple choice instruction.
Current: "Solve the problem and output the answer in the format of \\boxed{your answer}. If it is a multiple choice question, directly give the option letter.\\n Question:"
Restore to: "Solve the problem and output the answer in the format of \\boxed{your answer}.\\n Question:"
"""
current_text = "Solve the problem and output the answer in the format of \\boxed{your answer}. If it is a multiple choice question, directly give the option letter.\\n Question:"
original_text = "Solve the problem and output the answer in the format of \\boxed{your answer}.\\n Question:"
if current_text in prompt:
return prompt.replace(current_text, original_text)
else:
# Already in original format
return prompt
def restore_json_file(file_path: str) -> dict:
"""
Restore a JSON file to original prompt format.
Returns:
dict with statistics: total, restored, skipped
"""
print(f"\n处理文件: {file_path}")
# Load JSON
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Restore prompts
total = len(data)
restored = 0
for item in data:
old_prompt = item.get('prompt', '')
new_prompt = restore_prompt(old_prompt)
if new_prompt != old_prompt:
item['prompt'] = new_prompt
restored += 1
# Save updated JSON
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
stats = {
'total': total,
'restored': restored,
'skipped': total - restored
}
print(f" ✓ 总样本数: {stats['total']}")
print(f" ✓ 已恢复: {stats['restored']}")
print(f" ✓ 跳过: {stats['skipped']}")
return stats
def main():
data_dir = "data/math_vision"
if not os.path.exists(data_dir):
print(f"错误: 目录不存在: {data_dir}")
return
print("=" * 80)
print("Math Vision Prompt 恢复脚本")
print("=" * 80)
print(f"数据目录: {data_dir}")
print(f"恢复内容: 删除 'If it is a multiple choice question, directly give the option letter.' 指令")
# Find all JSON files
json_files = ['train.json', 'valid.json', 'test.json']
total_stats = {'total': 0, 'restored': 0, 'skipped': 0}
for filename in json_files:
file_path = os.path.join(data_dir, filename)
if not os.path.exists(file_path):
print(f"\n⚠ 跳过不存在的文件: {file_path}")
continue
stats = restore_json_file(file_path)
total_stats['total'] += stats['total']
total_stats['restored'] += stats['restored']
total_stats['skipped'] += stats['skipped']
print("\n" + "=" * 80)
print("总结")
print("=" * 80)
print(f"总样本数: {total_stats['total']}")
print(f"已恢复: {total_stats['restored']}")
print(f"跳过: {total_stats['skipped']}")
print(f"\n✓ 完成!所有prompt已恢复到原始格式。")
if __name__ == "__main__":
main()