Spaces:
Build error
Build error
Added app file
Browse files
app.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from bpe import Tokenizer
|
| 3 |
+
import random
|
| 4 |
+
import colorsys
|
| 5 |
+
|
| 6 |
+
# Set page config
|
| 7 |
+
st.set_page_config(
|
| 8 |
+
page_title="English BPE Tokenizer Visualizer",
|
| 9 |
+
layout="wide"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
# Load the trained tokenizer
|
| 13 |
+
@st.cache_resource
|
| 14 |
+
def load_tokenizer():
|
| 15 |
+
tokenizer = Tokenizer()
|
| 16 |
+
tokenizer.load("models/EnglishBPE_6999.model.model")
|
| 17 |
+
return tokenizer
|
| 18 |
+
|
| 19 |
+
# Load example texts
|
| 20 |
+
@st.cache_data
|
| 21 |
+
def load_examples():
|
| 22 |
+
try:
|
| 23 |
+
with open("data/testdata1.txt", "r", encoding="utf-8") as f:
|
| 24 |
+
example1 = f.read().strip()
|
| 25 |
+
with open("data/testdata2.txt", "r", encoding="utf-8") as f:
|
| 26 |
+
example2 = f.read().strip()
|
| 27 |
+
except Exception as e:
|
| 28 |
+
st.error(f"Error loading example texts: {str(e)}")
|
| 29 |
+
# Fallback examples in case files can't be loaded
|
| 30 |
+
|
| 31 |
+
return example1, example2
|
| 32 |
+
|
| 33 |
+
def generate_distinct_colors(n):
|
| 34 |
+
colors = []
|
| 35 |
+
for i in range(n):
|
| 36 |
+
hue = i / n
|
| 37 |
+
saturation = 0.7 + random.random() * 0.3
|
| 38 |
+
value = 0.8 + random.random() * 0.2
|
| 39 |
+
rgb = colorsys.hsv_to_rgb(hue, saturation, value)
|
| 40 |
+
hex_color = "#{:02x}{:02x}{:02x}".format(
|
| 41 |
+
int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)
|
| 42 |
+
)
|
| 43 |
+
colors.append(hex_color)
|
| 44 |
+
return colors
|
| 45 |
+
|
| 46 |
+
def process_text(text, tokenizer):
|
| 47 |
+
try:
|
| 48 |
+
# Get tokens
|
| 49 |
+
tokens = tokenizer.encode(text)
|
| 50 |
+
|
| 51 |
+
# Generate colors for visualization
|
| 52 |
+
unique_tokens = list(set(tokens))
|
| 53 |
+
colors = generate_distinct_colors(len(unique_tokens))
|
| 54 |
+
token_colors = dict(zip(unique_tokens, colors))
|
| 55 |
+
|
| 56 |
+
# Create HTML visualization
|
| 57 |
+
html_parts = []
|
| 58 |
+
decoded_tokens = [tokenizer.decode([token]) for token in tokens]
|
| 59 |
+
|
| 60 |
+
for token, token_text in zip(tokens, decoded_tokens):
|
| 61 |
+
color = token_colors[token]
|
| 62 |
+
html_parts.append(f'<span style="background-color: {color}; padding: 0 2px; border-radius: 3px;" title="Token ID: {token}">{token_text}</span>')
|
| 63 |
+
|
| 64 |
+
return ''.join(html_parts), tokens
|
| 65 |
+
except Exception as e:
|
| 66 |
+
return f"<span style='color: red'>Error processing text: {str(e)}</span>", None
|
| 67 |
+
|
| 68 |
+
def main():
|
| 69 |
+
# Load tokenizer and examples
|
| 70 |
+
tokenizer = load_tokenizer()
|
| 71 |
+
example1, example2 = load_examples()
|
| 72 |
+
|
| 73 |
+
# Title and description
|
| 74 |
+
st.title("English BPE Tokenizer Visualizer")
|
| 75 |
+
st.markdown("Enter text to see how it gets tokenized, with color-coded visualization")
|
| 76 |
+
|
| 77 |
+
# Example selector
|
| 78 |
+
example_option = st.selectbox(
|
| 79 |
+
"Choose an example or enter your own text below:",
|
| 80 |
+
["Custom Input", "Example 1", "Example 2"]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Text input
|
| 84 |
+
if example_option == "Example 1":
|
| 85 |
+
text = st.text_area("Enter Text", value=example1, height=100)
|
| 86 |
+
elif example_option == "Example 2":
|
| 87 |
+
text = st.text_area("Enter Text", value=example2, height=100)
|
| 88 |
+
else:
|
| 89 |
+
text = st.text_area("Enter Text", height=100)
|
| 90 |
+
|
| 91 |
+
# Process button
|
| 92 |
+
if st.button("Process Text") or text:
|
| 93 |
+
if text.strip():
|
| 94 |
+
# Create two columns for output
|
| 95 |
+
col1, col2 = st.columns([2, 1])
|
| 96 |
+
|
| 97 |
+
# Process the text
|
| 98 |
+
visualization, tokens = process_text(text, tokenizer)
|
| 99 |
+
|
| 100 |
+
with col1:
|
| 101 |
+
st.subheader("Visualization")
|
| 102 |
+
st.markdown(visualization, unsafe_allow_html=True)
|
| 103 |
+
|
| 104 |
+
with col2:
|
| 105 |
+
if tokens is not None:
|
| 106 |
+
st.subheader("Token Information")
|
| 107 |
+
st.write(f"Token count: {len(tokens)}")
|
| 108 |
+
st.write("Tokens:", tokens)
|
| 109 |
+
else:
|
| 110 |
+
st.warning("Please enter some text to process.")
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
main()
|
bpe.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import get_stats, merge, render_token
|
| 2 |
+
import regex as re
|
| 3 |
+
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
| 4 |
+
class Tokenizer:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
# default: vocab size of 256 (all bytes), no merges, no patterns
|
| 7 |
+
self.merges = {} # (int, int) -> int
|
| 8 |
+
self.pattern = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+" # str
|
| 9 |
+
self.compiled_pattern = re.compile(self.pattern)
|
| 10 |
+
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
|
| 11 |
+
self.vocab = self._build_vocab() # int -> bytes
|
| 12 |
+
self.compression_ratio = 0
|
| 13 |
+
|
| 14 |
+
def _build_vocab(self):
|
| 15 |
+
# vocab is simply and deterministically derived from merges
|
| 16 |
+
vocab = {idx: bytes([idx]) for idx in range(256)}
|
| 17 |
+
for (p0, p1), idx in self.merges.items():
|
| 18 |
+
vocab[idx] = vocab[p0] + vocab[p1]
|
| 19 |
+
for special, idx in self.special_tokens.items():
|
| 20 |
+
vocab[idx] = special.encode("utf-8")
|
| 21 |
+
return vocab
|
| 22 |
+
|
| 23 |
+
def train(self, text, vocab_size, verbose=False):
|
| 24 |
+
assert vocab_size >= 256
|
| 25 |
+
text = ' '.join(self.compiled_pattern.findall(text))
|
| 26 |
+
num_merges = vocab_size - 256
|
| 27 |
+
|
| 28 |
+
# input text preprocessing
|
| 29 |
+
text_bytes = text.encode("utf-8") # raw bytes
|
| 30 |
+
ids = list(text_bytes) # list of integers in range 0..255
|
| 31 |
+
original_ids = ids.copy()
|
| 32 |
+
|
| 33 |
+
# iteratively merge the most common pairs to create new tokens
|
| 34 |
+
merges = {} # (int, int) -> int
|
| 35 |
+
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
|
| 36 |
+
for i in range(num_merges):
|
| 37 |
+
# count up the number of times every consecutive pair appears
|
| 38 |
+
stats = get_stats(ids)
|
| 39 |
+
# find the pair with the highest count
|
| 40 |
+
pair = max(stats, key=stats.get)
|
| 41 |
+
# mint a new token: assign it the next available id
|
| 42 |
+
idx = 256 + i
|
| 43 |
+
# replace all occurrences of pair in ids with idx
|
| 44 |
+
ids = merge(ids, pair, idx)
|
| 45 |
+
# save the merge
|
| 46 |
+
merges[pair] = idx
|
| 47 |
+
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
|
| 48 |
+
# prints
|
| 49 |
+
if verbose:
|
| 50 |
+
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
|
| 51 |
+
|
| 52 |
+
# save class variables
|
| 53 |
+
self.merges = merges # used in encode()
|
| 54 |
+
self.vocab = vocab # used in decode()
|
| 55 |
+
self.compression_ratio = round(len(original_ids)/len(ids), 1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def encode(self, text):
|
| 59 |
+
# given a string text, return the token ids
|
| 60 |
+
text_bytes = text.encode("utf-8") # raw bytes
|
| 61 |
+
ids = list(text_bytes) # list of integers in range 0..255
|
| 62 |
+
while len(ids) >= 2:
|
| 63 |
+
# find the pair with the lowest merge index
|
| 64 |
+
stats = get_stats(ids)
|
| 65 |
+
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
|
| 66 |
+
# subtle: if there are no more merges available, the key will
|
| 67 |
+
# result in an inf for every single pair, and the min will be
|
| 68 |
+
# just the first pair in the list, arbitrarily
|
| 69 |
+
# we can detect this terminating case by a membership check
|
| 70 |
+
if pair not in self.merges:
|
| 71 |
+
break # nothing else can be merged anymore
|
| 72 |
+
# otherwise let's merge the best pair (lowest merge index)
|
| 73 |
+
idx = self.merges[pair]
|
| 74 |
+
ids = merge(ids, pair, idx)
|
| 75 |
+
return ids
|
| 76 |
+
|
| 77 |
+
def decode(self, ids):
|
| 78 |
+
# given ids (list of integers), return Python string
|
| 79 |
+
text_bytes = b"".join(self.vocab[idx] for idx in ids)
|
| 80 |
+
text = text_bytes.decode("utf-8", errors="replace")
|
| 81 |
+
return text
|
| 82 |
+
|
| 83 |
+
def save(self, file_prefix):
|
| 84 |
+
"""
|
| 85 |
+
Saves two files: file_prefix.vocab and file_prefix.model
|
| 86 |
+
This is inspired (but not equivalent to!) sentencepiece's model saving:
|
| 87 |
+
- model file is the critical one, intended for load()
|
| 88 |
+
- vocab file is just a pretty printed version for human inspection only
|
| 89 |
+
"""
|
| 90 |
+
# write the model: to be used in load() later
|
| 91 |
+
model_file = file_prefix + ".model"
|
| 92 |
+
with open(model_file, 'w') as f:
|
| 93 |
+
# write the version, pattern and compression ratio
|
| 94 |
+
f.write("minbpe v1\n")
|
| 95 |
+
f.write(f"{self.pattern}\n")
|
| 96 |
+
f.write(f"{self.compression_ratio}\n") # Save compression ratio as string
|
| 97 |
+
# write the special tokens, first the number of them, then each one
|
| 98 |
+
f.write(f"{len(self.special_tokens)}\n")
|
| 99 |
+
for special, idx in self.special_tokens.items():
|
| 100 |
+
f.write(f"{special} {idx}\n")
|
| 101 |
+
# the merges dict
|
| 102 |
+
for idx1, idx2 in self.merges:
|
| 103 |
+
f.write(f"{idx1} {idx2}\n")
|
| 104 |
+
# write the vocab: for the human to look at
|
| 105 |
+
vocab_file = file_prefix + ".vocab"
|
| 106 |
+
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
|
| 107 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 108 |
+
for idx, token in self.vocab.items():
|
| 109 |
+
# note: many tokens may be partial utf-8 sequences
|
| 110 |
+
# and cannot be decoded into valid strings. Here we're using
|
| 111 |
+
# errors='replace' to replace them with the replacement char �.
|
| 112 |
+
# this also means that we couldn't possibly use .vocab in load()
|
| 113 |
+
# because decoding in this way is a lossy operation!
|
| 114 |
+
s = render_token(token)
|
| 115 |
+
# find the children of this token, if any
|
| 116 |
+
if idx in inverted_merges:
|
| 117 |
+
# if this token has children, render it nicely as a merge
|
| 118 |
+
idx0, idx1 = inverted_merges[idx]
|
| 119 |
+
s0 = render_token(self.vocab[idx0])
|
| 120 |
+
s1 = render_token(self.vocab[idx1])
|
| 121 |
+
f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
|
| 122 |
+
else:
|
| 123 |
+
# otherwise this is leaf token, just print it
|
| 124 |
+
# (this should just be the first 256 tokens, the bytes)
|
| 125 |
+
f.write(f"[{s}] {idx}\n")
|
| 126 |
+
|
| 127 |
+
def load(self, model_file):
|
| 128 |
+
"""Inverse of save() but only for the model file"""
|
| 129 |
+
assert model_file.endswith(".model")
|
| 130 |
+
merges = {}
|
| 131 |
+
special_tokens = {}
|
| 132 |
+
idx = 256
|
| 133 |
+
|
| 134 |
+
with open(model_file, 'r', encoding="utf-8") as f:
|
| 135 |
+
# read the version
|
| 136 |
+
version = f.readline().strip()
|
| 137 |
+
assert version == "minbpe v1"
|
| 138 |
+
|
| 139 |
+
# read the pattern
|
| 140 |
+
self.pattern = f.readline().strip()
|
| 141 |
+
self.compiled_pattern = re.compile(self.pattern)
|
| 142 |
+
|
| 143 |
+
# read the compression ratio safely
|
| 144 |
+
compression_ratio_line = f.readline().strip()
|
| 145 |
+
try:
|
| 146 |
+
self.compression_ratio = float(compression_ratio_line)
|
| 147 |
+
except ValueError:
|
| 148 |
+
raise ValueError(f"Expected a float for compression ratio, got: {compression_ratio_line}")
|
| 149 |
+
|
| 150 |
+
# read the special tokens count safely
|
| 151 |
+
num_special_line = f.readline().strip()
|
| 152 |
+
if num_special_line.isdigit(): # Ensure it's a valid integer
|
| 153 |
+
num_special = int(num_special_line)
|
| 154 |
+
else:
|
| 155 |
+
raise ValueError(f"Expected an integer for number of special tokens, got: {num_special_line}")
|
| 156 |
+
|
| 157 |
+
# Read special tokens if any
|
| 158 |
+
for _ in range(num_special):
|
| 159 |
+
line = f.readline().strip()
|
| 160 |
+
if line:
|
| 161 |
+
special, idx_str = line.rsplit(" ", 1)
|
| 162 |
+
special_tokens[special] = int(idx_str)
|
| 163 |
+
|
| 164 |
+
# Read merges
|
| 165 |
+
for line in f:
|
| 166 |
+
parts = line.split()
|
| 167 |
+
if len(parts) == 2:
|
| 168 |
+
idx1, idx2 = map(int, parts)
|
| 169 |
+
merges[(idx1, idx2)] = idx
|
| 170 |
+
idx += 1
|
| 171 |
+
|
| 172 |
+
self.merges = merges
|
| 173 |
+
self.special_tokens = special_tokens
|
| 174 |
+
self.vocab = self._build_vocab()
|
| 175 |
+
|
| 176 |
+
|
utils.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unicodedata
|
| 2 |
+
def get_stats(ids, counts=None):
|
| 3 |
+
"""
|
| 4 |
+
Given a list of integers, return a dictionary of counts of consecutive pairs
|
| 5 |
+
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
|
| 6 |
+
Optionally allows to update an existing dictionary of counts
|
| 7 |
+
"""
|
| 8 |
+
counts = {} if counts is None else counts
|
| 9 |
+
for pair in zip(ids, ids[1:]): # iterate consecutive elements
|
| 10 |
+
counts[pair] = counts.get(pair, 0) + 1
|
| 11 |
+
return counts
|
| 12 |
+
|
| 13 |
+
def merge(ids, pair, idx):
|
| 14 |
+
"""
|
| 15 |
+
In the list of integers (ids), replace all consecutive occurrences
|
| 16 |
+
of pair with the new integer token idx
|
| 17 |
+
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
|
| 18 |
+
"""
|
| 19 |
+
newids = []
|
| 20 |
+
i = 0
|
| 21 |
+
while i < len(ids):
|
| 22 |
+
# if not at the very last position AND the pair matches, replace it
|
| 23 |
+
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
|
| 24 |
+
newids.append(idx)
|
| 25 |
+
i += 2
|
| 26 |
+
else:
|
| 27 |
+
newids.append(ids[i])
|
| 28 |
+
i += 1
|
| 29 |
+
return newids
|
| 30 |
+
|
| 31 |
+
# first two helper functions...
|
| 32 |
+
def replace_control_characters(s: str) -> str:
|
| 33 |
+
# we don't want to print control characters
|
| 34 |
+
# which distort the output (e.g. \n or much worse)
|
| 35 |
+
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
|
| 36 |
+
# http://www.unicode.org/reports/tr44/#GC_Values_Table
|
| 37 |
+
chars = []
|
| 38 |
+
for ch in s:
|
| 39 |
+
if unicodedata.category(ch)[0] != "C":
|
| 40 |
+
chars.append(ch) # this character is ok
|
| 41 |
+
else:
|
| 42 |
+
chars.append(f"\\u{ord(ch):04x}") # escape
|
| 43 |
+
return "".join(chars)
|
| 44 |
+
|
| 45 |
+
def render_token(t: bytes) -> str:
|
| 46 |
+
# pretty print a token, escaping control characters
|
| 47 |
+
s = t.decode('utf-8', errors='replace')
|
| 48 |
+
s = replace_control_characters(s)
|
| 49 |
+
return s
|
| 50 |
+
|
| 51 |
+
|