sllm / tokenizer /traintokenizer.py
geeteshcodes's picture
Initial commit
7f974df verified
from datasets import load_dataset
from tokenizers import Tokenizer
# Import our components
from normalizer import normalization # our normalize function
from bpe import build_tokenizer, build_trainer, get_special_token_ids
from post_processor import add_post_processor
# ------------------------------------------------------------------ #
# CONSTANTS
# ------------------------------------------------------------------ #
DATASET_NAME = "HuggingFaceFW/fineweb-edu"
DATASET_SUBSET = "CC-MAIN-2014-49"
MIN_QUALITY = 3 # int_score >= 3 only
MAX_TOKENS = 25_000_000 # ~100M characters worth, enough for BPE training
# FineWeb-Edu tokens avg 4-5 chars each
MIN_DOC_LENGTH = 100 # skip very short documents, likely boilerplate
import os
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SAVE_PATH = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer")
# ------------------------------------------------------------------ #
# DATA GENERATOR
# ------------------------------------------------------------------ #
def fineweb_edu_iterator(
max_tokens: int = MAX_TOKENS,
min_quality: int = MIN_QUALITY,
min_length: int = MIN_DOC_LENGTH,
):
"""
Streams FineWeb-Edu documents, filters by quality,
normalizes text, and yields clean strings for BPE training.
Args:
max_tokens : stop after consuming this many tokens total
min_quality : only yield docs with int_score >= this value
min_length : skip docs shorter than this many characters
Yields:
str: normalized, clean document text
"""
print(f"Loading dataset stream: {DATASET_NAME} / {DATASET_SUBSET}")
ds = load_dataset(
DATASET_NAME,
name=DATASET_SUBSET,
split="train",
streaming=True,
)
tokens_seen = 0 # running total of tokens consumed
docs_yielded = 0 # how many docs passed all filters
docs_skipped = 0 # how many docs were filtered out
for doc in ds:
# ---- Stop condition ----------------------------------------
if tokens_seen >= max_tokens:
break
# ---- Quality filter ----------------------------------------
# int_score is 0-5, we want educational quality >= 3
if doc["int_score"] < min_quality:
docs_skipped += 1
continue
# ---- Extract and normalize ---------------------------------
text = doc["text"]
# Skip very short documents before normalization
# (saves compute on boilerplate/empty docs)
if len(text) < min_length:
docs_skipped += 1
continue
# Run our normalization pipeline
text = normalization(text)
# Skip if normalization made it too short
# (e.g. doc was mostly HTML tags or control chars)
if len(text) < min_length:
docs_skipped += 1
continue
# ---- Track progress ----------------------------------------
tokens_seen += doc["token_count"]
docs_yielded += 1
# Log progress every 100k documents
if docs_yielded % 100_000 == 0:
print(
f" docs yielded: {docs_yielded:,} | "
f"docs skipped: {docs_skipped:,} | "
f"tokens seen: {tokens_seen:,} / {max_tokens:,} "
f"({100 * tokens_seen / max_tokens:.1f}%)"
)
yield text
# Final stats
print(f"\nStream complete:")
print(f" docs yielded : {docs_yielded:,}")
print(f" docs skipped : {docs_skipped:,}")
print(f" tokens seen : {tokens_seen:,}")
# ------------------------------------------------------------------ #
# TRAINING
# ------------------------------------------------------------------ #
def train_tokenizer() -> Tokenizer:
"""
Builds, trains, and saves the tokenizer.
Returns:
Trained Tokenizer object
"""
# Build untrained tokenizer and trainer
tokenizer = build_tokenizer()
trainer = build_trainer()
print("\nStarting BPE training...")
print(f" vocab size : {trainer.vocab_size:,}")
print(f" min frequency : {trainer.min_frequency}")
print(f" quality filter: int_score >= {MIN_QUALITY}")
print(f" max tokens : {MAX_TOKENS:,}\n")
# train_from_iterator expects an iterable of strings
# our generator yields one clean document string at a time
tokenizer.train_from_iterator(
iterator=fineweb_edu_iterator(),
trainer=trainer,
length=MAX_TOKENS, # optional hint for progress bar accuracy
)
print("\nTraining complete.")
tokenizer = add_post_processor(tokenizer)
# Print special token IDs
ids = get_special_token_ids(tokenizer)
print(f"\nSpecial token IDs:")
for token, token_id in ids.items():
print(f" {token} -> {token_id}")
# Save tokenizer to disk
tokenizer.save(f"{SAVE_PATH}.json")
print(f"\nTokenizer saved to: {SAVE_PATH}.json")
return tokenizer
# ------------------------------------------------------------------ #
# QUICK VERIFICATION after training
# ------------------------------------------------------------------ #
def verify_tokenizer(tokenizer: Tokenizer):
"""
Runs a few quick checks after training to verify correctness.
"""
print("\n" + "="*60)
print(" TOKENIZER VERIFICATION")
print("="*60 + "\n")
test_cases = [
"The mitochondria is the powerhouse of the cell.",
"CO2 levels rose by 1.5e-3 ppm in 2024.",
"def compute_loss(y_pred, y_true):\n return (y_pred - y_true)**2",
"U.S.A has a Ph.D program e.g. at MIT.",
"don't they've she'll",
"∇f(x) = 0 is a necessary condition.", # tests byte fallback
]
for text in test_cases:
encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded.ids)
n_tokens = len(encoded.ids)
print(f"Input : {repr(text)}")
print(f"Tokens : {encoded.tokens}")
print(f"IDs : {encoded.ids}")
print(f"N tokens: {n_tokens}")
print(f"Decoded : {repr(decoded)}")
print(f"Lossless: {text == decoded}")
print()
# Verify vocab size
vocab_size = tokenizer.get_vocab_size()
print(f"Final vocab size: {vocab_size:,}")
# Verify endoftext token exists
eot_id = tokenizer.token_to_id("<|endoftext|>")
print(f"<|endoftext|> ID: {eot_id}")
# ------------------------------------------------------------------ #
# ENTRY POINT
# ------------------------------------------------------------------ #
if __name__ == "__main__":
tokenizer = train_tokenizer()
verify_tokenizer(tokenizer)