import json from collections import Counter from collections import defaultdict from typing import Dict from typing import List from typing import Tuple class Tokenizer: def __init__(self, vocab_size: int = 1000): self.special_tokens = ['', '', '', ''] self.char2idx: Dict[str, int] = {} self.idx2char: Dict[int, str] = {} self.vocab_size: int = 0 self.target_vocab_size: int = vocab_size self.bpe_ranks: Dict[Tuple[str, str], int] = {} for idx, token in enumerate(self.special_tokens): self.char2idx[token] = idx self.idx2char[idx] = token self.vocab_size = len(self.special_tokens) def _get_stats(self, words: Dict[Tuple[str, ...], int]) -> Counter: pairs = Counter() for word, freq in words.items(): for i in range(len(word) - 1): pairs[(word[i], word[i + 1])] += freq return pairs def _merge_vocab( self, pair: Tuple[str, str], words: Dict[Tuple[str, ...], int] ) -> Dict[Tuple[str, ...], int]: new_words = {} replacement = "".join(pair) for word in words: new_word = [] i = 0 while i < len(word): if ( i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1] ): new_word.append(replacement) i += 2 else: new_word.append(word[i]) i += 1 new_words[tuple(new_word)] = words[word] return new_words def build_vocab(self, texts: List[str]) -> None: print(f"Building BPE vocabulary from {len(texts)} texts...") vocab = set() for text in texts: vocab.update(text) for char in sorted(vocab): if char not in self.char2idx: self.char2idx[char] = self.vocab_size self.idx2char[self.vocab_size] = char self.vocab_size += 1 print( f"Initial character vocabulary: " f"{self.vocab_size - len(self.special_tokens)} characters" ) words = defaultdict(int) for text in texts: word = tuple(text) words[word] += 1 num_merges = self.target_vocab_size - self.vocab_size print(f"Learning {num_merges} BPE merges...") for i in range(num_merges): pairs = self._get_stats(words) if not pairs: break best_pair = max(pairs, key=pairs.get) words = self._merge_vocab(best_pair, words) new_token = ''.join(best_pair) if new_token not in self.char2idx: self.char2idx[new_token] = self.vocab_size self.idx2char[self.vocab_size] = new_token self.vocab_size += 1 self.bpe_ranks[best_pair] = i if (i + 1) % 100 == 0: print( f" Learned {i + 1} merges, " f"vocab size: {self.vocab_size}" ) print(f"BPE Vocabulary built! Total tokens: {self.vocab_size}") print(f" - Special tokens: {len(self.special_tokens)}") print(f" - Base characters: {len(vocab)}") print(f" - BPE subwords: {len(self.bpe_ranks)}") print(f" - Sample subwords: {list(self.bpe_ranks.keys())[:5]}") def _tokenize(self, text: str) -> List[str]: if not text: return [] word = tuple(text) while len(word) > 1: pairs = [(word[i], word[i + 1]) for i in range(len(word) - 1)] valid_pairs = [p for p in pairs if p in self.bpe_ranks] if not valid_pairs: break bigram = min(valid_pairs, key=lambda p: self.bpe_ranks[p]) new_word = [] i = 0 while i < len(word): if ( i < len(word) - 1 and word[i] == bigram[0] and word[i + 1] == bigram[1] ): new_word.append("".join(bigram)) i += 2 else: new_word.append(word[i]) i += 1 word = tuple(new_word) return list(word) def add_token(self, token: str) -> None: if token not in self.char2idx: idx = self.vocab_size self.char2idx[token] = idx self.idx2char[idx] = token self.vocab_size += 1 def encode( self, text: str, max_length: int, add_special_tokens: bool = True ) -> List[int]: tokens = self._tokenize(text) indices = [] if add_special_tokens: indices.append(self.char2idx['']) for token in tokens[:max_length - (2 if add_special_tokens else 0)]: indices.append(self.char2idx.get(token, self.char2idx[''])) if add_special_tokens: indices.append(self.char2idx['']) while len(indices) < max_length: indices.append(self.char2idx['']) return indices def decode(self, indices: List[int]) -> str: chars = [] for idx in indices: token = self.idx2char.get(idx, '') if token == '': break if token not in ['', '', '']: chars.append(token) return ''.join(chars) def save(self, filepath: str) -> None: state = { "char2idx": self.char2idx, "special_tokens": self.special_tokens, "vocab_size": self.vocab_size, "target_vocab_size": self.target_vocab_size, "bpe_ranks": { f"{k[0]}_{k[1]}": v for k, v in self.bpe_ranks.items() }, } with open(filepath, "w") as f: json.dump(state, f, indent=2) print(f"BPE Tokenizer saved to {filepath}") def load(self, filepath: str) -> "Tokenizer": with open(filepath, "r") as f: state = json.load(f) self.char2idx = state["char2idx"] self.special_tokens = state["special_tokens"] self.vocab_size = state["vocab_size"] self.target_vocab_size = state.get("target_vocab_size", 1000) self.idx2char = {v: k for k, v in self.char2idx.items()} if "bpe_ranks" in state: self.bpe_ranks = {} for key, value in state["bpe_ranks"].items(): parts = key.split("_", 1) if len(parts) == 2: self.bpe_ranks[(parts[0], parts[1])] = value print(f"BPE Tokenizer loaded from {filepath}") print(f" - Vocab size: {self.vocab_size}") print(f" - BPE merges: {len(self.bpe_ranks)}") return self