| """BPE tokenizer for causal LM training (Phase D: Symbiotic Distillation).
|
|
|
| Minimal GPT-2 style byte-level BPE tokenizer that loads vocab.json and
|
| merges.txt files produced by the text-pipeline. No HuggingFace dependency.
|
|
|
| Matches the Julia SLM BPETokenizer (tokenizer.jl) encoding/decoding.
|
| """
|
| import json
|
| from typing import Dict, List, Tuple
|
|
|
| try:
|
| import regex as re
|
| except ImportError:
|
| import re
|
|
|
|
|
| def _build_byte_to_unicode() -> Dict[int, str]:
|
| """GPT-2 byte-to-unicode mapping (matches Julia _build_byte_to_unicode)."""
|
| bs = list(range(ord("!"), ord("~") + 1))
|
| bs += list(range(ord("¡"), ord("¬") + 1))
|
| bs += list(range(ord("®"), ord("ÿ") + 1))
|
| cs = list(bs)
|
| n = 0
|
| for b in range(256):
|
| if b not in bs:
|
| bs.append(b)
|
| cs.append(256 + n)
|
| n += 1
|
| return {b: chr(c) for b, c in zip(bs, cs)}
|
|
|
|
|
|
|
| _GPT2_PAT = re.compile(
|
| r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
| re.UNICODE,
|
| )
|
|
|
|
|
| class BPETokenizer:
|
| """Minimal byte-level BPE tokenizer for causal LM training.
|
|
|
| Loads vocab.json + merges.txt and provides encode/decode methods
|
| compatible with the Julia SLM tokenizer (0-indexed token IDs).
|
| """
|
|
|
| def __init__(
|
| self,
|
| encoder: Dict[str, int],
|
| merges: List[Tuple[str, str]],
|
| ):
|
| self.encoder = encoder
|
| self.decoder = {v: k for k, v in encoder.items()}
|
| self.merges = merges
|
| self.merge_ranks = {pair: i for i, pair in enumerate(merges)}
|
| self.byte_to_unicode = _build_byte_to_unicode()
|
| self.unicode_to_byte = {v: k for k, v in self.byte_to_unicode.items()}
|
|
|
| @classmethod
|
| def from_files(cls, vocab_path: str, merges_path: str) -> "BPETokenizer":
|
| """Load tokenizer from vocab.json and merges.txt files."""
|
| with open(vocab_path, "r", encoding="utf-8") as f:
|
| encoder = json.load(f)
|
|
|
| merges = []
|
| with open(merges_path, "r", encoding="utf-8") as f:
|
| for line in f:
|
| line = line.strip()
|
| if line.startswith("#") or not line:
|
| continue
|
| parts = line.split()
|
| if len(parts) == 2:
|
| merges.append((parts[0], parts[1]))
|
|
|
| return cls(encoder, merges)
|
|
|
| @property
|
| def vocab_size(self) -> int:
|
| return len(self.encoder)
|
|
|
| @property
|
| def pad_token_id(self) -> int:
|
| return self.encoder.get("<|pad|>", 0)
|
|
|
| @property
|
| def eos_token_id(self) -> int:
|
| return self.encoder.get("<|eos|>", 1)
|
|
|
| def encode(self, text: str) -> List[int]:
|
| """Encode text to token IDs (0-indexed)."""
|
| tokens = []
|
| for match in _GPT2_PAT.finditer(text):
|
| word = match.group()
|
|
|
| encoded_chars = [self.byte_to_unicode[b] for b in word.encode("utf-8")]
|
|
|
| symbols = list(encoded_chars)
|
| symbols = self._bpe_encode_word(symbols)
|
|
|
| for tok in symbols:
|
| token_id = self.encoder.get(tok)
|
| if token_id is not None:
|
| tokens.append(token_id)
|
| return tokens
|
|
|
| def decode(self, ids: List[int]) -> str:
|
| """Decode token IDs back to text."""
|
| token_strs = [self.decoder.get(i, "") for i in ids]
|
| joined = "".join(token_strs)
|
|
|
| out = bytearray()
|
| for c in joined:
|
| b = self.unicode_to_byte.get(c)
|
| if b is not None:
|
| out.append(b)
|
| else:
|
| out.extend(c.encode("utf-8"))
|
| return out.decode("utf-8", errors="replace")
|
|
|
| def _bpe_encode_word(self, symbols: List[str]) -> List[str]:
|
| """Iteratively merge the highest-priority pair."""
|
| while len(symbols) > 1:
|
|
|
| best_pair = None
|
| best_rank = float("inf")
|
| for i in range(len(symbols) - 1):
|
| pair = (symbols[i], symbols[i + 1])
|
| rank = self.merge_ranks.get(pair, float("inf"))
|
| if rank < best_rank:
|
| best_rank = rank
|
| best_pair = pair
|
|
|
| if best_rank == float("inf"):
|
| break
|
|
|
|
|
| new_symbols = []
|
| i = 0
|
| while i < len(symbols):
|
| if (
|
| i < len(symbols) - 1
|
| and symbols[i] == best_pair[0]
|
| and symbols[i + 1] == best_pair[1]
|
| ):
|
| new_symbols.append(best_pair[0] + best_pair[1])
|
| i += 2
|
| else:
|
| new_symbols.append(symbols[i])
|
| i += 1
|
| symbols = new_symbols
|
|
|
| return symbols
|
|
|