| | """ |
| | A minimal implementation of Byte-Pair Encoding (BPE) tokenization. |
| | |
| | BPE is a subword tokenization algorithm that iteratively merges the most frequent pairs of bytes or characters |
| | to build a vocabulary of subword tokens. This implementation is inspired by Andrej Karpathy's minbpe |
| | (https://github.com/karpathy/minbpe). |
| | """ |
| | import unicodedata |
| |
|
| | def get_stats(ids, freq): |
| | for pair in zip(ids[:-1], ids[1:]): |
| | freq[pair] = freq.get(pair, 0) + 1 |
| |
|
| | def merge(ids, pair, idx): |
| | newids = [] |
| | i = 0 |
| | while i < len(ids): |
| | if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: |
| | newids.append(idx) |
| | i += 2 |
| | else: |
| | newids.append(ids[i]) |
| | i += 1 |
| | return newids |
| |
|
| | def visualise_tokens(token_values: list[bytes]) -> None: |
| | background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]] |
| | |
| | |
| | |
| | unicode_token_values = [x.decode("utf-8", errors="replace") for x in token_values] |
| |
|
| | running_length = 0 |
| | last_color = None |
| | for token in unicode_token_values: |
| | color = background[running_length % len(background)] |
| | if color == last_color: |
| | color = background[(running_length + 1) % len(background)] |
| | assert color != last_color |
| | last_color = color |
| | running_length += len(token) |
| | print(color + token, end="") |
| | print("\u001b[0m") |
| |
|
| | |
| | def replace_control_characters(s: str) -> str: |
| | |
| | |
| | |
| | |
| | chars = [] |
| | for ch in s: |
| | if unicodedata.category(ch)[0] != "C": |
| | chars.append(ch) |
| | else: |
| | chars.append(f"\\u{ord(ch):04x}") |
| | return "".join(chars) |
| |
|
| | def render_token(t: bytes) -> str: |
| | |
| | s = t.decode('utf-8', errors='replace') |
| | s = replace_control_characters(s) |
| | return s |
| |
|
| | |
| | class Tokenizer: |
| | def __init__(self): |
| | self.merges = {} |
| | self.pattern = "" |
| | self.special_tokens = {} |
| | self.inverse_special_tokens = {} |
| | self.vocab = self._build_vocab() |
| | |
| | def _build_vocab(self): |
| | vocab = {idx: bytes([idx]) for idx in range(256)} |
| | for (p0, p1), idx in self.merges.items(): |
| | vocab[idx] = vocab[p0] + vocab[p1] |
| | return vocab |
| | |
| | def train(self, text, vocab_size, verbose=False): |
| | raise NotImplementedError |
| | |
| | def decode(self, ids) -> str: |
| | raise NotImplementedError |
| | |
| | def encode(self, text, verbose=False) -> list[int]: |
| | raise NotImplementedError |
| | |
| | def save(self, file_prefix): |
| | """ |
| | Saves two files: file_prefix.vocab and file_prefix.model |
| | This is inspired (but not equivalent to!) sentencepiece's model saving: |
| | - model file is the critical one, intended for load() |
| | - vocab file is just a pretty printed version for human inspection only |
| | """ |
| | |
| | model_file = file_prefix + ".model" |
| | with open(model_file, 'w') as f: |
| | |
| | f.write("minbpe v1\n") |
| | f.write(f"{self.pattern}\n") |
| | |
| | f.write(f"{len(self.special_tokens)}\n") |
| | for special, idx in self.special_tokens.items(): |
| | f.write(f"{special} {idx}\n") |
| | |
| | for idx1, idx2 in self.merges: |
| | f.write(f"{idx1} {idx2}\n") |
| | |
| | vocab_file = file_prefix + ".vocab" |
| | inverted_merges = {idx: pair for pair, idx in self.merges.items()} |
| | with open(vocab_file, "w", encoding="utf-8") as f: |
| | for idx, token in self.vocab.items(): |
| | |
| | |
| | |
| | |
| | |
| | s = render_token(token) |
| | |
| | if idx in inverted_merges: |
| | |
| | idx0, idx1 = inverted_merges[idx] |
| | s0 = render_token(self.vocab[idx0]) |
| | s1 = render_token(self.vocab[idx1]) |
| | f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n") |
| | else: |
| | |
| | |
| | f.write(f"[{s}] {idx}\n") |
| |
|
| | def load(self, model_file): |
| | """Inverse of save() but only for the model file""" |
| | assert model_file.endswith(".model") |
| | |
| | merges = {} |
| | special_tokens = {} |
| | idx = 256 |
| | with open(model_file, 'r', encoding="utf-8") as f: |
| | |
| | version = f.readline().strip() |
| | assert version == "minbpe v1" |
| | |
| | self.pattern = f.readline().strip() |
| | |
| | num_special = int(f.readline().strip()) |
| | for _ in range(num_special): |
| | special, special_idx = f.readline().strip().split() |
| | special_tokens[special] = int(special_idx) |
| | |
| | for line in f: |
| | idx1, idx2 = map(int, line.split()) |
| | merges[(idx1, idx2)] = idx |
| | idx += 1 |
| | self.merges = merges |
| | self.special_tokens = special_tokens |
| | self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} |
| | self.vocab = self._build_vocab() |
| |
|