| | try: |
| | from .base import get_stats, merge, visualise_tokens |
| | from .basic import BasicTokenizer |
| | from .patterns import GPT4_SPLIT_PATTERN |
| | except ImportError: |
| | from base import get_stats, merge, visualise_tokens |
| | from basic import BasicTokenizer |
| | from patterns import GPT4_SPLIT_PATTERN |
| | from collections import Counter, defaultdict |
| | import heapq |
| | import regex as re |
| | from tqdm import tqdm |
| | import time |
| |
|
| | class RegexTokenizer(BasicTokenizer): |
| | def __init__(self, regex: str = GPT4_SPLIT_PATTERN): |
| | super().__init__() |
| | self.pattern = regex |
| | self.regex = re.compile(self.pattern) |
| | |
| | def register_special_tokens(self, special_tokens: dict[str, int]): |
| | self.special_tokens = special_tokens |
| | self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} |
| |
|
| | @staticmethod |
| | def _merge_word(word: tuple[int, ...], pair: tuple[int, int], new_id: int) -> tuple[int, ...]: |
| | """Merge all non-overlapping occurrences of `pair` in `word`.""" |
| | out: list[int] = [] |
| | i = 0 |
| | while i < len(word): |
| | if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]: |
| | out.append(new_id) |
| | i += 2 |
| | else: |
| | out.append(word[i]) |
| | i += 1 |
| | return tuple(out) |
| |
|
| | @staticmethod |
| | def _pair_occurrences(word: tuple[int, ...]) -> dict[tuple[int, int], int]: |
| | """Return unweighted pair -> count for a single word/chunk.""" |
| | if len(word) < 2: |
| | return {} |
| | counts: dict[tuple[int, int], int] = {} |
| | a = word[0] |
| | for b in word[1:]: |
| | p = (a, b) |
| | counts[p] = counts.get(p, 0) + 1 |
| | a = b |
| | return counts |
| |
|
| | def train( |
| | self, |
| | text: str, |
| | vocab_size: int = 50_257, |
| | verbose: bool = False, |
| | *, |
| | min_chunk_freq: int = 1, |
| | max_chunks: int | None = None, |
| | ): |
| | assert vocab_size >= 256, "Vocab size must be at least 256" |
| | num_merges = vocab_size - 256 |
| |
|
| | |
| | |
| | chunk_counts: Counter[bytes] = Counter() |
| | for m in self.regex.finditer(text): |
| | s = m.group(0) |
| | if s: |
| | chunk_counts[s.encode("utf-8")] += 1 |
| |
|
| | |
| | |
| | if min_chunk_freq > 1: |
| | chunk_counts = Counter({b: f for b, f in chunk_counts.items() if f >= min_chunk_freq}) |
| | if max_chunks is not None and len(chunk_counts) > max_chunks: |
| | chunk_counts = Counter(dict(chunk_counts.most_common(max_chunks))) |
| |
|
| | |
| | words: dict[tuple[int, ...], int] = {} |
| | for b, freq in chunk_counts.items(): |
| | words[tuple(b)] = freq |
| |
|
| | |
| | pair_counts: dict[tuple[int, int], int] = defaultdict(int) |
| | pair_to_words: dict[tuple[int, int], set[tuple[int, ...]]] = defaultdict(set) |
| | for w, freq in words.items(): |
| | local = self._pair_occurrences(w) |
| | for p, occ in local.items(): |
| | pair_counts[p] += freq * occ |
| | pair_to_words[p].add(w) |
| |
|
| | |
| | heap: list[tuple[int, tuple[int, int]]] = [(-c, p) for p, c in pair_counts.items()] |
| | heapq.heapify(heap) |
| |
|
| | merges = {} |
| | vocab = {idx: bytes([idx]) for idx in range(256)} |
| |
|
| | def bump_pair(p: tuple[int, int], delta: int) -> None: |
| | if delta == 0: |
| | return |
| | new = pair_counts.get(p, 0) + delta |
| | if new <= 0: |
| | pair_counts.pop(p, None) |
| | pair_to_words.pop(p, None) |
| | return |
| | pair_counts[p] = new |
| | heapq.heappush(heap, (-new, p)) |
| |
|
| | for i in tqdm(range(num_merges), desc="Training tokenizer"): |
| | start_time = time.time() |
| |
|
| | |
| | while heap: |
| | negc, p = heap[0] |
| | c = pair_counts.get(p, 0) |
| | if c > 0 and -negc == c: |
| | break |
| | heapq.heappop(heap) |
| | if not heap: |
| | break |
| |
|
| | pair = heap[0][1] |
| | count = pair_counts.get(pair, 0) |
| | if count <= 0: |
| | break |
| |
|
| | idx = 256 + i |
| | merges[pair] = idx |
| | vocab[idx] = vocab[pair[0]] + vocab[pair[1]] |
| |
|
| | affected = list(pair_to_words.get(pair, ())) |
| | if not affected: |
| | pair_counts.pop(pair, None) |
| | pair_to_words.pop(pair, None) |
| | continue |
| |
|
| | |
| | for w in affected: |
| | freq = words.get(w) |
| | if not freq: |
| | continue |
| |
|
| | new_w = self._merge_word(w, pair, idx) |
| | if new_w == w: |
| | continue |
| |
|
| | |
| | old_local = self._pair_occurrences(w) |
| | for p, occ in old_local.items(): |
| | bump_pair(p, -freq * occ) |
| | s = pair_to_words.get(p) |
| | if s is not None: |
| | s.discard(w) |
| | if not s: |
| | pair_to_words.pop(p, None) |
| |
|
| | |
| | del words[w] |
| | words[new_w] = words.get(new_w, 0) + freq |
| |
|
| | |
| | new_local = self._pair_occurrences(new_w) |
| | for p, occ in new_local.items(): |
| | bump_pair(p, freq * occ) |
| | pair_to_words[p].add(new_w) |
| |
|
| | |
| | pair_counts.pop(pair, None) |
| | pair_to_words.pop(pair, None) |
| |
|
| | if verbose and i % 10 == 0: |
| | time_taken = time.time() - start_time |
| | tqdm.write( |
| | f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) " |
| | f"had {count} occurrences (took {time_taken:.2f}s)" |
| | ) |
| | |
| | self.merges = merges |
| | self.vocab = vocab |
| | |
| | def decode(self, ids) -> str: |
| | part_bytes = [] |
| | for id in ids: |
| | if id in self.vocab: |
| | part_bytes.append(self.vocab[id]) |
| | elif id in getattr(self, "inverse_special_tokens", {}): |
| | part_bytes.append(self.inverse_special_tokens[id].encode("utf-8")) |
| | else: |
| | raise ValueError(f"id={id} not in vocab or special_tokens") |
| | text_bytes = b"".join(part_bytes) |
| | text = text_bytes.decode(encoding="utf-8", errors="replace") |
| | return text |
| | |
| | def _encode_chunk(self, chunk_bytes: bytes, verbose=False) -> list[int]: |
| | tokens = list(chunk_bytes) |
| | while len(tokens) >= 2: |
| | if verbose: |
| | visualise_tokens([self.vocab[token] for token in tokens]) |
| | stats = {} |
| | get_stats(tokens, stats) |
| | pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) |
| | if not pair in self.merges: |
| | break |
| | idx = self.merges[pair] |
| | tokens = merge(tokens, pair, idx) |
| | return tokens |
| | |
| | def encode_ordinary(self, text, verbose=False) -> list[int]: |
| | chunk_texts = re.findall(self.regex, text) |
| | ids_list = [] |
| | for i, text in enumerate(chunk_texts): |
| | if verbose: |
| | print() |
| | print(f"encoding chunk {i+1}/{len(chunk_texts)}: {text}") |
| | chunk_bytes = text.encode("utf-8") |
| | ids = self._encode_chunk(chunk_bytes, verbose) |
| | ids_list.extend(ids) |
| | return ids_list |
| | |
| | def encode(self, text, verbose=False, allowed_special="none") -> list[int]: |
| | special = {} |
| | if allowed_special == "all": |
| | special = self.special_tokens |
| | elif allowed_special == "none": |
| | special = {} |
| | elif allowed_special == "none_raise": |
| | special = {} |
| | assert all(token not in text for token in self.special_tokens), "Text contains special tokens that are not allowed" |
| | elif isinstance(allowed_special, set): |
| | special = {k: v for k, v in self.special_tokens.items() if k in allowed_special} |
| | else: |
| | raise ValueError(f"allowed_special={allowed_special} not understood.") |
| | if not special: |
| | return self.encode_ordinary(text, verbose) |
| | special_pattern = "(" + "|".join(re.escape(token) for token in special) + ")" |
| | parts = re.split(special_pattern, text) |
| | ids = [] |
| | for part in parts: |
| | if part in special: |
| | ids.append(special[part]) |
| | else: |
| | ids.extend(self.encode_ordinary(part, verbose)) |
| | return ids |
| |
|
| |
|