File size: 9,441 Bytes
6c7e241 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | try:
from .base import get_stats, merge, visualise_tokens
from .basic import BasicTokenizer
from .patterns import GPT4_SPLIT_PATTERN
except ImportError: # allow running as a script from inside `tokenizer/`
from base import get_stats, merge, visualise_tokens
from basic import BasicTokenizer
from patterns import GPT4_SPLIT_PATTERN
from collections import Counter, defaultdict
import heapq
import regex as re
from tqdm import tqdm
import time
class RegexTokenizer(BasicTokenizer):
def __init__(self, regex: str = GPT4_SPLIT_PATTERN):
super().__init__()
self.pattern = regex
self.regex = re.compile(self.pattern)
def register_special_tokens(self, special_tokens: dict[str, int]):
self.special_tokens = special_tokens
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
@staticmethod
def _merge_word(word: tuple[int, ...], pair: tuple[int, int], new_id: int) -> tuple[int, ...]:
"""Merge all non-overlapping occurrences of `pair` in `word`."""
out: list[int] = []
i = 0
while i < len(word):
if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]:
out.append(new_id)
i += 2
else:
out.append(word[i])
i += 1
return tuple(out)
@staticmethod
def _pair_occurrences(word: tuple[int, ...]) -> dict[tuple[int, int], int]:
"""Return unweighted pair -> count for a single word/chunk."""
if len(word) < 2:
return {}
counts: dict[tuple[int, int], int] = {}
a = word[0]
for b in word[1:]:
p = (a, b)
counts[p] = counts.get(p, 0) + 1
a = b
return counts
def train(
self,
text: str,
vocab_size: int = 50_257,
verbose: bool = False,
*,
min_chunk_freq: int = 1,
max_chunks: int | None = None,
):
assert vocab_size >= 256, "Vocab size must be at least 256"
num_merges = vocab_size - 256
# Count chunk frequencies without storing a giant list of chunks.
# Each unique chunk becomes a "word" in classic BPE training.
chunk_counts: Counter[bytes] = Counter()
for m in self.regex.finditer(text):
s = m.group(0)
if s:
chunk_counts[s.encode("utf-8")] += 1
# Heuristic speed knobs: ignore rare chunks and/or cap unique chunk types.
# This massively reduces training state on web-scale corpora and keeps code simple.
if min_chunk_freq > 1:
chunk_counts = Counter({b: f for b, f in chunk_counts.items() if f >= min_chunk_freq})
if max_chunks is not None and len(chunk_counts) > max_chunks:
chunk_counts = Counter(dict(chunk_counts.most_common(max_chunks)))
# words: tuple(symbol_ids) -> frequency
words: dict[tuple[int, ...], int] = {}
for b, freq in chunk_counts.items():
words[tuple(b)] = freq
# Global pair stats and a reverse index pair -> set(words containing it)
pair_counts: dict[tuple[int, int], int] = defaultdict(int)
pair_to_words: dict[tuple[int, int], set[tuple[int, ...]]] = defaultdict(set)
for w, freq in words.items():
local = self._pair_occurrences(w)
for p, occ in local.items():
pair_counts[p] += freq * occ
pair_to_words[p].add(w)
# Max-heap for fast "most frequent pair" selection (lazy updates).
heap: list[tuple[int, tuple[int, int]]] = [(-c, p) for p, c in pair_counts.items()]
heapq.heapify(heap)
merges = {}
vocab = {idx: bytes([idx]) for idx in range(256)}
def bump_pair(p: tuple[int, int], delta: int) -> None:
if delta == 0:
return
new = pair_counts.get(p, 0) + delta
if new <= 0:
pair_counts.pop(p, None)
pair_to_words.pop(p, None)
return
pair_counts[p] = new
heapq.heappush(heap, (-new, p))
for i in tqdm(range(num_merges), desc="Training tokenizer"):
start_time = time.time()
# Pop stale heap entries until the top matches current counts.
while heap:
negc, p = heap[0]
c = pair_counts.get(p, 0)
if c > 0 and -negc == c:
break
heapq.heappop(heap)
if not heap:
break
pair = heap[0][1]
count = pair_counts.get(pair, 0)
if count <= 0:
break
idx = 256 + i
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
affected = list(pair_to_words.get(pair, ()))
if not affected:
pair_counts.pop(pair, None)
pair_to_words.pop(pair, None)
continue
# Apply merge to all words that contain the best pair.
for w in affected:
freq = words.get(w)
if not freq:
continue
new_w = self._merge_word(w, pair, idx)
if new_w == w:
continue
# Remove old word contributions
old_local = self._pair_occurrences(w)
for p, occ in old_local.items():
bump_pair(p, -freq * occ)
s = pair_to_words.get(p)
if s is not None:
s.discard(w)
if not s:
pair_to_words.pop(p, None)
# Update words dict (merge words that collapse to the same new tuple)
del words[w]
words[new_w] = words.get(new_w, 0) + freq
# Add new word contributions
new_local = self._pair_occurrences(new_w)
for p, occ in new_local.items():
bump_pair(p, freq * occ)
pair_to_words[p].add(new_w)
# This pair should be fully merged away.
pair_counts.pop(pair, None)
pair_to_words.pop(pair, None)
if verbose and i % 10 == 0:
time_taken = time.time() - start_time
tqdm.write(
f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) "
f"had {count} occurrences (took {time_taken:.2f}s)"
)
self.merges = merges
self.vocab = vocab
def decode(self, ids) -> str:
part_bytes = []
for id in ids:
if id in self.vocab:
part_bytes.append(self.vocab[id]) # id can be > 256 after merging
elif id in getattr(self, "inverse_special_tokens", {}):
part_bytes.append(self.inverse_special_tokens[id].encode("utf-8"))
else:
raise ValueError(f"id={id} not in vocab or special_tokens")
text_bytes = b"".join(part_bytes)
text = text_bytes.decode(encoding="utf-8", errors="replace")
return text
def _encode_chunk(self, chunk_bytes: bytes, verbose=False) -> list[int]:
tokens = list(chunk_bytes)
while len(tokens) >= 2:
if verbose:
visualise_tokens([self.vocab[token] for token in tokens]) # token can be > 256 after merging
stats = {}
get_stats(tokens, stats)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
if not pair in self.merges:
break
idx = self.merges[pair]
tokens = merge(tokens, pair, idx)
return tokens
def encode_ordinary(self, text, verbose=False) -> list[int]:
chunk_texts = re.findall(self.regex, text)
ids_list = []
for i, text in enumerate(chunk_texts):
if verbose:
print()
print(f"encoding chunk {i+1}/{len(chunk_texts)}: {text}")
chunk_bytes = text.encode("utf-8") # raw bytes
ids = self._encode_chunk(chunk_bytes, verbose)
ids_list.extend(ids)
return ids_list
def encode(self, text, verbose=False, allowed_special="none") -> list[int]:
special = {}
if allowed_special == "all":
special = self.special_tokens
elif allowed_special == "none":
special = {}
elif allowed_special == "none_raise":
special = {}
assert all(token not in text for token in self.special_tokens), "Text contains special tokens that are not allowed"
elif isinstance(allowed_special, set):
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
else:
raise ValueError(f"allowed_special={allowed_special} not understood.")
if not special:
return self.encode_ordinary(text, verbose)
special_pattern = "(" + "|".join(re.escape(token) for token in special) + ")"
parts = re.split(special_pattern, text)
ids = []
for part in parts:
if part in special:
ids.append(special[part])
else:
ids.extend(self.encode_ordinary(part, verbose))
return ids
|