mathminakshi commited on
Commit
2bb89f8
·
verified ·
1 Parent(s): 8bb8395

Added app file

Browse files
Files changed (3) hide show
  1. app.py +113 -0
  2. bpe.py +176 -0
  3. utils.py +51 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from bpe import Tokenizer
3
+ import random
4
+ import colorsys
5
+
6
+ # Set page config
7
+ st.set_page_config(
8
+ page_title="English BPE Tokenizer Visualizer",
9
+ layout="wide"
10
+ )
11
+
12
+ # Load the trained tokenizer
13
+ @st.cache_resource
14
+ def load_tokenizer():
15
+ tokenizer = Tokenizer()
16
+ tokenizer.load("models/EnglishBPE_6999.model.model")
17
+ return tokenizer
18
+
19
+ # Load example texts
20
+ @st.cache_data
21
+ def load_examples():
22
+ try:
23
+ with open("data/testdata1.txt", "r", encoding="utf-8") as f:
24
+ example1 = f.read().strip()
25
+ with open("data/testdata2.txt", "r", encoding="utf-8") as f:
26
+ example2 = f.read().strip()
27
+ except Exception as e:
28
+ st.error(f"Error loading example texts: {str(e)}")
29
+ # Fallback examples in case files can't be loaded
30
+
31
+ return example1, example2
32
+
33
+ def generate_distinct_colors(n):
34
+ colors = []
35
+ for i in range(n):
36
+ hue = i / n
37
+ saturation = 0.7 + random.random() * 0.3
38
+ value = 0.8 + random.random() * 0.2
39
+ rgb = colorsys.hsv_to_rgb(hue, saturation, value)
40
+ hex_color = "#{:02x}{:02x}{:02x}".format(
41
+ int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)
42
+ )
43
+ colors.append(hex_color)
44
+ return colors
45
+
46
+ def process_text(text, tokenizer):
47
+ try:
48
+ # Get tokens
49
+ tokens = tokenizer.encode(text)
50
+
51
+ # Generate colors for visualization
52
+ unique_tokens = list(set(tokens))
53
+ colors = generate_distinct_colors(len(unique_tokens))
54
+ token_colors = dict(zip(unique_tokens, colors))
55
+
56
+ # Create HTML visualization
57
+ html_parts = []
58
+ decoded_tokens = [tokenizer.decode([token]) for token in tokens]
59
+
60
+ for token, token_text in zip(tokens, decoded_tokens):
61
+ color = token_colors[token]
62
+ html_parts.append(f'<span style="background-color: {color}; padding: 0 2px; border-radius: 3px;" title="Token ID: {token}">{token_text}</span>')
63
+
64
+ return ''.join(html_parts), tokens
65
+ except Exception as e:
66
+ return f"<span style='color: red'>Error processing text: {str(e)}</span>", None
67
+
68
+ def main():
69
+ # Load tokenizer and examples
70
+ tokenizer = load_tokenizer()
71
+ example1, example2 = load_examples()
72
+
73
+ # Title and description
74
+ st.title("English BPE Tokenizer Visualizer")
75
+ st.markdown("Enter text to see how it gets tokenized, with color-coded visualization")
76
+
77
+ # Example selector
78
+ example_option = st.selectbox(
79
+ "Choose an example or enter your own text below:",
80
+ ["Custom Input", "Example 1", "Example 2"]
81
+ )
82
+
83
+ # Text input
84
+ if example_option == "Example 1":
85
+ text = st.text_area("Enter Text", value=example1, height=100)
86
+ elif example_option == "Example 2":
87
+ text = st.text_area("Enter Text", value=example2, height=100)
88
+ else:
89
+ text = st.text_area("Enter Text", height=100)
90
+
91
+ # Process button
92
+ if st.button("Process Text") or text:
93
+ if text.strip():
94
+ # Create two columns for output
95
+ col1, col2 = st.columns([2, 1])
96
+
97
+ # Process the text
98
+ visualization, tokens = process_text(text, tokenizer)
99
+
100
+ with col1:
101
+ st.subheader("Visualization")
102
+ st.markdown(visualization, unsafe_allow_html=True)
103
+
104
+ with col2:
105
+ if tokens is not None:
106
+ st.subheader("Token Information")
107
+ st.write(f"Token count: {len(tokens)}")
108
+ st.write("Tokens:", tokens)
109
+ else:
110
+ st.warning("Please enter some text to process.")
111
+
112
+ if __name__ == "__main__":
113
+ main()
bpe.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import get_stats, merge, render_token
2
+ import regex as re
3
+ GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
4
+ class Tokenizer:
5
+ def __init__(self):
6
+ # default: vocab size of 256 (all bytes), no merges, no patterns
7
+ self.merges = {} # (int, int) -> int
8
+ self.pattern = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+" # str
9
+ self.compiled_pattern = re.compile(self.pattern)
10
+ self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
11
+ self.vocab = self._build_vocab() # int -> bytes
12
+ self.compression_ratio = 0
13
+
14
+ def _build_vocab(self):
15
+ # vocab is simply and deterministically derived from merges
16
+ vocab = {idx: bytes([idx]) for idx in range(256)}
17
+ for (p0, p1), idx in self.merges.items():
18
+ vocab[idx] = vocab[p0] + vocab[p1]
19
+ for special, idx in self.special_tokens.items():
20
+ vocab[idx] = special.encode("utf-8")
21
+ return vocab
22
+
23
+ def train(self, text, vocab_size, verbose=False):
24
+ assert vocab_size >= 256
25
+ text = ' '.join(self.compiled_pattern.findall(text))
26
+ num_merges = vocab_size - 256
27
+
28
+ # input text preprocessing
29
+ text_bytes = text.encode("utf-8") # raw bytes
30
+ ids = list(text_bytes) # list of integers in range 0..255
31
+ original_ids = ids.copy()
32
+
33
+ # iteratively merge the most common pairs to create new tokens
34
+ merges = {} # (int, int) -> int
35
+ vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
36
+ for i in range(num_merges):
37
+ # count up the number of times every consecutive pair appears
38
+ stats = get_stats(ids)
39
+ # find the pair with the highest count
40
+ pair = max(stats, key=stats.get)
41
+ # mint a new token: assign it the next available id
42
+ idx = 256 + i
43
+ # replace all occurrences of pair in ids with idx
44
+ ids = merge(ids, pair, idx)
45
+ # save the merge
46
+ merges[pair] = idx
47
+ vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
48
+ # prints
49
+ if verbose:
50
+ print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
51
+
52
+ # save class variables
53
+ self.merges = merges # used in encode()
54
+ self.vocab = vocab # used in decode()
55
+ self.compression_ratio = round(len(original_ids)/len(ids), 1)
56
+
57
+
58
+ def encode(self, text):
59
+ # given a string text, return the token ids
60
+ text_bytes = text.encode("utf-8") # raw bytes
61
+ ids = list(text_bytes) # list of integers in range 0..255
62
+ while len(ids) >= 2:
63
+ # find the pair with the lowest merge index
64
+ stats = get_stats(ids)
65
+ pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
66
+ # subtle: if there are no more merges available, the key will
67
+ # result in an inf for every single pair, and the min will be
68
+ # just the first pair in the list, arbitrarily
69
+ # we can detect this terminating case by a membership check
70
+ if pair not in self.merges:
71
+ break # nothing else can be merged anymore
72
+ # otherwise let's merge the best pair (lowest merge index)
73
+ idx = self.merges[pair]
74
+ ids = merge(ids, pair, idx)
75
+ return ids
76
+
77
+ def decode(self, ids):
78
+ # given ids (list of integers), return Python string
79
+ text_bytes = b"".join(self.vocab[idx] for idx in ids)
80
+ text = text_bytes.decode("utf-8", errors="replace")
81
+ return text
82
+
83
+ def save(self, file_prefix):
84
+ """
85
+ Saves two files: file_prefix.vocab and file_prefix.model
86
+ This is inspired (but not equivalent to!) sentencepiece's model saving:
87
+ - model file is the critical one, intended for load()
88
+ - vocab file is just a pretty printed version for human inspection only
89
+ """
90
+ # write the model: to be used in load() later
91
+ model_file = file_prefix + ".model"
92
+ with open(model_file, 'w') as f:
93
+ # write the version, pattern and compression ratio
94
+ f.write("minbpe v1\n")
95
+ f.write(f"{self.pattern}\n")
96
+ f.write(f"{self.compression_ratio}\n") # Save compression ratio as string
97
+ # write the special tokens, first the number of them, then each one
98
+ f.write(f"{len(self.special_tokens)}\n")
99
+ for special, idx in self.special_tokens.items():
100
+ f.write(f"{special} {idx}\n")
101
+ # the merges dict
102
+ for idx1, idx2 in self.merges:
103
+ f.write(f"{idx1} {idx2}\n")
104
+ # write the vocab: for the human to look at
105
+ vocab_file = file_prefix + ".vocab"
106
+ inverted_merges = {idx: pair for pair, idx in self.merges.items()}
107
+ with open(vocab_file, "w", encoding="utf-8") as f:
108
+ for idx, token in self.vocab.items():
109
+ # note: many tokens may be partial utf-8 sequences
110
+ # and cannot be decoded into valid strings. Here we're using
111
+ # errors='replace' to replace them with the replacement char �.
112
+ # this also means that we couldn't possibly use .vocab in load()
113
+ # because decoding in this way is a lossy operation!
114
+ s = render_token(token)
115
+ # find the children of this token, if any
116
+ if idx in inverted_merges:
117
+ # if this token has children, render it nicely as a merge
118
+ idx0, idx1 = inverted_merges[idx]
119
+ s0 = render_token(self.vocab[idx0])
120
+ s1 = render_token(self.vocab[idx1])
121
+ f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
122
+ else:
123
+ # otherwise this is leaf token, just print it
124
+ # (this should just be the first 256 tokens, the bytes)
125
+ f.write(f"[{s}] {idx}\n")
126
+
127
+ def load(self, model_file):
128
+ """Inverse of save() but only for the model file"""
129
+ assert model_file.endswith(".model")
130
+ merges = {}
131
+ special_tokens = {}
132
+ idx = 256
133
+
134
+ with open(model_file, 'r', encoding="utf-8") as f:
135
+ # read the version
136
+ version = f.readline().strip()
137
+ assert version == "minbpe v1"
138
+
139
+ # read the pattern
140
+ self.pattern = f.readline().strip()
141
+ self.compiled_pattern = re.compile(self.pattern)
142
+
143
+ # read the compression ratio safely
144
+ compression_ratio_line = f.readline().strip()
145
+ try:
146
+ self.compression_ratio = float(compression_ratio_line)
147
+ except ValueError:
148
+ raise ValueError(f"Expected a float for compression ratio, got: {compression_ratio_line}")
149
+
150
+ # read the special tokens count safely
151
+ num_special_line = f.readline().strip()
152
+ if num_special_line.isdigit(): # Ensure it's a valid integer
153
+ num_special = int(num_special_line)
154
+ else:
155
+ raise ValueError(f"Expected an integer for number of special tokens, got: {num_special_line}")
156
+
157
+ # Read special tokens if any
158
+ for _ in range(num_special):
159
+ line = f.readline().strip()
160
+ if line:
161
+ special, idx_str = line.rsplit(" ", 1)
162
+ special_tokens[special] = int(idx_str)
163
+
164
+ # Read merges
165
+ for line in f:
166
+ parts = line.split()
167
+ if len(parts) == 2:
168
+ idx1, idx2 = map(int, parts)
169
+ merges[(idx1, idx2)] = idx
170
+ idx += 1
171
+
172
+ self.merges = merges
173
+ self.special_tokens = special_tokens
174
+ self.vocab = self._build_vocab()
175
+
176
+
utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unicodedata
2
+ def get_stats(ids, counts=None):
3
+ """
4
+ Given a list of integers, return a dictionary of counts of consecutive pairs
5
+ Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
6
+ Optionally allows to update an existing dictionary of counts
7
+ """
8
+ counts = {} if counts is None else counts
9
+ for pair in zip(ids, ids[1:]): # iterate consecutive elements
10
+ counts[pair] = counts.get(pair, 0) + 1
11
+ return counts
12
+
13
+ def merge(ids, pair, idx):
14
+ """
15
+ In the list of integers (ids), replace all consecutive occurrences
16
+ of pair with the new integer token idx
17
+ Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
18
+ """
19
+ newids = []
20
+ i = 0
21
+ while i < len(ids):
22
+ # if not at the very last position AND the pair matches, replace it
23
+ if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
24
+ newids.append(idx)
25
+ i += 2
26
+ else:
27
+ newids.append(ids[i])
28
+ i += 1
29
+ return newids
30
+
31
+ # first two helper functions...
32
+ def replace_control_characters(s: str) -> str:
33
+ # we don't want to print control characters
34
+ # which distort the output (e.g. \n or much worse)
35
+ # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
36
+ # http://www.unicode.org/reports/tr44/#GC_Values_Table
37
+ chars = []
38
+ for ch in s:
39
+ if unicodedata.category(ch)[0] != "C":
40
+ chars.append(ch) # this character is ok
41
+ else:
42
+ chars.append(f"\\u{ord(ch):04x}") # escape
43
+ return "".join(chars)
44
+
45
+ def render_token(t: bytes) -> str:
46
+ # pretty print a token, escaping control characters
47
+ s = t.decode('utf-8', errors='replace')
48
+ s = replace_control_characters(s)
49
+ return s
50
+
51
+