| """ |
| Pre-build and cache all datasets for all stages. |
| Run this ONCE before any training to avoid waiting during training. |
| |
| All datasets are cached at EVERY seq_len they will be used at to ensure |
| exact matches during training (no truncation/padding of replay buffers). |
| |
| Usage: |
| python prebuild_cache.py |
| """ |
|
|
| from tokenizers import Tokenizer |
| from dataset import StreamingStageDataset |
| import os |
|
|
| TOKENIZER_PATH = "tokenizers/tokenizer_corpus.json" |
| CACHE_DIR = "cache" |
|
|
| |
| |
| CACHE_CONFIGS = [ |
| |
| ("tinystories", 256, 50_000_000), |
| |
| |
| ("babylm_easy", 384, 50_000_000), |
| ("tinystories", 384, 50_000_000), |
| |
| |
| ("babylm_hard", 512, 60_000_000), |
| ("babylm_easy", 512, 50_000_000), |
| ("tinystories", 512, 50_000_000), |
| |
| |
| ("simplewiki", 512, 220_000_000), |
| ("babylm_hard", 512, 60_000_000), |
| |
| |
| ("fineweb_edu", 512, 500_000_000), |
| ("simplewiki", 512, 220_000_000), |
| ] |
|
|
| tokenizer = Tokenizer.from_file(TOKENIZER_PATH) |
|
|
| print("=" * 70) |
| print("PRE-BUILDING DATASET CACHES FOR ALL STAGES") |
| print("=" * 70) |
|
|
| for dataset_name, seq_len, max_tokens in CACHE_CONFIGS: |
| cache_path = os.path.join(CACHE_DIR, f"train_{dataset_name}_seq{seq_len}.pkl") |
| |
| |
| if os.path.exists(cache_path): |
| size_mb = os.path.getsize(cache_path) / (1024**2) |
| print(f"✓ {dataset_name:15} @ seq_len={seq_len:3}: {size_mb:6.0f} MB (cached)") |
| continue |
| |
| print(f"⏳ {dataset_name:15} @ seq_len={seq_len:3}...", end=" ", flush=True) |
| |
| ds = StreamingStageDataset().build( |
| dataset_name=dataset_name, |
| tokenizer=tokenizer, |
| seq_len=seq_len, |
| max_tokens=max_tokens, |
| cache_dir=CACHE_DIR, |
| ) |
| |
| size_mb = os.path.getsize(cache_path) / (1024**2) |
| print(f"{len(ds.chunks):8,} chunks → {size_mb:6.0f} MB ✓") |
|
|
| print("\n" + "=" * 70) |
| print("✅ ALL CACHES COMPLETE! Training will start instantly.") |
| print("=" * 70) |
|
|