HawkGPT-v0.3 / tokenizer_module.py
HawkLabofficial's picture
Upload tokenizer_module.py with huggingface_hub
4006c4c verified
Raw
History Blame Contribute Delete
2.33 kB
"""HawkGPT 0.3 — Digit-aware BPE tokenizer.
Key improvement: numbers are split into individual digits.
"123" → ["1", "2", "3"] — model can learn arithmetic digit-by-digit.
"""
import os
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
import config
def train_tokenizer(text_path: str, vocab_size: int = None) -> Tokenizer:
"""Train digit-aware BPE tokenizer."""
if vocab_size is None:
vocab_size = config.VOCAB_SIZE
tokenizer = Tokenizer(models.BPE())
# Digit-aware: Whitespace + individual digits
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Whitespace(),
pre_tokenizers.Digits(individual_digits=True),
])
trainer = trainers.BpeTrainer(
vocab_size=vocab_size,
special_tokens=["[PAD]", "[BOS]", "[EOS]", "[UNK]", "[MASK]"],
min_frequency=3,
)
def line_iterator():
with open(text_path, "r", encoding="utf-8") as f:
for line in f:
yield line
tokenizer.train_from_iterator(line_iterator(), trainer=trainer)
tokenizer.enable_padding(length=config.MAX_SEQ_LEN, pad_id=tokenizer.token_to_id("[PAD]"))
tokenizer.enable_truncation(max_length=config.MAX_SEQ_LEN)
os.makedirs(config.DATA_DIR, exist_ok=True)
tokenizer.save(config.TOKENIZER_PATH)
# Verify digit splitting
tok_test = Tokenizer.from_file(config.TOKENIZER_PATH)
tok_test.no_padding()
tok_test.no_truncation()
enc = tok_test.encode("123 + 456 = 579")
print(f"Digit test: {'123 + 456 = 579'}")
print(f" Tokens: {enc.tokens}")
print(f" Digit-aware: {'1' in enc.tokens and '2' in enc.tokens and '3' in enc.tokens}")
print(f"Tokenizer saved: {config.TOKENIZER_PATH} | vocab={tokenizer.get_vocab_size()}")
return tokenizer
def load_tokenizer() -> Tokenizer:
if not os.path.exists(config.TOKENIZER_PATH):
raise FileNotFoundError(f"Tokenizer not found at {config.TOKENIZER_PATH}")
return Tokenizer.from_file(config.TOKENIZER_PATH)
if __name__ == "__main__":
tok = train_tokenizer(config.DATA_TEXT_PATH)
tok.no_padding()
tok.no_truncation()
for test in ["123 + 456 = 579", "Привет! Как дела?", "Реши: 3x + 5 = 20"]:
enc = tok.encode(test)
print(f" {test}{enc.tokens}")