sllm / tokenizer /bpe.py
geeteshcodes's picture
Initial commit
7f974df verified
from tokenizers import Tokenizer, AddedToken
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Sequence, ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from pretokenizer import get_pretokenizer
VOCAB_SIZE = 32_000
MIN_FREQUENCY = 3
SPECIAL_TOKENS = ["<|endoftext|>"]
def build_tokenizer() -> Tokenizer:
"""
Builds and returns an untrained tokenizer with all components configured.
Call .train_from_iterator() or .train() on the returned object to train it.
Pipeline:
Raw text
-> Normalizer (handled externally in our normalize() fn)
-> Pre-tokenizer (custom regex splits + byte level conversion)
-> BPE Model (learns merge rules during training)
-> Decoder (reverses byte level for human readable output)
"""
# ---- 1. BPE Model ------------------------------------------------
# unk_token=None because byte-level means we NEVER have unknowns
# every character always maps to at least one byte token
model = BPE(
unk_token=None, # no unknown token - byte fallback handles everything
byte_fallback=True, # unknown chars represented as <0xXX> byte tokens
# e.g. ∇ -> <0xE2><0x88><0x87>
)
tokenizer = Tokenizer(model)
# ---- 2. Pre-tokenizer --------------------------------------------
# Sequence chains two pre-tokenizers in order:
#
# Step A: Our custom regex splits text into meaningful chunks
# (contractions, abbreviations, numbers, operators etc.)
#
# Step B: ByteLevel converts each chunk's characters to their
# byte representation using a 256-char printable alphabet
# e.g. é (bytes 0xC3 0xA9) -> "é"
#
# add_prefix_space=False because our regex already handles
# whitespace explicitly as its own token category
tokenizer.pre_tokenizer = Sequence([
get_pretokenizer(), # Step A - our regex
ByteLevel(add_prefix_space=False), # Step B - byte conversion
])
# ---- 3. Decoder --------------------------------------------------
# Reverses the ByteLevel encoding so output is human readable
# Without this tokenizer.decode() would return "é" instead of "é"
tokenizer.decoder = ByteLevelDecoder()
return tokenizer
# ------------------------------------------------------------------ #
# TRAINER CONFIG
# ------------------------------------------------------------------ #
def build_trainer() -> BpeTrainer:
"""
Configures the BPE trainer.
vocab_size breakdown:
256 base byte tokens (one per possible byte value, always present)
+ 31,743 learned BPE merge tokens
+ 1 special token (<|endoftext|>)
= 32,000 total
The trainer automatically accounts for the 256 base tokens,
so setting vocab_size=32_000 gives you the right final count.
"""
return BpeTrainer(
vocab_size=VOCAB_SIZE,
min_frequency=MIN_FREQUENCY,
special_tokens=SPECIAL_TOKENS,
# show_progress shows a progress bar during training
show_progress=True,
# initial_alphabet tells the trainer to include all 256 bytes
# as base tokens before any merges happen
# This is what guarantees byte-level fallback works
initial_alphabet=ByteLevel.alphabet(),
)
# CONVENIENCE: get special token IDs after training
def get_special_token_ids(tokenizer: Tokenizer) -> dict:
"""
Returns a dict of special token string -> token ID.
Call this AFTER training to get the final IDs.
Example:
ids = get_special_token_ids(tokenizer)
eot_id = ids["<|endoftext|>"] # typically 0
"""
return {
token: tokenizer.token_to_id(token)
for token in SPECIAL_TOKENS
}
# QUICK SANITY CHECK
if __name__ == "__main__":
print("Building tokenizer...")
tokenizer = build_tokenizer()
print("Building trainer...")
trainer = build_trainer()
# Verify pre-tokenizer chain is set up correctly
print("\nPre-tokenizer chain:")
print(f" {tokenizer.pre_tokenizer}")
# Verify decoder is set
print(f"\nDecoder:")
print(f" {tokenizer.decoder}")
# Verify trainer config
print(f"\nTrainer config:")
print(f" vocab_size : {trainer.vocab_size}")
print(f" min_frequency : {trainer.min_frequency}")
print(f" special_tokens: {trainer.special_tokens}")
print(f" base alphabet : {len(ByteLevel.alphabet())} byte tokens")
print("\nAll good - ready to train.")
print("Next step: pipe FineWeb-Edu text into tokenizer.train_from_iterator()")