""" Byte-Pair Encoding trainer and codec optimized for JSON value strings. Uses incremental pair counting with pair→word index for fast merges. """ from __future__ import annotations import json import re from collections import defaultdict from typing import Optional def _bytes_to_unicode() -> dict[int, str]: """Map bytes 0-255 to unicode chars, avoiding control/whitespace collisions.""" bs = ( list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) ) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8 + n) n += 1 return {b: chr(c) for b, c in zip(bs, cs)} BYTE_ENCODER = _bytes_to_unicode() BYTE_DECODER = {v: k for k, v in BYTE_ENCODER.items()} _PRE_TOK_PAT = re.compile( r"""'s|'t|'re|'ve|'m|'ll|'d| ?[a-zA-Z_]+| ?[0-9]+| ?[^\s\w]+|\s+|.""" ) class BPETrainer: """Train a BPE vocabulary from a corpus of JSON value strings.""" def __init__(self, vocab_size: int = 4096, min_frequency: int = 2): self.vocab_size = vocab_size self.min_frequency = min_frequency self.merges: list[tuple[str, str]] = [] self.vocab: dict[str, int] = {} self._id_to_tok: dict[int, str] | None = None def _pre_tokenize(self, text: str) -> list[str]: return _PRE_TOK_PAT.findall(text) def _text_to_bytes(self, text: str) -> tuple[str, ...]: return tuple(BYTE_ENCODER[b] for b in text.encode("utf-8")) def train(self, texts: list[str]) -> None: """Train BPE with pair→word index for O(affected) merges.""" # Count word frequencies word_freqs: dict[tuple[str, ...], int] = {} for text in texts: for word in self._pre_tokenize(text): bw = self._text_to_bytes(word) word_freqs[bw] = word_freqs.get(bw, 0) + 1 # Base vocab base_vocab: set[str] = set() for word in word_freqs: base_vocab.update(word) num_merges = self.vocab_size - len(base_vocab) - 1 # Word storage: idx → [symbols], freq words: list[list[str]] = [] freqs: list[int] = [] for w, f in word_freqs.items(): words.append(list(w)) freqs.append(f) # Pair counts and pair→word indices pair_counts: dict[tuple[str, str], int] = defaultdict(int) pair_to_words: dict[tuple[str, str], set[int]] = defaultdict(set) for idx, (w, f) in enumerate(zip(words, freqs)): for i in range(len(w) - 1): p = (w[i], w[i + 1]) pair_counts[p] += f pair_to_words[p].add(idx) for _ in range(max(0, num_merges)): if not pair_counts: break # Find best pair best_pair = max(pair_counts, key=pair_counts.__getitem__) if pair_counts[best_pair] < self.min_frequency: break a, b = best_pair merged = a + b self.merges.append(best_pair) # Only process words that contain this pair affected = list(pair_to_words.pop(best_pair, set())) del pair_counts[best_pair] for idx in affected: w = words[idx] f = freqs[idx] # Find positions of the pair new_w: list[str] = [] i = 0 while i < len(w): if i < len(w) - 1 and w[i] == a and w[i + 1] == b: # Decrement old adjacent pairs if new_w: old_left = (new_w[-1], a) pair_counts[old_left] -= f if pair_counts[old_left] <= 0: pair_counts.pop(old_left, None) pair_to_words[old_left].discard(idx) if i + 2 < len(w): old_right = (b, w[i + 2]) pair_counts[old_right] -= f if pair_counts[old_right] <= 0: pair_counts.pop(old_right, None) pair_to_words[old_right].discard(idx) new_w.append(merged) # Increment new adjacent pairs if len(new_w) >= 2: nl = (new_w[-2], merged) pair_counts[nl] += f pair_to_words[nl].add(idx) if i + 2 < len(w): nr = (merged, w[i + 2]) pair_counts[nr] += f pair_to_words[nr].add(idx) i += 2 else: new_w.append(w[i]) i += 1 words[idx] = new_w # Prune dead entries periodically if _ % 50 == 0: pair_counts = defaultdict(int, {k: v for k, v in pair_counts.items() if v > 0}) # Build vocab self.vocab = {} idx = 0 for ch in sorted(base_vocab): self.vocab[ch] = idx idx += 1 for merge in self.merges: m = merge[0] + merge[1] if m not in self.vocab: self.vocab[m] = idx idx += 1 self.vocab[""] = idx self._id_to_tok = None def _apply_merge(self, word: tuple[str, ...], pair: tuple[str, str]) -> tuple[str, ...]: new: list[str] = [] i = 0 while i < len(word): if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]: new.append(pair[0] + pair[1]) i += 2 else: new.append(word[i]) i += 1 return tuple(new) def encode_word(self, word: str) -> list[str]: bw = self._text_to_bytes(word) if len(bw) == 1: return [bw[0]] for merge in self.merges: bw = self._apply_merge(bw, merge) return list(bw) def encode(self, text: str) -> list[str]: tokens: list[str] = [] for word in self._pre_tokenize(text): tokens.extend(self.encode_word(word)) return tokens def encode_to_ids(self, text: str) -> list[int]: tokens = self.encode(text) unk_id = self.vocab.get("", 0) return [self.vocab.get(t, unk_id) for t in tokens] def id_to_token(self, token_id: int) -> str: if self._id_to_tok is None: self._id_to_tok = {v: k for k, v in self.vocab.items()} return self._id_to_tok.get(token_id, "") def decode_ids(self, ids: list[int]) -> str: return self.decode_tokens([self.id_to_token(i) for i in ids]) def decode_tokens(self, tokens: list[str]) -> str: byte_str = "".join(tokens) return bytearray(BYTE_DECODER.get(c, ord(c)) for c in byte_str).decode("utf-8", errors="replace") def save(self, path: str) -> None: with open(path, "w") as f: json.dump({ "version": "json-tokenizer-bpe-v1", "vocab_size": self.vocab_size, "min_frequency": self.min_frequency, "merges": [list(m) for m in self.merges], "vocab": self.vocab, }, f, indent=2) @classmethod def load(cls, path: str) -> "BPETrainer": with open(path) as f: data = json.load(f) t = cls(vocab_size=data["vocab_size"], min_frequency=data["min_frequency"]) t.merges = [tuple(m) for m in data["merges"]] t.vocab = data["vocab"] t._id_to_tok = None return t