| | """ |
| | Byte-Pair Encoding trainer and codec optimized for JSON value strings. |
| | |
| | Uses incremental pair counting with pair→word index for fast merges. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import json |
| | import re |
| | from collections import defaultdict |
| | from typing import Optional |
| |
|
| |
|
| | def _bytes_to_unicode() -> dict[int, str]: |
| | """Map bytes 0-255 to unicode chars, avoiding control/whitespace collisions.""" |
| | bs = ( |
| | list(range(ord("!"), ord("~") + 1)) |
| | + list(range(ord("¡"), ord("¬") + 1)) |
| | + list(range(ord("®"), ord("ÿ") + 1)) |
| | ) |
| | cs = bs[:] |
| | n = 0 |
| | for b in range(2**8): |
| | if b not in bs: |
| | bs.append(b) |
| | cs.append(2**8 + n) |
| | n += 1 |
| | return {b: chr(c) for b, c in zip(bs, cs)} |
| |
|
| |
|
| | BYTE_ENCODER = _bytes_to_unicode() |
| | BYTE_DECODER = {v: k for k, v in BYTE_ENCODER.items()} |
| |
|
| | _PRE_TOK_PAT = re.compile( |
| | r"""'s|'t|'re|'ve|'m|'ll|'d| ?[a-zA-Z_]+| ?[0-9]+| ?[^\s\w]+|\s+|.""" |
| | ) |
| |
|
| |
|
| | class BPETrainer: |
| | """Train a BPE vocabulary from a corpus of JSON value strings.""" |
| |
|
| | def __init__(self, vocab_size: int = 4096, min_frequency: int = 2): |
| | self.vocab_size = vocab_size |
| | self.min_frequency = min_frequency |
| | self.merges: list[tuple[str, str]] = [] |
| | self.vocab: dict[str, int] = {} |
| | self._id_to_tok: dict[int, str] | None = None |
| |
|
| | def _pre_tokenize(self, text: str) -> list[str]: |
| | return _PRE_TOK_PAT.findall(text) |
| |
|
| | def _text_to_bytes(self, text: str) -> tuple[str, ...]: |
| | return tuple(BYTE_ENCODER[b] for b in text.encode("utf-8")) |
| |
|
| | def train(self, texts: list[str]) -> None: |
| | """Train BPE with pair→word index for O(affected) merges.""" |
| | |
| | word_freqs: dict[tuple[str, ...], int] = {} |
| | for text in texts: |
| | for word in self._pre_tokenize(text): |
| | bw = self._text_to_bytes(word) |
| | word_freqs[bw] = word_freqs.get(bw, 0) + 1 |
| |
|
| | |
| | base_vocab: set[str] = set() |
| | for word in word_freqs: |
| | base_vocab.update(word) |
| |
|
| | num_merges = self.vocab_size - len(base_vocab) - 1 |
| |
|
| | |
| | words: list[list[str]] = [] |
| | freqs: list[int] = [] |
| | for w, f in word_freqs.items(): |
| | words.append(list(w)) |
| | freqs.append(f) |
| |
|
| | |
| | pair_counts: dict[tuple[str, str], int] = defaultdict(int) |
| | pair_to_words: dict[tuple[str, str], set[int]] = defaultdict(set) |
| |
|
| | for idx, (w, f) in enumerate(zip(words, freqs)): |
| | for i in range(len(w) - 1): |
| | p = (w[i], w[i + 1]) |
| | pair_counts[p] += f |
| | pair_to_words[p].add(idx) |
| |
|
| | for _ in range(max(0, num_merges)): |
| | if not pair_counts: |
| | break |
| |
|
| | |
| | best_pair = max(pair_counts, key=pair_counts.__getitem__) |
| | if pair_counts[best_pair] < self.min_frequency: |
| | break |
| |
|
| | a, b = best_pair |
| | merged = a + b |
| | self.merges.append(best_pair) |
| |
|
| | |
| | affected = list(pair_to_words.pop(best_pair, set())) |
| | del pair_counts[best_pair] |
| |
|
| | for idx in affected: |
| | w = words[idx] |
| | f = freqs[idx] |
| |
|
| | |
| | new_w: list[str] = [] |
| | i = 0 |
| | while i < len(w): |
| | if i < len(w) - 1 and w[i] == a and w[i + 1] == b: |
| | |
| | if new_w: |
| | old_left = (new_w[-1], a) |
| | pair_counts[old_left] -= f |
| | if pair_counts[old_left] <= 0: |
| | pair_counts.pop(old_left, None) |
| | pair_to_words[old_left].discard(idx) |
| |
|
| | if i + 2 < len(w): |
| | old_right = (b, w[i + 2]) |
| | pair_counts[old_right] -= f |
| | if pair_counts[old_right] <= 0: |
| | pair_counts.pop(old_right, None) |
| | pair_to_words[old_right].discard(idx) |
| |
|
| | new_w.append(merged) |
| |
|
| | |
| | if len(new_w) >= 2: |
| | nl = (new_w[-2], merged) |
| | pair_counts[nl] += f |
| | pair_to_words[nl].add(idx) |
| |
|
| | if i + 2 < len(w): |
| | nr = (merged, w[i + 2]) |
| | pair_counts[nr] += f |
| | pair_to_words[nr].add(idx) |
| |
|
| | i += 2 |
| | else: |
| | new_w.append(w[i]) |
| | i += 1 |
| |
|
| | words[idx] = new_w |
| |
|
| | |
| | if _ % 50 == 0: |
| | pair_counts = defaultdict(int, {k: v for k, v in pair_counts.items() if v > 0}) |
| |
|
| | |
| | self.vocab = {} |
| | idx = 0 |
| | for ch in sorted(base_vocab): |
| | self.vocab[ch] = idx |
| | idx += 1 |
| | for merge in self.merges: |
| | m = merge[0] + merge[1] |
| | if m not in self.vocab: |
| | self.vocab[m] = idx |
| | idx += 1 |
| | self.vocab["<UNK>"] = idx |
| | self._id_to_tok = None |
| |
|
| | def _apply_merge(self, word: tuple[str, ...], pair: tuple[str, str]) -> tuple[str, ...]: |
| | new: list[str] = [] |
| | i = 0 |
| | while i < len(word): |
| | if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]: |
| | new.append(pair[0] + pair[1]) |
| | i += 2 |
| | else: |
| | new.append(word[i]) |
| | i += 1 |
| | return tuple(new) |
| |
|
| | def encode_word(self, word: str) -> list[str]: |
| | bw = self._text_to_bytes(word) |
| | if len(bw) == 1: |
| | return [bw[0]] |
| | for merge in self.merges: |
| | bw = self._apply_merge(bw, merge) |
| | return list(bw) |
| |
|
| | def encode(self, text: str) -> list[str]: |
| | tokens: list[str] = [] |
| | for word in self._pre_tokenize(text): |
| | tokens.extend(self.encode_word(word)) |
| | return tokens |
| |
|
| | def encode_to_ids(self, text: str) -> list[int]: |
| | tokens = self.encode(text) |
| | unk_id = self.vocab.get("<UNK>", 0) |
| | return [self.vocab.get(t, unk_id) for t in tokens] |
| |
|
| | def id_to_token(self, token_id: int) -> str: |
| | if self._id_to_tok is None: |
| | self._id_to_tok = {v: k for k, v in self.vocab.items()} |
| | return self._id_to_tok.get(token_id, "<UNK>") |
| |
|
| | def decode_ids(self, ids: list[int]) -> str: |
| | return self.decode_tokens([self.id_to_token(i) for i in ids]) |
| |
|
| | def decode_tokens(self, tokens: list[str]) -> str: |
| | byte_str = "".join(tokens) |
| | return bytearray(BYTE_DECODER.get(c, ord(c)) for c in byte_str).decode("utf-8", errors="replace") |
| |
|
| | def save(self, path: str) -> None: |
| | with open(path, "w") as f: |
| | json.dump({ |
| | "version": "json-tokenizer-bpe-v1", |
| | "vocab_size": self.vocab_size, |
| | "min_frequency": self.min_frequency, |
| | "merges": [list(m) for m in self.merges], |
| | "vocab": self.vocab, |
| | }, f, indent=2) |
| |
|
| | @classmethod |
| | def load(cls, path: str) -> "BPETrainer": |
| | with open(path) as f: |
| | data = json.load(f) |
| | t = cls(vocab_size=data["vocab_size"], min_frequency=data["min_frequency"]) |
| | t.merges = [tuple(m) for m in data["merges"]] |
| | t.vocab = data["vocab"] |
| | t._id_to_tok = None |
| | return t |
| |
|