| """ |
| 为每个 category 的第一道题创建训练数据和复制图片 |
| """ |
| import json |
| import pandas as pd |
| import shutil |
| from pathlib import Path |
|
|
| |
| TEST_RL_PATH = Path("d:/Research/Rl4Phyx/ZeroSearch/One-Shot-RLVR/data/test/test_rl4phyx/test_rl.jsonl") |
| IMAGE_SRC_DIR = Path("d:/Research/Rl4Phyx/ZeroSearch/One-Shot-RLVR/data/test/test_rl4phyx/testimage_rl") |
| TRAIN_DIR = Path("d:/Research/Rl4Phyx/ZeroSearch/One-Shot-RLVR/data/train/physics_vlm") |
|
|
| |
| CATEGORY_FOLDER_MAP = { |
| "Electromagnetism": "electromagnetism", |
| "Mechanics": "mechanics", |
| "Optics": "optics", |
| "Modern Physics": "quantum", |
| "Thermodynamics": "thermo", |
| "Waves/Acoustics": "waves" |
| } |
|
|
| |
| with open(TEST_RL_PATH, 'r', encoding='utf-8') as f: |
| records = [json.loads(line) for line in f] |
|
|
| print(f"总记录数: {len(records)}") |
|
|
| |
| first_by_category = {} |
| for r in records: |
| cat = r.get('category') |
| if cat and cat not in first_by_category: |
| first_by_category[cat] = r |
|
|
| print(f"\n找到 {len(first_by_category)} 个类别:") |
|
|
| |
| for cat, record in first_by_category.items(): |
| folder_name = CATEGORY_FOLDER_MAP.get(cat, cat.lower().replace(" ", "_").replace("/", "_")) |
| target_folder = TRAIN_DIR / folder_name |
| |
| |
| target_folder.mkdir(parents=True, exist_ok=True) |
| |
| |
| image_name = record.get('image') |
| if image_name: |
| src_image = IMAGE_SRC_DIR / image_name |
| dst_image = target_folder / image_name |
| if src_image.exists(): |
| shutil.copy2(src_image, dst_image) |
| print(f" ✅ 图片复制: {image_name} -> {folder_name}/") |
| else: |
| print(f" ⚠️ 图片不存在: {src_image}") |
| |
| |
| description = record.get('description', '') |
| question = record.get('question', '') |
| options = record.get('options', '') |
| |
| prompt_text = f"""Look at the image and answer the physics question. |
| |
| {description} |
| |
| Question: {question} |
| |
| Options: |
| {options} |
| |
| Please analyze the problem step by step and provide your final answer in the format \\boxed{{X}} where X is the letter of your choice (A, B, C, or D).""" |
|
|
| prompt_chat = [{"role": "user", "content": prompt_text}] |
| |
| |
| training_records = [] |
| for i in range(128): |
| training_records.append({ |
| "prompt": prompt_chat, |
| "answer": record.get('answer', ''), |
| "image_path": image_name, |
| "data_source": "physics_vlm", |
| "category": cat, |
| "index": record.get('index') |
| }) |
| |
| |
| df = pd.DataFrame(training_records) |
| |
| |
| safe_name = folder_name.replace("/", "_").replace(" ", "_") |
| parquet_file = target_folder / f"{safe_name}_1_rl.parquet" |
| df.to_parquet(parquet_file, index=False) |
| |
| print(f" ✅ 训练数据: {parquet_file.name} (128 条)") |
| print(f" 答案: {record.get('answer')}, index: {record.get('index')}") |
|
|
| print("\n" + "="*50) |
| print("✅ 所有训练数据创建完成!") |
|
|