nanochat / scripts /download_sft_data.py
A113NW3I's picture
Upload folder using huggingface_hub
5eaff91 verified
"""
Pre-download all SFT datasets to nanochat_data/sft_data/ so that chat_sft can run
without network access during training.
Run once before SFT:
python -m scripts.download_sft_data
Requires NANOCHAT_BASE_DIR or uses {repo}/nanochat_data by default.
"""
import os
import urllib.request
import logging
# Reduce HTTP request logging during download
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("datasets").setLevel(logging.WARNING)
from nanochat.common import get_base_dir
def main():
base_dir = get_base_dir()
sft_data_dir = os.path.join(base_dir, "sft_data")
os.makedirs(sft_data_dir, exist_ok=True)
print(f"Saving SFT data to: {sft_data_dir}")
# -------------------------------------------------------------------------
# 1. SmolTalk (HuggingFaceTB/smol-smoltalk)
# -------------------------------------------------------------------------
from datasets import load_dataset
print("\n[1/6] Downloading SmolTalk...")
smoltalk_dir = os.path.join(sft_data_dir, "smoltalk")
for split in ["train", "test"]:
path = os.path.join(smoltalk_dir, split)
if os.path.isdir(path) and os.path.exists(os.path.join(path, "dataset_info.json")):
print(f" SmolTalk/{split} already exists, skipping")
continue
ds = load_dataset("HuggingFaceTB/smol-smoltalk", split=split).shuffle(seed=42)
ds.save_to_disk(path)
print(f" Saved SmolTalk/{split} ({len(ds):,} rows)")
# -------------------------------------------------------------------------
# 2. MMLU (cais/mmlu) - auxiliary_train and all
# -------------------------------------------------------------------------
print("\n[2/6] Downloading MMLU...")
mmlu_dir = os.path.join(sft_data_dir, "mmlu")
# auxiliary_train (train split)
path_aux = os.path.join(mmlu_dir, "auxiliary_train")
if not (os.path.isdir(path_aux) and os.path.exists(os.path.join(path_aux, "dataset_info.json"))):
ds = load_dataset("cais/mmlu", "auxiliary_train", split="train").shuffle(seed=42)
ds = ds.map(lambda row: row["train"], remove_columns=["train"])
os.makedirs(path_aux, exist_ok=True)
ds.save_to_disk(path_aux)
print(f" Saved MMLU/auxiliary_train ({len(ds):,} rows)")
else:
print(f" MMLU/auxiliary_train already exists, skipping")
# all (test split)
path_all = os.path.join(mmlu_dir, "all_test")
if not (os.path.isdir(path_all) and os.path.exists(os.path.join(path_all, "dataset_info.json"))):
ds = load_dataset("cais/mmlu", "all", split="test").shuffle(seed=42)
os.makedirs(path_all, exist_ok=True)
ds.save_to_disk(path_all)
print(f" Saved MMLU/all_test ({len(ds):,} rows)")
else:
print(f" MMLU/all_test already exists, skipping")
# -------------------------------------------------------------------------
# 3. GSM8K (openai/gsm8k)
# -------------------------------------------------------------------------
print("\n[3/6] Downloading GSM8K...")
gsm8k_dir = os.path.join(sft_data_dir, "gsm8k")
for split in ["train", "test"]:
path = os.path.join(gsm8k_dir, f"main_{split}")
if os.path.isdir(path) and os.path.exists(os.path.join(path, "dataset_info.json")):
print(f" GSM8K/main_{split} already exists, skipping")
continue
ds = load_dataset("openai/gsm8k", "main", split=split).shuffle(seed=42)
os.makedirs(path, exist_ok=True)
ds.save_to_disk(path)
print(f" Saved GSM8K/main_{split} ({len(ds):,} rows)")
# -------------------------------------------------------------------------
# 4. Identity conversations (identity_conversations.jsonl)
# -------------------------------------------------------------------------
print("\n[4/6] Downloading identity_conversations.jsonl...")
identity_path = os.path.join(sft_data_dir, "identity_conversations.jsonl")
if os.path.exists(identity_path):
print(f" identity_conversations.jsonl already exists, skipping")
else:
url = "https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl"
urllib.request.urlretrieve(url, identity_path)
print(f" Saved identity_conversations.jsonl")
# -------------------------------------------------------------------------
# 5. Words alpha (for SimpleSpelling and SpellingBee)
# -------------------------------------------------------------------------
print("\n[5/6] Downloading words_alpha.txt...")
words_path = os.path.join(sft_data_dir, "words_alpha.txt")
if os.path.exists(words_path):
print(f" words_alpha.txt already exists, skipping")
else:
url = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words_alpha.txt"
urllib.request.urlretrieve(url, words_path)
with open(words_path) as f:
n_words = sum(1 for _ in f)
print(f" Saved words_alpha.txt ({n_words:,} words)")
print("\n[6/6] All SFT data downloaded.")
print(f"SFT data directory: {sft_data_dir}")
if __name__ == "__main__":
main()