SHOREKEEPER / scripts /01_download_stem_data.py
geoore's picture
Restructure to src/ layout with attention, per-layer MoE, and working chat
73400c8
#!/usr/bin/env python3
"""
Download high-quality STEM datasets for SHOREKEEPER
Math, Code, Science - No random web text
"""
import json
from pathlib import Path
from datasets import load_dataset
def download_stem_data():
print("=" * 70)
print("DOWNLOADING STEM DATASETS")
print("=" * 70)
data_dir = Path("./data/stem")
data_dir.mkdir(parents=True, exist_ok=True)
all_data = []
# 1. MetaMathQA - 395k math problems with step-by-step reasoning
print("\n1. MetaMathQA (395k math problems)...")
try:
dataset = load_dataset("meta-math/MetaMathQA", split="train")
print(f" Loading {len(dataset)} examples...")
for item in dataset:
all_data.append({
"prompt": item.get("query", ""),
"response": f"|special_token| {item.get('response', '')} |special_token|",
"source": "metamath"
})
print(f" βœ“ Added {len(dataset)} math examples")
except Exception as e:
print(f" βœ— Failed: {e}")
# 2. CodeFeedback - 1.2M code instructions
print("\n2. CodeFeedback (1.2M code examples - taking 200k)...")
try:
dataset = load_dataset("m-a-p/CodeFeedback", split="train[:200000]")
print(f" Loading {len(dataset)} examples...")
for item in dataset:
instruction = item.get("instruction", "")
output = item.get("output", "")
if instruction and output:
all_data.append({
"prompt": instruction,
"response": f"|special_token| Here's the code:\n{output} |special_token|",
"source": "codefeedback"
})
print(f" βœ“ Added {len(dataset)} code examples")
except Exception as e:
print(f" βœ— Failed: {e}")
# 3. NuminaMath-CoT - 860k math problems
print("\n3. NuminaMath-CoT (860k math problems - taking 200k)...")
try:
dataset = load_dataset("AI-MO/NuminaMath-CoT", split="train[:200000]")
print(f" Loading {len(dataset)} examples...")
for item in dataset:
problem = item.get("problem", "")
solution = item.get("solution", "")
if problem and solution:
all_data.append({
"prompt": problem,
"response": f"|special_token| Let me solve this step by step.\n{solution} |special_token|",
"source": "numinamath"
})
print(f" βœ“ Added {len(dataset)} math examples")
except Exception as e:
print(f" βœ— Failed: {e}")
# 4. ScienceQA - 21k science questions
print("\n4. ScienceQA (21k science questions)...")
try:
dataset = load_dataset("derek-thomas/ScienceQA", split="train")
print(f" Loading {len(dataset)} examples...")
for item in dataset:
question = item.get("question", "")
answer = item.get("answer", "")
if question and answer:
all_data.append({
"prompt": question,
"response": f"|special_token| Science explanation:\n{answer} |special_token|",
"source": "scienceqa"
})
print(f" βœ“ Added {len(dataset)} science examples")
except Exception as e:
print(f" βœ— Failed: {e}")
# 5. GSM8K - 8.5k grade school math
print("\n5. GSM8K (8.5k grade school math)...")
try:
dataset = load_dataset("gsm8k", "main", split="train")
print(f" Loading {len(dataset)} examples...")
for item in dataset:
question = item.get("question", "")
answer = item.get("answer", "").split("####")[-1].strip()
if question and answer:
all_data.append({
"prompt": question,
"response": f"|special_token| {answer} |special_token|",
"source": "gsm8k"
})
print(f" βœ“ Added {len(dataset)} math examples")
except Exception as e:
print(f" βœ— Failed: {e}")
print("\n" + "=" * 70)
print(f"TOTAL STEM EXAMPLES: {len(all_data):,}")
print("=" * 70)
# Show breakdown
sources = {}
for item in all_data:
src = item['source']
sources[src] = sources.get(src, 0) + 1
print("\nBreakdown by source:")
for src, count in sources.items():
print(f" {src}: {count:,}")
# Save
print("\nSaving to disk...")
with open(data_dir / "stem_train.jsonl", "w") as f:
for item in all_data:
f.write(json.dumps(item) + "\n")
print(f"βœ“ Saved to: {data_dir}/stem_train.jsonl")
print(f" Total size: {len(all_data):,} examples")
# Also create validation split
split_idx = int(len(all_data) * 0.95)
train = all_data[:split_idx]
val = all_data[split_idx:]
with open(data_dir / "stem_val.jsonl", "w") as f:
for item in val:
f.write(json.dumps(item) + "\n")
print(f" Train: {len(train):,}")
print(f" Val: {len(val):,}")
if __name__ == "__main__":
download_stem_data()