| """HawkGPT 0.3 — Digit-aware BPE tokenizer. |
| |
| Key improvement: numbers are split into individual digits. |
| "123" → ["1", "2", "3"] — model can learn arithmetic digit-by-digit. |
| """ |
|
|
| import os |
| from tokenizers import Tokenizer, models, pre_tokenizers, trainers |
|
|
| import config |
|
|
|
|
| def train_tokenizer(text_path: str, vocab_size: int = None) -> Tokenizer: |
| """Train digit-aware BPE tokenizer.""" |
| if vocab_size is None: |
| vocab_size = config.VOCAB_SIZE |
|
|
| tokenizer = Tokenizer(models.BPE()) |
|
|
| |
| tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ |
| pre_tokenizers.Whitespace(), |
| pre_tokenizers.Digits(individual_digits=True), |
| ]) |
|
|
| trainer = trainers.BpeTrainer( |
| vocab_size=vocab_size, |
| special_tokens=["[PAD]", "[BOS]", "[EOS]", "[UNK]", "[MASK]"], |
| min_frequency=3, |
| ) |
|
|
| def line_iterator(): |
| with open(text_path, "r", encoding="utf-8") as f: |
| for line in f: |
| yield line |
|
|
| tokenizer.train_from_iterator(line_iterator(), trainer=trainer) |
|
|
| tokenizer.enable_padding(length=config.MAX_SEQ_LEN, pad_id=tokenizer.token_to_id("[PAD]")) |
| tokenizer.enable_truncation(max_length=config.MAX_SEQ_LEN) |
|
|
| os.makedirs(config.DATA_DIR, exist_ok=True) |
| tokenizer.save(config.TOKENIZER_PATH) |
|
|
| |
| tok_test = Tokenizer.from_file(config.TOKENIZER_PATH) |
| tok_test.no_padding() |
| tok_test.no_truncation() |
| enc = tok_test.encode("123 + 456 = 579") |
| print(f"Digit test: {'123 + 456 = 579'}") |
| print(f" Tokens: {enc.tokens}") |
| print(f" Digit-aware: {'1' in enc.tokens and '2' in enc.tokens and '3' in enc.tokens}") |
|
|
| print(f"Tokenizer saved: {config.TOKENIZER_PATH} | vocab={tokenizer.get_vocab_size()}") |
| return tokenizer |
|
|
|
|
| def load_tokenizer() -> Tokenizer: |
| if not os.path.exists(config.TOKENIZER_PATH): |
| raise FileNotFoundError(f"Tokenizer not found at {config.TOKENIZER_PATH}") |
| return Tokenizer.from_file(config.TOKENIZER_PATH) |
|
|
|
|
| if __name__ == "__main__": |
| tok = train_tokenizer(config.DATA_TEXT_PATH) |
| tok.no_padding() |
| tok.no_truncation() |
| for test in ["123 + 456 = 579", "Привет! Как дела?", "Реши: 3x + 5 = 20"]: |
| enc = tok.encode(test) |
| print(f" {test} → {enc.tokens}") |
|
|