rl4phyx-backup / ZeroSearch /One-Shot-RLVR /create_training_data.py
YUNTA88's picture
Upload folder using huggingface_hub
9a71cb6 verified
"""
为每个 category 的第一道题创建训练数据和复制图片
"""
import json
import pandas as pd
import shutil
from pathlib import Path
# 路径配置
TEST_RL_PATH = Path("d:/Research/Rl4Phyx/ZeroSearch/One-Shot-RLVR/data/test/test_rl4phyx/test_rl.jsonl")
IMAGE_SRC_DIR = Path("d:/Research/Rl4Phyx/ZeroSearch/One-Shot-RLVR/data/test/test_rl4phyx/testimage_rl")
TRAIN_DIR = Path("d:/Research/Rl4Phyx/ZeroSearch/One-Shot-RLVR/data/train/physics_vlm")
# category 到文件夹的映射
CATEGORY_FOLDER_MAP = {
"Electromagnetism": "electromagnetism",
"Mechanics": "mechanics",
"Optics": "optics",
"Modern Physics": "quantum",
"Thermodynamics": "thermo",
"Waves/Acoustics": "waves"
}
# 读取所有记录
with open(TEST_RL_PATH, 'r', encoding='utf-8') as f:
records = [json.loads(line) for line in f]
print(f"总记录数: {len(records)}")
# 按 category 分组,取每组第一条
first_by_category = {}
for r in records:
cat = r.get('category')
if cat and cat not in first_by_category:
first_by_category[cat] = r
print(f"\n找到 {len(first_by_category)} 个类别:")
# 为每个 category 创建训练数据
for cat, record in first_by_category.items():
folder_name = CATEGORY_FOLDER_MAP.get(cat, cat.lower().replace(" ", "_").replace("/", "_"))
target_folder = TRAIN_DIR / folder_name
# 确保文件夹存在
target_folder.mkdir(parents=True, exist_ok=True)
# 1. 复制图片
image_name = record.get('image')
if image_name:
src_image = IMAGE_SRC_DIR / image_name
dst_image = target_folder / image_name
if src_image.exists():
shutil.copy2(src_image, dst_image)
print(f" ✅ 图片复制: {image_name} -> {folder_name}/")
else:
print(f" ⚠️ 图片不存在: {src_image}")
# 2. 构建 prompt (chat 格式)
description = record.get('description', '')
question = record.get('question', '')
options = record.get('options', '')
prompt_text = f"""Look at the image and answer the physics question.
{description}
Question: {question}
Options:
{options}
Please analyze the problem step by step and provide your final answer in the format \\boxed{{X}} where X is the letter of your choice (A, B, C, or D)."""
prompt_chat = [{"role": "user", "content": prompt_text}]
# 3. 创建训练数据 (1题 x 128次)
training_records = []
for i in range(128):
training_records.append({
"prompt": prompt_chat,
"answer": record.get('answer', ''),
"image_path": image_name,
"data_source": "physics_vlm",
"category": cat,
"index": record.get('index')
})
# 4. 保存为 parquet
df = pd.DataFrame(training_records)
# 使用简化的文件名: category_1_rl.parquet
safe_name = folder_name.replace("/", "_").replace(" ", "_")
parquet_file = target_folder / f"{safe_name}_1_rl.parquet"
df.to_parquet(parquet_file, index=False)
print(f" ✅ 训练数据: {parquet_file.name} (128 条)")
print(f" 答案: {record.get('answer')}, index: {record.get('index')}")
print("\n" + "="*50)
print("✅ 所有训练数据创建完成!")