SymbioGPT-10M / transformer_tokenizer.py
LisaMegaWatts's picture
Add BPE tokenizer (needed for Colab training notebooks)
215b74b verified
"""BPE tokenizer for causal LM training (Phase D: Symbiotic Distillation).
Minimal GPT-2 style byte-level BPE tokenizer that loads vocab.json and
merges.txt files produced by the text-pipeline. No HuggingFace dependency.
Matches the Julia SLM BPETokenizer (tokenizer.jl) encoding/decoding.
"""
import json
from typing import Dict, List, Tuple
try:
import regex as re # supports \p{L} Unicode property escapes
except ImportError:
import re # fallback (will fail on \p{L} patterns)
def _build_byte_to_unicode() -> Dict[int, str]:
"""GPT-2 byte-to-unicode mapping (matches Julia _build_byte_to_unicode)."""
bs = list(range(ord("!"), ord("~") + 1))
bs += list(range(ord("¡"), ord("¬") + 1))
bs += list(range(ord("®"), ord("ÿ") + 1))
cs = list(bs)
n = 0
for b in range(256):
if b not in bs:
bs.append(b)
cs.append(256 + n)
n += 1
return {b: chr(c) for b, c in zip(bs, cs)}
# GPT-2 pre-tokenization pattern
_GPT2_PAT = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
re.UNICODE,
)
class BPETokenizer:
"""Minimal byte-level BPE tokenizer for causal LM training.
Loads vocab.json + merges.txt and provides encode/decode methods
compatible with the Julia SLM tokenizer (0-indexed token IDs).
"""
def __init__(
self,
encoder: Dict[str, int],
merges: List[Tuple[str, str]],
):
self.encoder = encoder
self.decoder = {v: k for k, v in encoder.items()}
self.merges = merges
self.merge_ranks = {pair: i for i, pair in enumerate(merges)}
self.byte_to_unicode = _build_byte_to_unicode()
self.unicode_to_byte = {v: k for k, v in self.byte_to_unicode.items()}
@classmethod
def from_files(cls, vocab_path: str, merges_path: str) -> "BPETokenizer":
"""Load tokenizer from vocab.json and merges.txt files."""
with open(vocab_path, "r", encoding="utf-8") as f:
encoder = json.load(f)
merges = []
with open(merges_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("#") or not line:
continue
parts = line.split()
if len(parts) == 2:
merges.append((parts[0], parts[1]))
return cls(encoder, merges)
@property
def vocab_size(self) -> int:
return len(self.encoder)
@property
def pad_token_id(self) -> int:
return self.encoder.get("<|pad|>", 0)
@property
def eos_token_id(self) -> int:
return self.encoder.get("<|eos|>", 1)
def encode(self, text: str) -> List[int]:
"""Encode text to token IDs (0-indexed)."""
tokens = []
for match in _GPT2_PAT.finditer(text):
word = match.group()
# Convert bytes to unicode representation
encoded_chars = [self.byte_to_unicode[b] for b in word.encode("utf-8")]
# Apply BPE merges
symbols = list(encoded_chars)
symbols = self._bpe_encode_word(symbols)
# Look up token IDs
for tok in symbols:
token_id = self.encoder.get(tok)
if token_id is not None:
tokens.append(token_id)
return tokens
def decode(self, ids: List[int]) -> str:
"""Decode token IDs back to text."""
token_strs = [self.decoder.get(i, "") for i in ids]
joined = "".join(token_strs)
# Convert unicode chars back to bytes
out = bytearray()
for c in joined:
b = self.unicode_to_byte.get(c)
if b is not None:
out.append(b)
else:
out.extend(c.encode("utf-8"))
return out.decode("utf-8", errors="replace")
def _bpe_encode_word(self, symbols: List[str]) -> List[str]:
"""Iteratively merge the highest-priority pair."""
while len(symbols) > 1:
# Find best merge pair
best_pair = None
best_rank = float("inf")
for i in range(len(symbols) - 1):
pair = (symbols[i], symbols[i + 1])
rank = self.merge_ranks.get(pair, float("inf"))
if rank < best_rank:
best_rank = rank
best_pair = pair
if best_rank == float("inf"):
break # no more merges
# Apply the merge
new_symbols = []
i = 0
while i < len(symbols):
if (
i < len(symbols) - 1
and symbols[i] == best_pair[0]
and symbols[i + 1] == best_pair[1]
):
new_symbols.append(best_pair[0] + best_pair[1])
i += 2
else:
new_symbols.append(symbols[i])
i += 1
symbols = new_symbols
return symbols