| from datasets import load_dataset
|
| from tokenizers import Tokenizer
|
|
|
|
|
| from normalizer import normalization
|
| from bpe import build_tokenizer, build_trainer, get_special_token_ids
|
|
|
| from post_processor import add_post_processor
|
|
|
|
|
|
|
|
|
| DATASET_NAME = "HuggingFaceFW/fineweb-edu"
|
| DATASET_SUBSET = "CC-MAIN-2014-49"
|
| MIN_QUALITY = 3
|
| MAX_TOKENS = 25_000_000
|
|
|
| MIN_DOC_LENGTH = 100
|
| import os
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| SAVE_PATH = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def fineweb_edu_iterator(
|
| max_tokens: int = MAX_TOKENS,
|
| min_quality: int = MIN_QUALITY,
|
| min_length: int = MIN_DOC_LENGTH,
|
| ):
|
| """
|
| Streams FineWeb-Edu documents, filters by quality,
|
| normalizes text, and yields clean strings for BPE training.
|
|
|
| Args:
|
| max_tokens : stop after consuming this many tokens total
|
| min_quality : only yield docs with int_score >= this value
|
| min_length : skip docs shorter than this many characters
|
|
|
| Yields:
|
| str: normalized, clean document text
|
| """
|
|
|
| print(f"Loading dataset stream: {DATASET_NAME} / {DATASET_SUBSET}")
|
| ds = load_dataset(
|
| DATASET_NAME,
|
| name=DATASET_SUBSET,
|
| split="train",
|
| streaming=True,
|
| )
|
|
|
| tokens_seen = 0
|
| docs_yielded = 0
|
| docs_skipped = 0
|
|
|
| for doc in ds:
|
|
|
|
|
| if tokens_seen >= max_tokens:
|
| break
|
|
|
|
|
|
|
| if doc["int_score"] < min_quality:
|
| docs_skipped += 1
|
| continue
|
|
|
|
|
| text = doc["text"]
|
|
|
|
|
|
|
| if len(text) < min_length:
|
| docs_skipped += 1
|
| continue
|
|
|
|
|
| text = normalization(text)
|
|
|
|
|
|
|
| if len(text) < min_length:
|
| docs_skipped += 1
|
| continue
|
|
|
|
|
| tokens_seen += doc["token_count"]
|
| docs_yielded += 1
|
|
|
|
|
| if docs_yielded % 100_000 == 0:
|
| print(
|
| f" docs yielded: {docs_yielded:,} | "
|
| f"docs skipped: {docs_skipped:,} | "
|
| f"tokens seen: {tokens_seen:,} / {max_tokens:,} "
|
| f"({100 * tokens_seen / max_tokens:.1f}%)"
|
| )
|
|
|
| yield text
|
|
|
|
|
| print(f"\nStream complete:")
|
| print(f" docs yielded : {docs_yielded:,}")
|
| print(f" docs skipped : {docs_skipped:,}")
|
| print(f" tokens seen : {tokens_seen:,}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def train_tokenizer() -> Tokenizer:
|
| """
|
| Builds, trains, and saves the tokenizer.
|
|
|
| Returns:
|
| Trained Tokenizer object
|
| """
|
|
|
|
|
| tokenizer = build_tokenizer()
|
| trainer = build_trainer()
|
|
|
| print("\nStarting BPE training...")
|
| print(f" vocab size : {trainer.vocab_size:,}")
|
| print(f" min frequency : {trainer.min_frequency}")
|
| print(f" quality filter: int_score >= {MIN_QUALITY}")
|
| print(f" max tokens : {MAX_TOKENS:,}\n")
|
|
|
|
|
|
|
| tokenizer.train_from_iterator(
|
| iterator=fineweb_edu_iterator(),
|
| trainer=trainer,
|
| length=MAX_TOKENS,
|
| )
|
|
|
| print("\nTraining complete.")
|
|
|
| tokenizer = add_post_processor(tokenizer)
|
|
|
|
|
| ids = get_special_token_ids(tokenizer)
|
| print(f"\nSpecial token IDs:")
|
| for token, token_id in ids.items():
|
| print(f" {token} -> {token_id}")
|
|
|
|
|
|
|
| tokenizer.save(f"{SAVE_PATH}.json")
|
| print(f"\nTokenizer saved to: {SAVE_PATH}.json")
|
|
|
| return tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
| def verify_tokenizer(tokenizer: Tokenizer):
|
| """
|
| Runs a few quick checks after training to verify correctness.
|
| """
|
| print("\n" + "="*60)
|
| print(" TOKENIZER VERIFICATION")
|
| print("="*60 + "\n")
|
|
|
| test_cases = [
|
| "The mitochondria is the powerhouse of the cell.",
|
| "CO2 levels rose by 1.5e-3 ppm in 2024.",
|
| "def compute_loss(y_pred, y_true):\n return (y_pred - y_true)**2",
|
| "U.S.A has a Ph.D program e.g. at MIT.",
|
| "don't they've she'll",
|
| "∇f(x) = 0 is a necessary condition.",
|
| ]
|
|
|
| for text in test_cases:
|
| encoded = tokenizer.encode(text)
|
| decoded = tokenizer.decode(encoded.ids)
|
| n_tokens = len(encoded.ids)
|
|
|
| print(f"Input : {repr(text)}")
|
| print(f"Tokens : {encoded.tokens}")
|
| print(f"IDs : {encoded.ids}")
|
| print(f"N tokens: {n_tokens}")
|
| print(f"Decoded : {repr(decoded)}")
|
| print(f"Lossless: {text == decoded}")
|
| print()
|
|
|
|
|
| vocab_size = tokenizer.get_vocab_size()
|
| print(f"Final vocab size: {vocab_size:,}")
|
|
|
|
|
| eot_id = tokenizer.token_to_id("<|endoftext|>")
|
| print(f"<|endoftext|> ID: {eot_id}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| tokenizer = train_tokenizer()
|
| verify_tokenizer(tokenizer) |