|
|
""" |
|
|
Byte-Pair Encoding (BPE) Tokenizer β Built From Scratch |
|
|
======================================================== |
|
|
A minimal but complete BPE tokenizer implementation. |
|
|
Supports training from raw text, encoding/decoding, and special chat tokens. |
|
|
|
|
|
For production use, you'd typically use SentencePiece or tiktoken, |
|
|
but this demonstrates the full tokenizer pipeline. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
import re |
|
|
from collections import Counter |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
|
|
|
class BPETokenizer: |
|
|
""" |
|
|
Byte-Pair Encoding tokenizer with special token support. |
|
|
|
|
|
Special tokens: |
|
|
<pad> = 0 Padding token |
|
|
<bos> = 1 Beginning of sequence |
|
|
<eos> = 2 End of sequence |
|
|
<unk> = 3 Unknown token |
|
|
<|system|> = 4 System prompt delimiter |
|
|
<|user|> = 5 User turn delimiter |
|
|
<|assistant|> = 6 Assistant turn delimiter |
|
|
<|end|> = 7 End of turn |
|
|
""" |
|
|
|
|
|
SPECIAL_TOKENS = { |
|
|
"<pad>": 0, |
|
|
"<bos>": 1, |
|
|
"<eos>": 2, |
|
|
"<unk>": 3, |
|
|
"<|system|>": 4, |
|
|
"<|user|>": 5, |
|
|
"<|assistant|>": 6, |
|
|
"<|end|>": 7, |
|
|
} |
|
|
|
|
|
|
|
|
PAT = re.compile( |
|
|
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?\d+| ?[^\s\w\d]+|\s+(?!\S)|\s+""", |
|
|
re.UNICODE, |
|
|
) |
|
|
|
|
|
def __init__(self, vocab_size: int = 32_000): |
|
|
self.target_vocab_size = vocab_size |
|
|
self.special_tokens = dict(self.SPECIAL_TOKENS) |
|
|
self.num_special = len(self.special_tokens) |
|
|
|
|
|
|
|
|
self.byte_to_id: Dict[int, int] = { |
|
|
b: b + self.num_special for b in range(256) |
|
|
} |
|
|
self.id_to_byte: Dict[int, int] = {v: k for k, v in self.byte_to_id.items()} |
|
|
|
|
|
|
|
|
self.merges: List[Tuple[int, int]] = [] |
|
|
self.merge_to_id: Dict[Tuple[int, int], int] = {} |
|
|
|
|
|
|
|
|
self.vocab: Dict[int, bytes] = {} |
|
|
self._build_vocab() |
|
|
|
|
|
def _build_vocab(self): |
|
|
"""Reconstruct the full vocabulary from merges.""" |
|
|
self.vocab = {} |
|
|
|
|
|
for tok, idx in self.special_tokens.items(): |
|
|
self.vocab[idx] = tok.encode("utf-8") |
|
|
|
|
|
for b in range(256): |
|
|
self.vocab[self.num_special + b] = bytes([b]) |
|
|
|
|
|
for (a, b), idx in self.merge_to_id.items(): |
|
|
self.vocab[idx] = self.vocab[a] + self.vocab[b] |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return len(self.vocab) |
|
|
|
|
|
|
|
|
|
|
|
def train(self, text: str, verbose: bool = True): |
|
|
""" |
|
|
Train BPE merges from raw text. |
|
|
|
|
|
Args: |
|
|
text: Raw training text |
|
|
verbose: Print progress |
|
|
""" |
|
|
if verbose: |
|
|
print(f"Training BPE tokenizer (target vocab: {self.target_vocab_size:,})...") |
|
|
|
|
|
|
|
|
words = re.findall(self.PAT, text) |
|
|
|
|
|
|
|
|
word_freqs: Counter = Counter() |
|
|
for word in words: |
|
|
byte_ids = tuple(self.byte_to_id[b] for b in word.encode("utf-8")) |
|
|
word_freqs[byte_ids] += 1 |
|
|
|
|
|
current_vocab_size = self.num_special + 256 |
|
|
num_merges = self.target_vocab_size - current_vocab_size |
|
|
|
|
|
for i in range(num_merges): |
|
|
|
|
|
pair_counts: Counter = Counter() |
|
|
for word, freq in word_freqs.items(): |
|
|
for j in range(len(word) - 1): |
|
|
pair_counts[(word[j], word[j + 1])] += freq |
|
|
|
|
|
if not pair_counts: |
|
|
break |
|
|
|
|
|
|
|
|
best_pair = pair_counts.most_common(1)[0][0] |
|
|
new_id = current_vocab_size |
|
|
|
|
|
|
|
|
self.merges.append(best_pair) |
|
|
self.merge_to_id[best_pair] = new_id |
|
|
|
|
|
|
|
|
new_word_freqs: Counter = Counter() |
|
|
for word, freq in word_freqs.items(): |
|
|
new_word = self._apply_merge(word, best_pair, new_id) |
|
|
new_word_freqs[new_word] += freq |
|
|
word_freqs = new_word_freqs |
|
|
|
|
|
current_vocab_size += 1 |
|
|
|
|
|
if verbose and (i + 1) % 1000 == 0: |
|
|
print(f" Merge {i + 1}/{num_merges}: " |
|
|
f"({best_pair[0]}, {best_pair[1]}) β {new_id}, " |
|
|
f"freq={pair_counts[best_pair]}") |
|
|
|
|
|
self._build_vocab() |
|
|
if verbose: |
|
|
print(f"Done! Final vocab size: {self.vocab_size:,}") |
|
|
|
|
|
@staticmethod |
|
|
def _apply_merge( |
|
|
word: Tuple[int, ...], pair: Tuple[int, int], new_id: int |
|
|
) -> Tuple[int, ...]: |
|
|
"""Apply a single merge rule to a word.""" |
|
|
result = [] |
|
|
i = 0 |
|
|
while i < len(word): |
|
|
if i < len(word) - 1 and (word[i], word[i + 1]) == pair: |
|
|
result.append(new_id) |
|
|
i += 2 |
|
|
else: |
|
|
result.append(word[i]) |
|
|
i += 1 |
|
|
return tuple(result) |
|
|
|
|
|
|
|
|
|
|
|
def encode(self, text: str, add_special_tokens: bool = False) -> List[int]: |
|
|
""" |
|
|
Encode text to token IDs. |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
add_special_tokens: Whether to wrap with <bos>/<eos> |
|
|
|
|
|
Returns: |
|
|
List of token IDs |
|
|
""" |
|
|
tokens = [] |
|
|
|
|
|
|
|
|
parts = self._split_special_tokens(text) |
|
|
|
|
|
for part, is_special in parts: |
|
|
if is_special: |
|
|
tokens.append(self.special_tokens[part]) |
|
|
else: |
|
|
|
|
|
words = re.findall(self.PAT, part) |
|
|
for word in words: |
|
|
|
|
|
byte_ids = list(self.byte_to_id[b] for b in word.encode("utf-8")) |
|
|
|
|
|
for pair, new_id in zip(self.merges, range(self.num_special + 256, self.vocab_size)): |
|
|
i = 0 |
|
|
while i < len(byte_ids) - 1: |
|
|
if (byte_ids[i], byte_ids[i + 1]) == pair: |
|
|
byte_ids[i] = new_id |
|
|
del byte_ids[i + 1] |
|
|
else: |
|
|
i += 1 |
|
|
tokens.extend(byte_ids) |
|
|
|
|
|
if add_special_tokens: |
|
|
tokens = [self.special_tokens["<bos>"]] + tokens + [self.special_tokens["<eos>"]] |
|
|
|
|
|
return tokens |
|
|
|
|
|
def _split_special_tokens(self, text: str) -> List[Tuple[str, bool]]: |
|
|
"""Split text on special token boundaries.""" |
|
|
|
|
|
pattern = "|".join(re.escape(tok) for tok in self.special_tokens.keys()) |
|
|
if not pattern: |
|
|
return [(text, False)] |
|
|
|
|
|
parts = [] |
|
|
last_end = 0 |
|
|
for match in re.finditer(pattern, text): |
|
|
if match.start() > last_end: |
|
|
parts.append((text[last_end:match.start()], False)) |
|
|
parts.append((match.group(), True)) |
|
|
last_end = match.end() |
|
|
if last_end < len(text): |
|
|
parts.append((text[last_end:], False)) |
|
|
return parts |
|
|
|
|
|
|
|
|
|
|
|
def decode(self, ids: List[int], skip_special: bool = True) -> str: |
|
|
""" |
|
|
Decode token IDs to text. |
|
|
|
|
|
Args: |
|
|
ids: List of token IDs |
|
|
skip_special: Whether to skip special tokens |
|
|
|
|
|
Returns: |
|
|
Decoded text string |
|
|
""" |
|
|
byte_chunks = [] |
|
|
for idx in ids: |
|
|
if idx in self.special_tokens.values(): |
|
|
if not skip_special: |
|
|
|
|
|
for tok, tid in self.special_tokens.items(): |
|
|
if tid == idx: |
|
|
byte_chunks.append(tok.encode("utf-8")) |
|
|
break |
|
|
elif idx in self.vocab: |
|
|
byte_chunks.append(self.vocab[idx]) |
|
|
return b"".join(byte_chunks).decode("utf-8", errors="replace") |
|
|
|
|
|
|
|
|
|
|
|
def encode_chat( |
|
|
self, |
|
|
messages: List[Dict[str, str]], |
|
|
add_generation_prompt: bool = True, |
|
|
) -> List[int]: |
|
|
""" |
|
|
Encode a chat conversation into token IDs. |
|
|
|
|
|
Args: |
|
|
messages: List of {"role": "system"|"user"|"assistant", "content": "..."} |
|
|
add_generation_prompt: Add the assistant turn start token at the end |
|
|
|
|
|
Returns: |
|
|
List of token IDs |
|
|
""" |
|
|
tokens = [self.special_tokens["<bos>"]] |
|
|
|
|
|
for msg in messages: |
|
|
role = msg["role"] |
|
|
content = msg["content"] |
|
|
|
|
|
if role == "system": |
|
|
tokens.append(self.special_tokens["<|system|>"]) |
|
|
elif role == "user": |
|
|
tokens.append(self.special_tokens["<|user|>"]) |
|
|
elif role == "assistant": |
|
|
tokens.append(self.special_tokens["<|assistant|>"]) |
|
|
|
|
|
tokens.extend(self.encode(content)) |
|
|
tokens.append(self.special_tokens["<|end|>"]) |
|
|
|
|
|
if add_generation_prompt: |
|
|
tokens.append(self.special_tokens["<|assistant|>"]) |
|
|
|
|
|
return tokens |
|
|
|
|
|
|
|
|
|
|
|
def save(self, path: str): |
|
|
"""Save tokenizer to JSON.""" |
|
|
os.makedirs(os.path.dirname(path) or ".", exist_ok=True) |
|
|
data = { |
|
|
"target_vocab_size": self.target_vocab_size, |
|
|
"merges": self.merges, |
|
|
} |
|
|
with open(path, "w") as f: |
|
|
json.dump(data, f) |
|
|
|
|
|
@classmethod |
|
|
def load(cls, path: str) -> "BPETokenizer": |
|
|
"""Load tokenizer from JSON.""" |
|
|
with open(path) as f: |
|
|
data = json.load(f) |
|
|
tok = cls(vocab_size=data["target_vocab_size"]) |
|
|
tok.merges = [tuple(m) for m in data["merges"]] |
|
|
tok.merge_to_id = { |
|
|
tuple(pair): idx |
|
|
for idx, pair in enumerate(tok.merges, start=tok.num_special + 256) |
|
|
} |
|
|
tok._build_vocab() |
|
|
return tok |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
tok = BPETokenizer(vocab_size=500) |
|
|
|
|
|
sample = ( |
|
|
"Hello, world! This is a test of the BPE tokenizer. " |
|
|
"The quick brown fox jumps over the lazy dog. " |
|
|
"Machine learning is fascinating and powerful. " * 20 |
|
|
) |
|
|
|
|
|
tok.train(sample, verbose=True) |
|
|
|
|
|
text = "Hello, world! Machine learning is great." |
|
|
ids = tok.encode(text) |
|
|
decoded = tok.decode(ids) |
|
|
print(f"\nOriginal: {text}") |
|
|
print(f"Token IDs: {ids[:20]}...") |
|
|
print(f"Decoded: {decoded}") |
|
|
|
|
|
|
|
|
chat = [ |
|
|
{"role": "system", "content": "You are helpful."}, |
|
|
{"role": "user", "content": "Hello!"}, |
|
|
] |
|
|
chat_ids = tok.encode_chat(chat) |
|
|
print(f"\nChat IDs: {chat_ids[:20]}...") |
|
|
print(f"Chat decoded: {tok.decode(chat_ids, skip_special=False)}") |
|
|
|