Spaces:
Sleeping
Sleeping
| from utils import get_stats, merge, render_token | |
| import regex as re | |
| GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" | |
| class Tokenizer: | |
| def __init__(self): | |
| # default: vocab size of 256 (all bytes), no merges, no patterns | |
| self.merges = {} # (int, int) -> int | |
| self.pattern = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+" # str | |
| self.compiled_pattern = re.compile(self.pattern) | |
| self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257} | |
| self.vocab = self._build_vocab() # int -> bytes | |
| self.compression_ratio = 0 | |
| def _build_vocab(self): | |
| # vocab is simply and deterministically derived from merges | |
| vocab = {idx: bytes([idx]) for idx in range(256)} | |
| for (p0, p1), idx in self.merges.items(): | |
| vocab[idx] = vocab[p0] + vocab[p1] | |
| for special, idx in self.special_tokens.items(): | |
| vocab[idx] = special.encode("utf-8") | |
| return vocab | |
| def train(self, text, vocab_size, verbose=False): | |
| assert vocab_size >= 256 | |
| text = ' '.join(self.compiled_pattern.findall(text)) | |
| num_merges = vocab_size - 256 | |
| # input text preprocessing | |
| text_bytes = text.encode("utf-8") # raw bytes | |
| ids = list(text_bytes) # list of integers in range 0..255 | |
| original_ids = ids.copy() | |
| # iteratively merge the most common pairs to create new tokens | |
| merges = {} # (int, int) -> int | |
| vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes | |
| for i in range(num_merges): | |
| # count up the number of times every consecutive pair appears | |
| stats = get_stats(ids) | |
| # find the pair with the highest count | |
| pair = max(stats, key=stats.get) | |
| # mint a new token: assign it the next available id | |
| idx = 256 + i | |
| # replace all occurrences of pair in ids with idx | |
| ids = merge(ids, pair, idx) | |
| # save the merge | |
| merges[pair] = idx | |
| vocab[idx] = vocab[pair[0]] + vocab[pair[1]] | |
| # prints | |
| if verbose: | |
| print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") | |
| # save class variables | |
| self.merges = merges # used in encode() | |
| self.vocab = vocab # used in decode() | |
| self.compression_ratio = round(len(original_ids)/len(ids), 1) | |
| def encode(self, text): | |
| # given a string text, return the token ids | |
| text_bytes = text.encode("utf-8") # raw bytes | |
| ids = list(text_bytes) # list of integers in range 0..255 | |
| while len(ids) >= 2: | |
| # find the pair with the lowest merge index | |
| stats = get_stats(ids) | |
| pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) | |
| # subtle: if there are no more merges available, the key will | |
| # result in an inf for every single pair, and the min will be | |
| # just the first pair in the list, arbitrarily | |
| # we can detect this terminating case by a membership check | |
| if pair not in self.merges: | |
| break # nothing else can be merged anymore | |
| # otherwise let's merge the best pair (lowest merge index) | |
| idx = self.merges[pair] | |
| ids = merge(ids, pair, idx) | |
| return ids | |
| def decode(self, ids): | |
| # given ids (list of integers), return Python string | |
| text_bytes = b"".join(self.vocab[idx] for idx in ids) | |
| text = text_bytes.decode("utf-8", errors="replace") | |
| return text | |
| def save(self, file_prefix): | |
| """ | |
| Saves two files: file_prefix.vocab and file_prefix.model | |
| This is inspired (but not equivalent to!) sentencepiece's model saving: | |
| - model file is the critical one, intended for load() | |
| - vocab file is just a pretty printed version for human inspection only | |
| """ | |
| # write the model: to be used in load() later | |
| model_file = file_prefix + ".model" | |
| with open(model_file, 'w') as f: | |
| # write the version, pattern and compression ratio | |
| f.write("minbpe v1\n") | |
| f.write(f"{self.pattern}\n") | |
| f.write(f"{self.compression_ratio}\n") # Save compression ratio as string | |
| # write the special tokens, first the number of them, then each one | |
| f.write(f"{len(self.special_tokens)}\n") | |
| for special, idx in self.special_tokens.items(): | |
| f.write(f"{special} {idx}\n") | |
| # the merges dict | |
| for idx1, idx2 in self.merges: | |
| f.write(f"{idx1} {idx2}\n") | |
| # write the vocab: for the human to look at | |
| vocab_file = file_prefix + ".vocab" | |
| inverted_merges = {idx: pair for pair, idx in self.merges.items()} | |
| with open(vocab_file, "w", encoding="utf-8") as f: | |
| for idx, token in self.vocab.items(): | |
| # note: many tokens may be partial utf-8 sequences | |
| # and cannot be decoded into valid strings. Here we're using | |
| # errors='replace' to replace them with the replacement char �. | |
| # this also means that we couldn't possibly use .vocab in load() | |
| # because decoding in this way is a lossy operation! | |
| s = render_token(token) | |
| # find the children of this token, if any | |
| if idx in inverted_merges: | |
| # if this token has children, render it nicely as a merge | |
| idx0, idx1 = inverted_merges[idx] | |
| s0 = render_token(self.vocab[idx0]) | |
| s1 = render_token(self.vocab[idx1]) | |
| f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n") | |
| else: | |
| # otherwise this is leaf token, just print it | |
| # (this should just be the first 256 tokens, the bytes) | |
| f.write(f"[{s}] {idx}\n") | |
| def load(self, model_file): | |
| """Inverse of save() but only for the model file""" | |
| assert model_file.endswith(".model") | |
| merges = {} | |
| special_tokens = {} | |
| idx = 256 | |
| with open(model_file, 'r', encoding="utf-8") as f: | |
| # read the version | |
| version = f.readline().strip() | |
| assert version == "minbpe v1" | |
| # read the pattern | |
| self.pattern = f.readline().strip() | |
| self.compiled_pattern = re.compile(self.pattern) | |
| # read the compression ratio safely | |
| compression_ratio_line = f.readline().strip() | |
| try: | |
| self.compression_ratio = float(compression_ratio_line) | |
| except ValueError: | |
| raise ValueError(f"Expected a float for compression ratio, got: {compression_ratio_line}") | |
| # read the special tokens count safely | |
| num_special_line = f.readline().strip() | |
| if num_special_line.isdigit(): # Ensure it's a valid integer | |
| num_special = int(num_special_line) | |
| else: | |
| raise ValueError(f"Expected an integer for number of special tokens, got: {num_special_line}") | |
| # Read special tokens if any | |
| for _ in range(num_special): | |
| line = f.readline().strip() | |
| if line: | |
| special, idx_str = line.rsplit(" ", 1) | |
| special_tokens[special] = int(idx_str) | |
| # Read merges | |
| for line in f: | |
| parts = line.split() | |
| if len(parts) == 2: | |
| idx1, idx2 = map(int, parts) | |
| merges[(idx1, idx2)] = idx | |
| idx += 1 | |
| self.merges = merges | |
| self.special_tokens = special_tokens | |
| self.vocab = self._build_vocab() | |