NanoGPT-X_Base / data_prepare.py
luxopes's picture
Upload 7 files
8449341 verified
# -- coding: utf-8 --
import os
from datasets import load_dataset
from tqdm import tqdm
import sentencepiece as spm
import numpy as np
# ===========================================================
# KONFIGURACE
# ===========================================================
TARGET_TOKENS = 1_000_000_000 # 100M pro test, může být 1_000_000_000 a víc
VOCAB_SIZE = 32_000
RAW_TEXT_PATH = "dataset.txt"
TOKENIZER_MODEL_PATH = "tokenizer.model"
BIN_TRAIN_PATH = "dataset.bin"
BIN_VALID_PATH = "valid.bin"
TRAIN_RATIO = 0.98 # 98% trénink, 2% valid
SPECIAL_TOKENS = {
"unk_id": 0,
"bos_id": 1,
"eos_id": 2,
"pad_id": 3,
}
# ===========================================================
# 1) STREAMOVANÉ STAŽENÍ FINEWEB -> dataset.txt
# ===========================================================
def download_fineweb_streaming():
if os.path.exists(RAW_TEXT_PATH):
print("✔ dataset.txt už existuje, přeskočeno.")
return
print("📥 Stahuji FineWeb-Edu streamovacím způsobem...")
dataset = load_dataset(
"HuggingFaceFW/fineweb-edu",
name="sample-10BT",
split="train",
streaming=True
)
tokens_so_far = 0
with open(RAW_TEXT_PATH, "w", encoding="utf-8") as f:
for example in tqdm(dataset, desc="Stahuji dataset"):
text = example["text"].strip() + "\n\n"
approx = len(text) // 4 # odhad tokenů
if tokens_so_far + approx > TARGET_TOKENS:
remaining = TARGET_TOKENS - tokens_so_far
chars = remaining * 4
f.write(text[:chars])
print("✔ dataset.txt hotovo.")
return
f.write(text)
tokens_so_far += approx
if tokens_so_far >= TARGET_TOKENS:
print("✔ dataset.txt hotovo.")
return
# ===========================================================
# 2) TRÉNINK SENTENCEPIECE TOKENIZERU
# ===========================================================
def train_tokenizer():
if os.path.exists(TOKENIZER_MODEL_PATH):
print("✔ Tokenizer už existuje, přeskakuji.")
return
print("🔧 Trénuji SentencePiece tokenizer...")
prefix = TOKENIZER_MODEL_PATH.replace(".model", "")
spm.SentencePieceTrainer.train(
input=RAW_TEXT_PATH,
model_prefix=prefix,
vocab_size=VOCAB_SIZE,
model_type="unigram",
character_coverage=1.0,
byte_fallback=True,
unk_id=SPECIAL_TOKENS["unk_id"],
bos_id=SPECIAL_TOKENS["bos_id"],
eos_id=SPECIAL_TOKENS["eos_id"],
pad_id=SPECIAL_TOKENS["pad_id"],
train_extremely_large_corpus=True,
)
print("✔ Tokenizer natrénován.")
# ===========================================================
# 3) STREAMOVÁ TOKENIZACE → BIN FILE (INT32)
# ===========================================================
def tokenize_to_bin_streaming():
"""
Streamovací tokenizace velkého datasetu do binárních souborů (int32),
bez držení celého datasetu v paměti.
"""
if os.path.exists(BIN_TRAIN_PATH) and os.path.exists(BIN_VALID_PATH):
print("✔ dataset.bin + valid.bin už existují.")
return
print("🔠 Streamuji text → tokeny (int32) → dataset.bin...")
sp = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH)
EOS = sp.eos_id()
# ===========================================================
# 1️⃣ ZJIŠTĚNÍ CELKOVÉHO POČTU TOKENŮ
# ===========================================================
print("🔎 Počítám tokeny...")
total_tokens = 0
with open(RAW_TEXT_PATH, "r", encoding="utf-8") as f:
for line in tqdm(f, desc="Počítám tokeny"):
line = line.strip()
if not line:
continue
total_tokens += len(sp.encode(line)) + 1 # +1 pro EOS
train_tokens = int(total_tokens * TRAIN_RATIO)
valid_tokens = total_tokens - train_tokens
print(f"Celkem tokenů: {total_tokens:,}")
print(f"Train: {train_tokens:,}")
print(f"Valid: {valid_tokens:,}")
# ===========================================================
# 2️⃣ VYTVOŘENÍ MEMMAP SOUBORŮ
# ===========================================================
train_mm = np.memmap(BIN_TRAIN_PATH, dtype=np.int32, mode="w+", shape=(train_tokens,))
valid_mm = np.memmap(BIN_VALID_PATH, dtype=np.int32, mode="w+", shape=(valid_tokens,))
# ===========================================================
# 3️⃣ STREAMOVÁ TOKENIZACE A ZÁPIS
# ===========================================================
print("✍ Tokenizuji a zapisují do memmap...")
ti, vi = 0, 0 # indexy do train/valid memmap
with open(RAW_TEXT_PATH, "r", encoding="utf-8") as f:
for line in tqdm(f, desc="Tokenizuji dataset"):
line = line.strip()
if not line:
continue
ids = sp.encode(line) + [EOS]
for tok in ids:
if ti < train_tokens:
train_mm[ti] = tok
ti += 1
else:
valid_mm[vi] = tok
vi += 1
# ===========================================================
# 4️⃣ FLUSH MEMMAP
# ===========================================================
train_mm.flush()
valid_mm.flush()
print("✔ Hotovo — dataset.bin + valid.bin připravené pro trénink!")
# ===========================================================
# MAIN
# ===========================================================
if __name__ == "__main__":
download_fineweb_streaming()
train_tokenizer()
tokenize_to_bin_streaming()
print("\n🎉 HOTOVO — dataset.bin + valid.bin připravené pro trénink!")