from datasets import load_dataset from tokenizers import Tokenizer # Import our components from normalizer import normalization # our normalize function from bpe import build_tokenizer, build_trainer, get_special_token_ids from post_processor import add_post_processor # ------------------------------------------------------------------ # # CONSTANTS # ------------------------------------------------------------------ # DATASET_NAME = "HuggingFaceFW/fineweb-edu" DATASET_SUBSET = "CC-MAIN-2014-49" MIN_QUALITY = 3 # int_score >= 3 only MAX_TOKENS = 25_000_000 # ~100M characters worth, enough for BPE training # FineWeb-Edu tokens avg 4-5 chars each MIN_DOC_LENGTH = 100 # skip very short documents, likely boilerplate import os SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) SAVE_PATH = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer") # ------------------------------------------------------------------ # # DATA GENERATOR # ------------------------------------------------------------------ # def fineweb_edu_iterator( max_tokens: int = MAX_TOKENS, min_quality: int = MIN_QUALITY, min_length: int = MIN_DOC_LENGTH, ): """ Streams FineWeb-Edu documents, filters by quality, normalizes text, and yields clean strings for BPE training. Args: max_tokens : stop after consuming this many tokens total min_quality : only yield docs with int_score >= this value min_length : skip docs shorter than this many characters Yields: str: normalized, clean document text """ print(f"Loading dataset stream: {DATASET_NAME} / {DATASET_SUBSET}") ds = load_dataset( DATASET_NAME, name=DATASET_SUBSET, split="train", streaming=True, ) tokens_seen = 0 # running total of tokens consumed docs_yielded = 0 # how many docs passed all filters docs_skipped = 0 # how many docs were filtered out for doc in ds: # ---- Stop condition ---------------------------------------- if tokens_seen >= max_tokens: break # ---- Quality filter ---------------------------------------- # int_score is 0-5, we want educational quality >= 3 if doc["int_score"] < min_quality: docs_skipped += 1 continue # ---- Extract and normalize --------------------------------- text = doc["text"] # Skip very short documents before normalization # (saves compute on boilerplate/empty docs) if len(text) < min_length: docs_skipped += 1 continue # Run our normalization pipeline text = normalization(text) # Skip if normalization made it too short # (e.g. doc was mostly HTML tags or control chars) if len(text) < min_length: docs_skipped += 1 continue # ---- Track progress ---------------------------------------- tokens_seen += doc["token_count"] docs_yielded += 1 # Log progress every 100k documents if docs_yielded % 100_000 == 0: print( f" docs yielded: {docs_yielded:,} | " f"docs skipped: {docs_skipped:,} | " f"tokens seen: {tokens_seen:,} / {max_tokens:,} " f"({100 * tokens_seen / max_tokens:.1f}%)" ) yield text # Final stats print(f"\nStream complete:") print(f" docs yielded : {docs_yielded:,}") print(f" docs skipped : {docs_skipped:,}") print(f" tokens seen : {tokens_seen:,}") # ------------------------------------------------------------------ # # TRAINING # ------------------------------------------------------------------ # def train_tokenizer() -> Tokenizer: """ Builds, trains, and saves the tokenizer. Returns: Trained Tokenizer object """ # Build untrained tokenizer and trainer tokenizer = build_tokenizer() trainer = build_trainer() print("\nStarting BPE training...") print(f" vocab size : {trainer.vocab_size:,}") print(f" min frequency : {trainer.min_frequency}") print(f" quality filter: int_score >= {MIN_QUALITY}") print(f" max tokens : {MAX_TOKENS:,}\n") # train_from_iterator expects an iterable of strings # our generator yields one clean document string at a time tokenizer.train_from_iterator( iterator=fineweb_edu_iterator(), trainer=trainer, length=MAX_TOKENS, # optional hint for progress bar accuracy ) print("\nTraining complete.") tokenizer = add_post_processor(tokenizer) # Print special token IDs ids = get_special_token_ids(tokenizer) print(f"\nSpecial token IDs:") for token, token_id in ids.items(): print(f" {token} -> {token_id}") # Save tokenizer to disk tokenizer.save(f"{SAVE_PATH}.json") print(f"\nTokenizer saved to: {SAVE_PATH}.json") return tokenizer # ------------------------------------------------------------------ # # QUICK VERIFICATION after training # ------------------------------------------------------------------ # def verify_tokenizer(tokenizer: Tokenizer): """ Runs a few quick checks after training to verify correctness. """ print("\n" + "="*60) print(" TOKENIZER VERIFICATION") print("="*60 + "\n") test_cases = [ "The mitochondria is the powerhouse of the cell.", "CO2 levels rose by 1.5e-3 ppm in 2024.", "def compute_loss(y_pred, y_true):\n return (y_pred - y_true)**2", "U.S.A has a Ph.D program e.g. at MIT.", "don't they've she'll", "∇f(x) = 0 is a necessary condition.", # tests byte fallback ] for text in test_cases: encoded = tokenizer.encode(text) decoded = tokenizer.decode(encoded.ids) n_tokens = len(encoded.ids) print(f"Input : {repr(text)}") print(f"Tokens : {encoded.tokens}") print(f"IDs : {encoded.ids}") print(f"N tokens: {n_tokens}") print(f"Decoded : {repr(decoded)}") print(f"Lossless: {text == decoded}") print() # Verify vocab size vocab_size = tokenizer.get_vocab_size() print(f"Final vocab size: {vocab_size:,}") # Verify endoftext token exists eot_id = tokenizer.token_to_id("<|endoftext|>") print(f"<|endoftext|> ID: {eot_id}") # ------------------------------------------------------------------ # # ENTRY POINT # ------------------------------------------------------------------ # if __name__ == "__main__": tokenizer = train_tokenizer() verify_tokenizer(tokenizer)