BPETokenizer / src /bpe.py
mathminakshi's picture
added src folder
d6ae8d7 verified
from utils import get_stats, merge, render_token
import regex as re
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
class Tokenizer:
def __init__(self):
# default: vocab size of 256 (all bytes), no merges, no patterns
self.merges = {} # (int, int) -> int
self.pattern = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+" # str
self.compiled_pattern = re.compile(self.pattern)
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
self.vocab = self._build_vocab() # int -> bytes
self.compression_ratio = 0
def _build_vocab(self):
# vocab is simply and deterministically derived from merges
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
text = ' '.join(self.compiled_pattern.findall(text))
num_merges = vocab_size - 256
# input text preprocessing
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
original_ids = ids.copy()
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
for i in range(num_merges):
# count up the number of times every consecutive pair appears
stats = get_stats(ids)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 256 + i
# replace all occurrences of pair in ids with idx
ids = merge(ids, pair, idx)
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
self.compression_ratio = round(len(original_ids)/len(ids), 1)
def encode(self, text):
# given a string text, return the token ids
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
def decode(self, ids):
# given ids (list of integers), return Python string
text_bytes = b"".join(self.vocab[idx] for idx in ids)
text = text_bytes.decode("utf-8", errors="replace")
return text
def save(self, file_prefix):
"""
Saves two files: file_prefix.vocab and file_prefix.model
This is inspired (but not equivalent to!) sentencepiece's model saving:
- model file is the critical one, intended for load()
- vocab file is just a pretty printed version for human inspection only
"""
# write the model: to be used in load() later
model_file = file_prefix + ".model"
with open(model_file, 'w') as f:
# write the version, pattern and compression ratio
f.write("minbpe v1\n")
f.write(f"{self.pattern}\n")
f.write(f"{self.compression_ratio}\n") # Save compression ratio as string
# write the special tokens, first the number of them, then each one
f.write(f"{len(self.special_tokens)}\n")
for special, idx in self.special_tokens.items():
f.write(f"{special} {idx}\n")
# the merges dict
for idx1, idx2 in self.merges:
f.write(f"{idx1} {idx2}\n")
# write the vocab: for the human to look at
vocab_file = file_prefix + ".vocab"
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
with open(vocab_file, "w", encoding="utf-8") as f:
for idx, token in self.vocab.items():
# note: many tokens may be partial utf-8 sequences
# and cannot be decoded into valid strings. Here we're using
# errors='replace' to replace them with the replacement char �.
# this also means that we couldn't possibly use .vocab in load()
# because decoding in this way is a lossy operation!
s = render_token(token)
# find the children of this token, if any
if idx in inverted_merges:
# if this token has children, render it nicely as a merge
idx0, idx1 = inverted_merges[idx]
s0 = render_token(self.vocab[idx0])
s1 = render_token(self.vocab[idx1])
f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
else:
# otherwise this is leaf token, just print it
# (this should just be the first 256 tokens, the bytes)
f.write(f"[{s}] {idx}\n")
def load(self, model_file):
"""Inverse of save() but only for the model file"""
assert model_file.endswith(".model")
merges = {}
special_tokens = {}
idx = 256
with open(model_file, 'r', encoding="utf-8") as f:
# read the version
version = f.readline().strip()
assert version == "minbpe v1"
# read the pattern
self.pattern = f.readline().strip()
self.compiled_pattern = re.compile(self.pattern)
# read the compression ratio safely
compression_ratio_line = f.readline().strip()
try:
self.compression_ratio = float(compression_ratio_line)
except ValueError:
raise ValueError(f"Expected a float for compression ratio, got: {compression_ratio_line}")
# read the special tokens count safely
num_special_line = f.readline().strip()
if num_special_line.isdigit(): # Ensure it's a valid integer
num_special = int(num_special_line)
else:
raise ValueError(f"Expected an integer for number of special tokens, got: {num_special_line}")
# Read special tokens if any
for _ in range(num_special):
line = f.readline().strip()
if line:
special, idx_str = line.rsplit(" ", 1)
special_tokens[special] = int(idx_str)
# Read merges
for line in f:
parts = line.split()
if len(parts) == 2:
idx1, idx2 = map(int, parts)
merges[(idx1, idx2)] = idx
idx += 1
self.merges = merges
self.special_tokens = special_tokens
self.vocab = self._build_vocab()