""" A minimal implementation of Byte-Pair Encoding (BPE) tokenization. BPE is a subword tokenization algorithm that iteratively merges the most frequent pairs of bytes or characters to build a vocabulary of subword tokens. This implementation is inspired by Andrej Karpathy's minbpe (https://github.com/karpathy/minbpe). """ import unicodedata def get_stats(ids, freq): for pair in zip(ids[:-1], ids[1:]): freq[pair] = freq.get(pair, 0) + 1 def merge(ids, pair, idx): newids = [] i = 0 while i < len(ids): if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: newids.append(idx) i += 2 else: newids.append(ids[i]) i += 1 return newids def visualise_tokens(token_values: list[bytes]) -> None: background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]] # If token boundaries do not occur at unicode character boundaries, it's unclear how best to # visualise the token. Here, we'll just use the unicode replacement character to represent some # fraction of a character. unicode_token_values = [x.decode("utf-8", errors="replace") for x in token_values] running_length = 0 last_color = None for token in unicode_token_values: color = background[running_length % len(background)] if color == last_color: color = background[(running_length + 1) % len(background)] assert color != last_color last_color = color running_length += len(token) print(color + token, end="") print("\u001b[0m") # first two helper functions... def replace_control_characters(s: str) -> str: # we don't want to print control characters # which distort the output (e.g. \n or much worse) # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117 # http://www.unicode.org/reports/tr44/#GC_Values_Table chars = [] for ch in s: if unicodedata.category(ch)[0] != "C": chars.append(ch) # this character is ok else: chars.append(f"\\u{ord(ch):04x}") # escape return "".join(chars) def render_token(t: bytes) -> str: # pretty print a token, escaping control characters s = t.decode('utf-8', errors='replace') s = replace_control_characters(s) return s #-------------------------------------------------------------------------------------------------- class Tokenizer: def __init__(self): self.merges = {} # (int, int) -> int self.pattern = "" # str self.special_tokens = {} # str -> int e.g {'<|endoftext|>': 100257} self.inverse_special_tokens = {} # int -> str self.vocab = self._build_vocab() # int -> bytes def _build_vocab(self): vocab = {idx: bytes([idx]) for idx in range(256)} for (p0, p1), idx in self.merges.items(): vocab[idx] = vocab[p0] + vocab[p1] return vocab def train(self, text, vocab_size, verbose=False): raise NotImplementedError def decode(self, ids) -> str: raise NotImplementedError def encode(self, text, verbose=False) -> list[int]: raise NotImplementedError def save(self, file_prefix): """ Saves two files: file_prefix.vocab and file_prefix.model This is inspired (but not equivalent to!) sentencepiece's model saving: - model file is the critical one, intended for load() - vocab file is just a pretty printed version for human inspection only """ # write the model: to be used in load() later model_file = file_prefix + ".model" with open(model_file, 'w') as f: # write the version, pattern and merges, that's all that's needed f.write("minbpe v1\n") f.write(f"{self.pattern}\n") # write the special tokens, first the number of them, then each one f.write(f"{len(self.special_tokens)}\n") for special, idx in self.special_tokens.items(): f.write(f"{special} {idx}\n") # the merges dict for idx1, idx2 in self.merges: f.write(f"{idx1} {idx2}\n") # write the vocab: for the human to look at vocab_file = file_prefix + ".vocab" inverted_merges = {idx: pair for pair, idx in self.merges.items()} with open(vocab_file, "w", encoding="utf-8") as f: for idx, token in self.vocab.items(): # note: many tokens may be partial utf-8 sequences # and cannot be decoded into valid strings. Here we're using # errors='replace' to replace them with the replacement char �. # this also means that we couldn't possibly use .vocab in load() # because decoding in this way is a lossy operation! s = render_token(token) # find the children of this token, if any if idx in inverted_merges: # if this token has children, render it nicely as a merge idx0, idx1 = inverted_merges[idx] s0 = render_token(self.vocab[idx0]) s1 = render_token(self.vocab[idx1]) f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n") else: # otherwise this is leaf token, just print it # (this should just be the first 256 tokens, the bytes) f.write(f"[{s}] {idx}\n") def load(self, model_file): """Inverse of save() but only for the model file""" assert model_file.endswith(".model") # read the model file merges = {} special_tokens = {} idx = 256 with open(model_file, 'r', encoding="utf-8") as f: # read the version version = f.readline().strip() assert version == "minbpe v1" # read the pattern self.pattern = f.readline().strip() # read the special tokens num_special = int(f.readline().strip()) for _ in range(num_special): special, special_idx = f.readline().strip().split() special_tokens[special] = int(special_idx) # read the merges for line in f: idx1, idx2 = map(int, line.split()) merges[(idx1, idx2)] = idx idx += 1 self.merges = merges self.special_tokens = special_tokens self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} self.vocab = self._build_vocab()