| """ |
| A minimal implementation of Byte-Pair Encoding (BPE) tokenization. |
| |
| BPE is a subword tokenization algorithm that iteratively merges the most frequent pairs of bytes or characters |
| to build a vocabulary of subword tokens. This implementation is inspired by Andrej Karpathy's minbpe |
| (https://github.com/karpathy/minbpe). |
| """ |
| import unicodedata |
|
|
| def get_stats(ids, freq): |
| for pair in zip(ids[:-1], ids[1:]): |
| freq[pair] = freq.get(pair, 0) + 1 |
|
|
| def merge(ids, pair, idx): |
| newids = [] |
| i = 0 |
| while i < len(ids): |
| if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: |
| newids.append(idx) |
| i += 2 |
| else: |
| newids.append(ids[i]) |
| i += 1 |
| return newids |
|
|
| def visualise_tokens(token_values: list[bytes]) -> None: |
| background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]] |
| |
| |
| |
| unicode_token_values = [x.decode("utf-8", errors="replace") for x in token_values] |
|
|
| running_length = 0 |
| last_color = None |
| for token in unicode_token_values: |
| color = background[running_length % len(background)] |
| if color == last_color: |
| color = background[(running_length + 1) % len(background)] |
| assert color != last_color |
| last_color = color |
| running_length += len(token) |
| print(color + token, end="") |
| print("\u001b[0m") |
|
|
| |
| def replace_control_characters(s: str) -> str: |
| |
| |
| |
| |
| chars = [] |
| for ch in s: |
| if unicodedata.category(ch)[0] != "C": |
| chars.append(ch) |
| else: |
| chars.append(f"\\u{ord(ch):04x}") |
| return "".join(chars) |
|
|
| def render_token(t: bytes) -> str: |
| |
| s = t.decode('utf-8', errors='replace') |
| s = replace_control_characters(s) |
| return s |
|
|
| |
| class Tokenizer: |
| def __init__(self): |
| self.merges = {} |
| self.pattern = "" |
| self.special_tokens = {} |
| self.inverse_special_tokens = {} |
| self.vocab = self._build_vocab() |
| |
| def _build_vocab(self): |
| vocab = {idx: bytes([idx]) for idx in range(256)} |
| for (p0, p1), idx in self.merges.items(): |
| vocab[idx] = vocab[p0] + vocab[p1] |
| return vocab |
| |
| def train(self, text, vocab_size, verbose=False): |
| raise NotImplementedError |
| |
| def decode(self, ids) -> str: |
| raise NotImplementedError |
| |
| def encode(self, text, verbose=False) -> list[int]: |
| raise NotImplementedError |
| |
| def save(self, file_prefix): |
| """ |
| Saves two files: file_prefix.vocab and file_prefix.model |
| This is inspired (but not equivalent to!) sentencepiece's model saving: |
| - model file is the critical one, intended for load() |
| - vocab file is just a pretty printed version for human inspection only |
| """ |
| |
| model_file = file_prefix + ".model" |
| with open(model_file, 'w') as f: |
| |
| f.write("minbpe v1\n") |
| f.write(f"{self.pattern}\n") |
| |
| f.write(f"{len(self.special_tokens)}\n") |
| for special, idx in self.special_tokens.items(): |
| f.write(f"{special} {idx}\n") |
| |
| for idx1, idx2 in self.merges: |
| f.write(f"{idx1} {idx2}\n") |
| |
| vocab_file = file_prefix + ".vocab" |
| inverted_merges = {idx: pair for pair, idx in self.merges.items()} |
| with open(vocab_file, "w", encoding="utf-8") as f: |
| for idx, token in self.vocab.items(): |
| |
| |
| |
| |
| |
| s = render_token(token) |
| |
| if idx in inverted_merges: |
| |
| idx0, idx1 = inverted_merges[idx] |
| s0 = render_token(self.vocab[idx0]) |
| s1 = render_token(self.vocab[idx1]) |
| f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n") |
| else: |
| |
| |
| f.write(f"[{s}] {idx}\n") |
|
|
| def load(self, model_file): |
| """Inverse of save() but only for the model file""" |
| assert model_file.endswith(".model") |
| |
| merges = {} |
| special_tokens = {} |
| idx = 256 |
| with open(model_file, 'r', encoding="utf-8") as f: |
| |
| version = f.readline().strip() |
| assert version == "minbpe v1" |
| |
| self.pattern = f.readline().strip() |
| |
| num_special = int(f.readline().strip()) |
| for _ in range(num_special): |
| special, special_idx = f.readline().strip().split() |
| special_tokens[special] = int(special_idx) |
| |
| for line in f: |
| idx1, idx2 = map(int, line.split()) |
| merges[(idx1, idx2)] = idx |
| idx += 1 |
| self.merges = merges |
| self.special_tokens = special_tokens |
| self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} |
| self.vocab = self._build_vocab() |
|
|