rl4phyx-backup / root_scripts /download_coldstart.py
YUNTA88's picture
Upload root_scripts/download_coldstart.py with huggingface_hub
057c5c5 verified
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
"""
Download Multimodal-Cold-Start dataset from HuggingFace,
convert to the same JSONL format as existing physics CoT data,
save images, and merge with existing 1,467 physics samples.
"""
import os
import json
from datasets import load_dataset
from tqdm import tqdm
# ===== Configuration =====
DATASET_NAME = "WaltonFuture/Multimodal-Cold-Start"
OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/sft_train"
IMAGE_DIR = os.path.join(OUTPUT_DIR, "images_coldstart")
OUTPUT_JSONL = os.path.join(OUTPUT_DIR, "coldstart_formatted.jsonl")
MERGED_JSONL = os.path.join(OUTPUT_DIR, "merged_train.jsonl")
EXISTING_JSONL = os.path.join(OUTPUT_DIR, "sft_train_formatted.jsonl")
SYSTEM_PROMPT = "Please reason step by step, and put your final answer within \\boxed{}."
os.makedirs(IMAGE_DIR, exist_ok=True)
print(f"Downloading {DATASET_NAME}...")
dataset = load_dataset(DATASET_NAME, split="train")
print(f"Downloaded {len(dataset)} samples")
# Convert to our format
converted = []
for i, sample in tqdm(enumerate(dataset), total=len(dataset), desc="Converting"):
# Save image
img = sample["images"][0]
img_path = os.path.join(IMAGE_DIR, f"cs_{i}.jpg")
if not os.path.exists(img_path):
img.convert("RGB").save(img_path)
# Build message in the same format as existing data
record = {
"messages": [
{
"role": "user",
"content": [
{"type": "image", "image": f"file://{img_path}"},
{"type": "text", "text": f"{SYSTEM_PROMPT}\n\n{sample['problem']}"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": sample["answer"]}
]
}
]
}
converted.append(record)
# Save cold-start JSONL
print(f"\nSaving {len(converted)} cold-start samples to {OUTPUT_JSONL}")
with open(OUTPUT_JSONL, 'w', encoding='utf-8') as f:
for r in converted:
f.write(json.dumps(r, ensure_ascii=False) + '\n')
# Load existing physics data
existing = []
if os.path.exists(EXISTING_JSONL):
with open(EXISTING_JSONL, 'r', encoding='utf-8') as f:
existing = [json.loads(line) for line in f]
print(f"Loaded {len(existing)} existing physics samples")
# Merge
merged = converted + existing
print(f"Total merged: {len(merged)} ({len(converted)} cold-start + {len(existing)} physics)")
with open(MERGED_JSONL, 'w', encoding='utf-8') as f:
for r in merged:
f.write(json.dumps(r, ensure_ascii=False) + '\n')
print(f"\nMerged data saved to {MERGED_JSONL}")
print("Done!")