| | """BPE tokenizer — GPT-2 style byte-level BPE (matches Julia SLM tokenizer).""" |
| | import json |
| | from typing import Dict, List, Tuple |
| |
|
| | try: |
| | import regex as re |
| | except ImportError: |
| | import re |
| |
|
| | _GPT2_PAT = re.compile( |
| | r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", |
| | re.UNICODE, |
| | ) |
| |
|
| |
|
| | def _build_byte_to_unicode() -> Dict[int, str]: |
| | bs = list(range(ord("!"), ord("~") + 1)) |
| | bs += list(range(ord("¡"), ord("¬") + 1)) |
| | bs += list(range(ord("®"), ord("ÿ") + 1)) |
| | cs = list(bs) |
| | n = 0 |
| | for b in range(256): |
| | if b not in bs: |
| | bs.append(b) |
| | cs.append(256 + n) |
| | n += 1 |
| | return {b: chr(c) for b, c in zip(bs, cs)} |
| |
|
| |
|
| | class BPETokenizer: |
| | def __init__(self, encoder: Dict[str, int], merges: List[Tuple[str, str]]): |
| | self.encoder = encoder |
| | self.decoder = {v: k for k, v in encoder.items()} |
| | self.merges = merges |
| | self.merge_ranks = {pair: i for i, pair in enumerate(merges)} |
| | self.byte_to_unicode = _build_byte_to_unicode() |
| | self.unicode_to_byte = {v: k for k, v in self.byte_to_unicode.items()} |
| |
|
| | @classmethod |
| | def from_files(cls, vocab_path: str, merges_path: str) -> "BPETokenizer": |
| | with open(vocab_path, "r", encoding="utf-8") as f: |
| | encoder = json.load(f) |
| | merges = [] |
| | with open(merges_path, "r", encoding="utf-8") as f: |
| | for line in f: |
| | line = line.strip() |
| | if line.startswith("#") or not line: |
| | continue |
| | parts = line.split() |
| | if len(parts) == 2: |
| | merges.append((parts[0], parts[1])) |
| | return cls(encoder, merges) |
| |
|
| | @property |
| | def vocab_size(self) -> int: |
| | return len(self.encoder) |
| |
|
| | def encode(self, text: str) -> List[int]: |
| | tokens = [] |
| | for match in _GPT2_PAT.finditer(text): |
| | word = match.group() |
| | encoded_chars = [self.byte_to_unicode[b] for b in word.encode("utf-8")] |
| | symbols = self._bpe_encode_word(list(encoded_chars)) |
| | for tok in symbols: |
| | token_id = self.encoder.get(tok) |
| | if token_id is not None: |
| | tokens.append(token_id) |
| | return tokens |
| |
|
| | def decode(self, ids: List[int]) -> str: |
| | token_strs = [self.decoder.get(i, "") for i in ids] |
| | joined = "".join(token_strs) |
| | out = bytearray() |
| | for c in joined: |
| | b = self.unicode_to_byte.get(c) |
| | if b is not None: |
| | out.append(b) |
| | else: |
| | out.extend(c.encode("utf-8")) |
| | return out.decode("utf-8", errors="replace") |
| |
|
| | def _bpe_encode_word(self, symbols: List[str]) -> List[str]: |
| | while len(symbols) > 1: |
| | best_pair = None |
| | best_rank = float("inf") |
| | for i in range(len(symbols) - 1): |
| | pair = (symbols[i], symbols[i + 1]) |
| | rank = self.merge_ranks.get(pair, float("inf")) |
| | if rank < best_rank: |
| | best_rank = rank |
| | best_pair = pair |
| | if best_rank == float("inf"): |
| | break |
| | new_symbols = [] |
| | i = 0 |
| | while i < len(symbols): |
| | if ( |
| | i < len(symbols) - 1 |
| | and symbols[i] == best_pair[0] |
| | and symbols[i + 1] == best_pair[1] |
| | ): |
| | new_symbols.append(best_pair[0] + best_pair[1]) |
| | i += 2 |
| | else: |
| | new_symbols.append(symbols[i]) |
| | i += 1 |
| | symbols = new_symbols |
| | return symbols |
| |
|