| """ |
| Pre-download all SFT datasets to nanochat_data/sft_data/ so that chat_sft can run |
| without network access during training. |
| |
| Run once before SFT: |
| python -m scripts.download_sft_data |
| |
| Requires NANOCHAT_BASE_DIR or uses {repo}/nanochat_data by default. |
| """ |
|
|
| import os |
| import urllib.request |
| import logging |
|
|
| |
| logging.getLogger("httpx").setLevel(logging.WARNING) |
| logging.getLogger("datasets").setLevel(logging.WARNING) |
|
|
| from nanochat.common import get_base_dir |
|
|
|
|
| def main(): |
| base_dir = get_base_dir() |
| sft_data_dir = os.path.join(base_dir, "sft_data") |
| os.makedirs(sft_data_dir, exist_ok=True) |
| print(f"Saving SFT data to: {sft_data_dir}") |
|
|
| |
| |
| |
| from datasets import load_dataset |
|
|
| print("\n[1/6] Downloading SmolTalk...") |
| smoltalk_dir = os.path.join(sft_data_dir, "smoltalk") |
| for split in ["train", "test"]: |
| path = os.path.join(smoltalk_dir, split) |
| if os.path.isdir(path) and os.path.exists(os.path.join(path, "dataset_info.json")): |
| print(f" SmolTalk/{split} already exists, skipping") |
| continue |
| ds = load_dataset("HuggingFaceTB/smol-smoltalk", split=split).shuffle(seed=42) |
| ds.save_to_disk(path) |
| print(f" Saved SmolTalk/{split} ({len(ds):,} rows)") |
|
|
| |
| |
| |
| print("\n[2/6] Downloading MMLU...") |
| mmlu_dir = os.path.join(sft_data_dir, "mmlu") |
| |
| path_aux = os.path.join(mmlu_dir, "auxiliary_train") |
| if not (os.path.isdir(path_aux) and os.path.exists(os.path.join(path_aux, "dataset_info.json"))): |
| ds = load_dataset("cais/mmlu", "auxiliary_train", split="train").shuffle(seed=42) |
| ds = ds.map(lambda row: row["train"], remove_columns=["train"]) |
| os.makedirs(path_aux, exist_ok=True) |
| ds.save_to_disk(path_aux) |
| print(f" Saved MMLU/auxiliary_train ({len(ds):,} rows)") |
| else: |
| print(f" MMLU/auxiliary_train already exists, skipping") |
| |
| path_all = os.path.join(mmlu_dir, "all_test") |
| if not (os.path.isdir(path_all) and os.path.exists(os.path.join(path_all, "dataset_info.json"))): |
| ds = load_dataset("cais/mmlu", "all", split="test").shuffle(seed=42) |
| os.makedirs(path_all, exist_ok=True) |
| ds.save_to_disk(path_all) |
| print(f" Saved MMLU/all_test ({len(ds):,} rows)") |
| else: |
| print(f" MMLU/all_test already exists, skipping") |
|
|
| |
| |
| |
| print("\n[3/6] Downloading GSM8K...") |
| gsm8k_dir = os.path.join(sft_data_dir, "gsm8k") |
| for split in ["train", "test"]: |
| path = os.path.join(gsm8k_dir, f"main_{split}") |
| if os.path.isdir(path) and os.path.exists(os.path.join(path, "dataset_info.json")): |
| print(f" GSM8K/main_{split} already exists, skipping") |
| continue |
| ds = load_dataset("openai/gsm8k", "main", split=split).shuffle(seed=42) |
| os.makedirs(path, exist_ok=True) |
| ds.save_to_disk(path) |
| print(f" Saved GSM8K/main_{split} ({len(ds):,} rows)") |
|
|
| |
| |
| |
| print("\n[4/6] Downloading identity_conversations.jsonl...") |
| identity_path = os.path.join(sft_data_dir, "identity_conversations.jsonl") |
| if os.path.exists(identity_path): |
| print(f" identity_conversations.jsonl already exists, skipping") |
| else: |
| url = "https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl" |
| urllib.request.urlretrieve(url, identity_path) |
| print(f" Saved identity_conversations.jsonl") |
|
|
| |
| |
| |
| print("\n[5/6] Downloading words_alpha.txt...") |
| words_path = os.path.join(sft_data_dir, "words_alpha.txt") |
| if os.path.exists(words_path): |
| print(f" words_alpha.txt already exists, skipping") |
| else: |
| url = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words_alpha.txt" |
| urllib.request.urlretrieve(url, words_path) |
| with open(words_path) as f: |
| n_words = sum(1 for _ in f) |
| print(f" Saved words_alpha.txt ({n_words:,} words)") |
|
|
| print("\n[6/6] All SFT data downloaded.") |
| print(f"SFT data directory: {sft_data_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|