sllm / tokenizer /wrap_tokenizer.py
geeteshcodes's picture
Initial commit
7f974df verified
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
import json
import os
# ------------------------------------------------------------------ #
# CONSTANTS
# ------------------------------------------------------------------ #
import os
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
TOKENIZER_PATH = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer.json")
SAVE_DIR = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer") # output folder
MODEL_MAX_LENGTH = 1024 # context length
PADDING_SIDE = "right" # causal LM standard
# ------------------------------------------------------------------ #
# WRAP
# ------------------------------------------------------------------ #
def wrap_tokenizer(
tokenizer_path: str = TOKENIZER_PATH,
save_dir: str = SAVE_DIR,
) -> PreTrainedTokenizerFast:
"""
Wraps a trained HuggingFace Tokenizer as a PreTrainedTokenizerFast.
This gives us:
- datasets.map() compatibility for bulk tokenization
- HuggingFace Trainer + DataCollator compatibility
- Automatic padding, truncation, attention masks
- from_pretrained() loading support
- return_tensors="pt" for PyTorch tensors
Args:
tokenizer_path : path to trained tokenizer .json file
save_dir : folder to save the wrapped tokenizer
Returns:
PreTrainedTokenizerFast ready for training
"""
print(f"Loading trained tokenizer from: {tokenizer_path}")
base_tokenizer = Tokenizer.from_file(tokenizer_path)
# ---- Wrap --------------------------------------------------------
# We map <|endoftext|> to all three roles:
#
# eos_token - end of sequence marker, used during generation
# to know when to stop
#
# bos_token - beginning of sequence, GPT-2 style uses eos
# for both since there is no separate BOS token
#
# pad_token - safe to reuse eos here because we are packing
# sequences and will never actually pad during
# pretraining. Defined so HuggingFace doesn't
# complain about missing pad token
#
# unk_token - None because byte-level means no unknowns ever
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=base_tokenizer,
# Special token mappings
eos_token="<|endoftext|>",
bos_token="<|endoftext|>",
pad_token="<|endoftext|>",
unk_token=None,
# Context length
model_max_length=MODEL_MAX_LENGTH,
# Padding behavior
padding_side=PADDING_SIDE,
# Truncation side - truncate from the right
# (keep the beginning of the sequence, drop the end)
truncation_side="right",
)
tokenizer.add_special_tokens({
"eos_token": "<|endoftext|>",
"bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
})
special_tokens_map = {
"bos_token": "<|endoftext|>",
"eos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
}
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, "special_tokens_map.json"), "w") as f:
json.dump(special_tokens_map, f, indent=2)
print("special_tokens_map.json written manually")
# ---- Save --------------------------------------------------------
# Saves three files to save_dir/:
# tokenizer.json - the trained BPE tokenizer
# tokenizer_config.json - max length, pad token, special tokens
# special_tokens_map.json - maps eos/bos/pad to actual tokens
tokenizer.save_pretrained(save_dir)
print(f"Tokenizer saved to: {save_dir}/")
print(f" tokenizer.json")
print(f" tokenizer_config.json")
print(f" special_tokens_map.json")
return tokenizer
# ------------------------------------------------------------------ #
# VERIFICATION
# ------------------------------------------------------------------ #
def verify_wrapped_tokenizer(tokenizer: PreTrainedTokenizerFast):
"""
Verifies the wrapped tokenizer behaves correctly.
Tests encoding, decoding, padding, truncation and batch encoding.
"""
print("\n" + "="*60)
print(" WRAPPED TOKENIZER VERIFICATION")
print("="*60 + "\n")
eot_id = tokenizer.eos_token_id
# ---- 1. Basic config -----------------------------------------
print("Config:")
print(f" vocab size : {tokenizer.vocab_size:,}")
print(f" model_max_length : {tokenizer.model_max_length}")
print(f" padding_side : {tokenizer.padding_side}")
print(f" eos_token : {tokenizer.eos_token!r} (ID: {eot_id})")
print(f" bos_token : {tokenizer.bos_token!r}")
print(f" pad_token : {tokenizer.pad_token!r} (ID: {tokenizer.pad_token_id})")
print(f" unk_token : {tokenizer.unk_token!r}")
print()
# ---- 2. Basic encode/decode ----------------------------------
text = "The mitochondria is the powerhouse of the cell."
encoded = tokenizer(text)
decoded = tokenizer.decode(encoded["input_ids"])
print("Basic encode/decode:")
print(f" input : {repr(text)}")
print(f" input_ids: {encoded['input_ids']}")
print(f" decoded : {repr(decoded)}")
print()
# ---- 3. Padding ----------------------------------------------
# Batch of two sequences with different lengths
# shorter one should be right-padded to match the longer
batch = [
"Short sentence.",
"This is a much longer sentence that has more tokens in it.",
]
encoded_batch = tokenizer(
batch,
padding=True, # pad to longest in batch
return_tensors="pt", # return PyTorch tensors
)
print("Batch padding (right padding):")
print(f" input_ids shape : {encoded_batch['input_ids'].shape}")
print(f" attention_mask shape : {encoded_batch['attention_mask'].shape}")
print(f" input_ids[0] : {encoded_batch['input_ids'][0].tolist()}")
print(f" input_ids[1] : {encoded_batch['input_ids'][1].tolist()}")
print(f" attention_mask[0] : {encoded_batch['attention_mask'][0].tolist()}")
print()
# ---- 4. Truncation -------------------------------------------
# Sequence longer than model_max_length should be truncated
long_text = "word " * 2000 # 2000 words >> 1024 tokens
encoded_long = tokenizer(
long_text,
truncation=True,
max_length=MODEL_MAX_LENGTH,
)
print("Truncation:")
print(f" input length : {len(long_text.split())} words")
print(f" token count : {len(encoded_long['input_ids'])} (max: {MODEL_MAX_LENGTH})")
print(f" truncated : {len(encoded_long['input_ids']) <= MODEL_MAX_LENGTH}")
print()
# ---- 5. Load from disk and verify ----------------------------
print("Loading from disk:")
reloaded = PreTrainedTokenizerFast.from_pretrained(SAVE_DIR)
reloaded_ids = reloaded(text)["input_ids"]
original_ids = encoded["input_ids"]
match = reloaded_ids == original_ids
print(f" from_pretrained() : OK")
print(f" IDs match original: {match}")
# ------------------------------------------------------------------ #
# ENTRY POINT
# ------------------------------------------------------------------ #
if __name__ == "__main__":
tokenizer = wrap_tokenizer()
verify_wrapped_tokenizer(tokenizer)
print("\n" + "="*60)
print(" USAGE EXAMPLES")
print("="*60)
print("""
# Load anywhere with one line
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast.from_pretrained("fineweb_edu_tokenizer")
# Single encode
ids = tokenizer("Hello world")["input_ids"]
# Batch encode with padding and tensors
batch = tokenizer(
["sentence one", "sentence two"],
padding=True,
truncation=True,
max_length=1024,
return_tensors="pt",
)
# Decode
text = tokenizer.decode(ids, skip_special_tokens=True)
# Get eos token id (use as document separator when packing)
eot_id = tokenizer.eos_token_id
""")