"""BPE tokenizer — GPT-2 style byte-level BPE (matches Julia SLM tokenizer).""" import json from typing import Dict, List, Tuple try: import regex as re except ImportError: import re _GPT2_PAT = re.compile( r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", re.UNICODE, ) def _build_byte_to_unicode() -> Dict[int, str]: bs = list(range(ord("!"), ord("~") + 1)) bs += list(range(ord("¡"), ord("¬") + 1)) bs += list(range(ord("®"), ord("ÿ") + 1)) cs = list(bs) n = 0 for b in range(256): if b not in bs: bs.append(b) cs.append(256 + n) n += 1 return {b: chr(c) for b, c in zip(bs, cs)} class BPETokenizer: def __init__(self, encoder: Dict[str, int], merges: List[Tuple[str, str]]): self.encoder = encoder self.decoder = {v: k for k, v in encoder.items()} self.merges = merges self.merge_ranks = {pair: i for i, pair in enumerate(merges)} self.byte_to_unicode = _build_byte_to_unicode() self.unicode_to_byte = {v: k for k, v in self.byte_to_unicode.items()} @classmethod def from_files(cls, vocab_path: str, merges_path: str) -> "BPETokenizer": with open(vocab_path, "r", encoding="utf-8") as f: encoder = json.load(f) merges = [] with open(merges_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line.startswith("#") or not line: continue parts = line.split() if len(parts) == 2: merges.append((parts[0], parts[1])) return cls(encoder, merges) @property def vocab_size(self) -> int: return len(self.encoder) def encode(self, text: str) -> List[int]: tokens = [] for match in _GPT2_PAT.finditer(text): word = match.group() encoded_chars = [self.byte_to_unicode[b] for b in word.encode("utf-8")] symbols = self._bpe_encode_word(list(encoded_chars)) for tok in symbols: token_id = self.encoder.get(tok) if token_id is not None: tokens.append(token_id) return tokens def decode(self, ids: List[int]) -> str: token_strs = [self.decoder.get(i, "") for i in ids] joined = "".join(token_strs) out = bytearray() for c in joined: b = self.unicode_to_byte.get(c) if b is not None: out.append(b) else: out.extend(c.encode("utf-8")) return out.decode("utf-8", errors="replace") def _bpe_encode_word(self, symbols: List[str]) -> List[str]: while len(symbols) > 1: best_pair = None best_rank = float("inf") for i in range(len(symbols) - 1): pair = (symbols[i], symbols[i + 1]) rank = self.merge_ranks.get(pair, float("inf")) if rank < best_rank: best_rank = rank best_pair = pair if best_rank == float("inf"): break new_symbols = [] i = 0 while i < len(symbols): if ( i < len(symbols) - 1 and symbols[i] == best_pair[0] and symbols[i + 1] == best_pair[1] ): new_symbols.append(best_pair[0] + best_pair[1]) i += 2 else: new_symbols.append(symbols[i]) i += 1 symbols = new_symbols return symbols