# -- coding: utf-8 -- import os from datasets import load_dataset from tqdm import tqdm import sentencepiece as spm import numpy as np # =========================================================== # KONFIGURACE # =========================================================== TARGET_TOKENS = 1_000_000_000 # 100M pro test, může být 1_000_000_000 a víc VOCAB_SIZE = 32_000 RAW_TEXT_PATH = "dataset.txt" TOKENIZER_MODEL_PATH = "tokenizer.model" BIN_TRAIN_PATH = "dataset.bin" BIN_VALID_PATH = "valid.bin" TRAIN_RATIO = 0.98 # 98% trénink, 2% valid SPECIAL_TOKENS = { "unk_id": 0, "bos_id": 1, "eos_id": 2, "pad_id": 3, } # =========================================================== # 1) STREAMOVANÉ STAŽENÍ FINEWEB -> dataset.txt # =========================================================== def download_fineweb_streaming(): if os.path.exists(RAW_TEXT_PATH): print("✔ dataset.txt už existuje, přeskočeno.") return print("📥 Stahuji FineWeb-Edu streamovacím způsobem...") dataset = load_dataset( "HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True ) tokens_so_far = 0 with open(RAW_TEXT_PATH, "w", encoding="utf-8") as f: for example in tqdm(dataset, desc="Stahuji dataset"): text = example["text"].strip() + "\n\n" approx = len(text) // 4 # odhad tokenů if tokens_so_far + approx > TARGET_TOKENS: remaining = TARGET_TOKENS - tokens_so_far chars = remaining * 4 f.write(text[:chars]) print("✔ dataset.txt hotovo.") return f.write(text) tokens_so_far += approx if tokens_so_far >= TARGET_TOKENS: print("✔ dataset.txt hotovo.") return # =========================================================== # 2) TRÉNINK SENTENCEPIECE TOKENIZERU # =========================================================== def train_tokenizer(): if os.path.exists(TOKENIZER_MODEL_PATH): print("✔ Tokenizer už existuje, přeskakuji.") return print("🔧 Trénuji SentencePiece tokenizer...") prefix = TOKENIZER_MODEL_PATH.replace(".model", "") spm.SentencePieceTrainer.train( input=RAW_TEXT_PATH, model_prefix=prefix, vocab_size=VOCAB_SIZE, model_type="unigram", character_coverage=1.0, byte_fallback=True, unk_id=SPECIAL_TOKENS["unk_id"], bos_id=SPECIAL_TOKENS["bos_id"], eos_id=SPECIAL_TOKENS["eos_id"], pad_id=SPECIAL_TOKENS["pad_id"], train_extremely_large_corpus=True, ) print("✔ Tokenizer natrénován.") # =========================================================== # 3) STREAMOVÁ TOKENIZACE → BIN FILE (INT32) # =========================================================== def tokenize_to_bin_streaming(): """ Streamovací tokenizace velkého datasetu do binárních souborů (int32), bez držení celého datasetu v paměti. """ if os.path.exists(BIN_TRAIN_PATH) and os.path.exists(BIN_VALID_PATH): print("✔ dataset.bin + valid.bin už existují.") return print("🔠 Streamuji text → tokeny (int32) → dataset.bin...") sp = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH) EOS = sp.eos_id() # =========================================================== # 1️⃣ ZJIŠTĚNÍ CELKOVÉHO POČTU TOKENŮ # =========================================================== print("🔎 Počítám tokeny...") total_tokens = 0 with open(RAW_TEXT_PATH, "r", encoding="utf-8") as f: for line in tqdm(f, desc="Počítám tokeny"): line = line.strip() if not line: continue total_tokens += len(sp.encode(line)) + 1 # +1 pro EOS train_tokens = int(total_tokens * TRAIN_RATIO) valid_tokens = total_tokens - train_tokens print(f"Celkem tokenů: {total_tokens:,}") print(f"Train: {train_tokens:,}") print(f"Valid: {valid_tokens:,}") # =========================================================== # 2️⃣ VYTVOŘENÍ MEMMAP SOUBORŮ # =========================================================== train_mm = np.memmap(BIN_TRAIN_PATH, dtype=np.int32, mode="w+", shape=(train_tokens,)) valid_mm = np.memmap(BIN_VALID_PATH, dtype=np.int32, mode="w+", shape=(valid_tokens,)) # =========================================================== # 3️⃣ STREAMOVÁ TOKENIZACE A ZÁPIS # =========================================================== print("✍ Tokenizuji a zapisují do memmap...") ti, vi = 0, 0 # indexy do train/valid memmap with open(RAW_TEXT_PATH, "r", encoding="utf-8") as f: for line in tqdm(f, desc="Tokenizuji dataset"): line = line.strip() if not line: continue ids = sp.encode(line) + [EOS] for tok in ids: if ti < train_tokens: train_mm[ti] = tok ti += 1 else: valid_mm[vi] = tok vi += 1 # =========================================================== # 4️⃣ FLUSH MEMMAP # =========================================================== train_mm.flush() valid_mm.flush() print("✔ Hotovo — dataset.bin + valid.bin připravené pro trénink!") # =========================================================== # MAIN # =========================================================== if __name__ == "__main__": download_fineweb_streaming() train_tokenizer() tokenize_to_bin_streaming() print("\n🎉 HOTOVO — dataset.bin + valid.bin připravené pro trénink!")