| import os |
| os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" |
| """ |
| Download Multimodal-Cold-Start dataset from HuggingFace, |
| convert to the same JSONL format as existing physics CoT data, |
| save images, and merge with existing 1,467 physics samples. |
| """ |
| import os |
| import json |
| from datasets import load_dataset |
| from tqdm import tqdm |
|
|
| |
| DATASET_NAME = "WaltonFuture/Multimodal-Cold-Start" |
| OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/sft_train" |
| IMAGE_DIR = os.path.join(OUTPUT_DIR, "images_coldstart") |
| OUTPUT_JSONL = os.path.join(OUTPUT_DIR, "coldstart_formatted.jsonl") |
| MERGED_JSONL = os.path.join(OUTPUT_DIR, "merged_train.jsonl") |
| EXISTING_JSONL = os.path.join(OUTPUT_DIR, "sft_train_formatted.jsonl") |
|
|
| SYSTEM_PROMPT = "Please reason step by step, and put your final answer within \\boxed{}." |
|
|
| os.makedirs(IMAGE_DIR, exist_ok=True) |
|
|
| print(f"Downloading {DATASET_NAME}...") |
| dataset = load_dataset(DATASET_NAME, split="train") |
| print(f"Downloaded {len(dataset)} samples") |
|
|
| |
| converted = [] |
| for i, sample in tqdm(enumerate(dataset), total=len(dataset), desc="Converting"): |
| |
| img = sample["images"][0] |
| img_path = os.path.join(IMAGE_DIR, f"cs_{i}.jpg") |
| if not os.path.exists(img_path): |
| img.convert("RGB").save(img_path) |
| |
| |
| record = { |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": f"file://{img_path}"}, |
| {"type": "text", "text": f"{SYSTEM_PROMPT}\n\n{sample['problem']}"} |
| ] |
| }, |
| { |
| "role": "assistant", |
| "content": [ |
| {"type": "text", "text": sample["answer"]} |
| ] |
| } |
| ] |
| } |
| converted.append(record) |
|
|
| |
| print(f"\nSaving {len(converted)} cold-start samples to {OUTPUT_JSONL}") |
| with open(OUTPUT_JSONL, 'w', encoding='utf-8') as f: |
| for r in converted: |
| f.write(json.dumps(r, ensure_ascii=False) + '\n') |
|
|
| |
| existing = [] |
| if os.path.exists(EXISTING_JSONL): |
| with open(EXISTING_JSONL, 'r', encoding='utf-8') as f: |
| existing = [json.loads(line) for line in f] |
| print(f"Loaded {len(existing)} existing physics samples") |
|
|
| |
| merged = converted + existing |
| print(f"Total merged: {len(merged)} ({len(converted)} cold-start + {len(existing)} physics)") |
|
|
| with open(MERGED_JSONL, 'w', encoding='utf-8') as f: |
| for r in merged: |
| f.write(json.dumps(r, ensure_ascii=False) + '\n') |
|
|
| print(f"\nMerged data saved to {MERGED_JSONL}") |
| print("Done!") |
|
|