SymbioSLM-GrammarExpert / tokenizer.py
LisaMegaWatts's picture
Upload tokenizer.py with huggingface_hub
c312c47 verified
"""BPE tokenizer — GPT-2 style byte-level BPE (matches Julia SLM tokenizer)."""
import json
from typing import Dict, List, Tuple
try:
import regex as re
except ImportError:
import re
_GPT2_PAT = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
re.UNICODE,
)
def _build_byte_to_unicode() -> Dict[int, str]:
bs = list(range(ord("!"), ord("~") + 1))
bs += list(range(ord("¡"), ord("¬") + 1))
bs += list(range(ord("®"), ord("ÿ") + 1))
cs = list(bs)
n = 0
for b in range(256):
if b not in bs:
bs.append(b)
cs.append(256 + n)
n += 1
return {b: chr(c) for b, c in zip(bs, cs)}
class BPETokenizer:
def __init__(self, encoder: Dict[str, int], merges: List[Tuple[str, str]]):
self.encoder = encoder
self.decoder = {v: k for k, v in encoder.items()}
self.merges = merges
self.merge_ranks = {pair: i for i, pair in enumerate(merges)}
self.byte_to_unicode = _build_byte_to_unicode()
self.unicode_to_byte = {v: k for k, v in self.byte_to_unicode.items()}
@classmethod
def from_files(cls, vocab_path: str, merges_path: str) -> "BPETokenizer":
with open(vocab_path, "r", encoding="utf-8") as f:
encoder = json.load(f)
merges = []
with open(merges_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("#") or not line:
continue
parts = line.split()
if len(parts) == 2:
merges.append((parts[0], parts[1]))
return cls(encoder, merges)
@property
def vocab_size(self) -> int:
return len(self.encoder)
def encode(self, text: str) -> List[int]:
tokens = []
for match in _GPT2_PAT.finditer(text):
word = match.group()
encoded_chars = [self.byte_to_unicode[b] for b in word.encode("utf-8")]
symbols = self._bpe_encode_word(list(encoded_chars))
for tok in symbols:
token_id = self.encoder.get(tok)
if token_id is not None:
tokens.append(token_id)
return tokens
def decode(self, ids: List[int]) -> str:
token_strs = [self.decoder.get(i, "") for i in ids]
joined = "".join(token_strs)
out = bytearray()
for c in joined:
b = self.unicode_to_byte.get(c)
if b is not None:
out.append(b)
else:
out.extend(c.encode("utf-8"))
return out.decode("utf-8", errors="replace")
def _bpe_encode_word(self, symbols: List[str]) -> List[str]:
while len(symbols) > 1:
best_pair = None
best_rank = float("inf")
for i in range(len(symbols) - 1):
pair = (symbols[i], symbols[i + 1])
rank = self.merge_ranks.get(pair, float("inf"))
if rank < best_rank:
best_rank = rank
best_pair = pair
if best_rank == float("inf"):
break
new_symbols = []
i = 0
while i < len(symbols):
if (
i < len(symbols) - 1
and symbols[i] == best_pair[0]
and symbols[i + 1] == best_pair[1]
):
new_symbols.append(best_pair[0] + best_pair[1])
i += 2
else:
new_symbols.append(symbols[i])
i += 1
symbols = new_symbols
return symbols