Spaces:
Build error
Build error
| from typing import List, Dict, Optional | |
| from tqdm import tqdm | |
| from collections import Counter | |
| from matplotlib import pyplot as plt | |
| import json | |
| from pathlib import Path | |
| class TrieNode: | |
| """Node in the prefix tree (trie) for fast token matching""" | |
| def __init__(self): | |
| self.children = {} | |
| self.is_token = False | |
| self.token = None | |
| class BytePairEncoder: | |
| def __init__(self, text: str): | |
| # Initialize vocabulary from characters | |
| self.chars = sorted(list(set(text))) | |
| self.stoi = {ch: i for i, ch in enumerate(self.chars)} | |
| self.itos = {i: ch for i, ch in enumerate(self.chars)} | |
| # Initial encoding of text | |
| self.data = [self.stoi[c] for c in text] | |
| # Statistics tracking | |
| self.stats = { | |
| "vocab_sizes": [len(self.chars)], | |
| "data_sizes": [len(self.data)], | |
| "compression_ratios": [1.0], | |
| "merge_counts": [], | |
| "tokens_created": [], | |
| "max_token_lengths": [1], | |
| } | |
| # Store original length for compression ratio | |
| self.original_length = len(self.data) | |
| self.max_token_length = 1 | |
| def get_digram_stats(self) -> Counter: | |
| """Get digram counts""" | |
| counts = Counter() | |
| for pair in zip(self.data, self.data[1:]): | |
| pair = (int(pair[0]), int(pair[1])) | |
| counts[pair] += 1 | |
| return counts | |
| def encode_to_vocab_size(self, target_vocab_size: int, plot_interval: Optional[int] = None, | |
| print_interval: int = 100) -> None: | |
| """Train until reaching target vocabulary size""" | |
| pbar = tqdm(total=target_vocab_size, desc="Training BPE", initial=len(self.chars)) | |
| iteration = 0 | |
| while len(self.itos) < target_vocab_size: | |
| result = self._merge_step() | |
| if result is None: | |
| break | |
| iteration += 1 | |
| pbar.update(1) | |
| if print_interval and iteration % print_interval == 0: | |
| self._print_progress(iteration) | |
| if plot_interval and iteration % plot_interval == 0: | |
| self.plot_statistics(iteration=iteration) | |
| pbar.close() | |
| def _merge_step(self): | |
| """Perform one merge operation""" | |
| stats = self.get_digram_stats() | |
| if not stats: | |
| return None | |
| top_pair, count = max(stats.items(), key=lambda x: x[1]) | |
| new_token = self._add_token(top_pair) | |
| self.data = self._replace_pairs(top_pair, new_token) | |
| self._update_stats(count) | |
| return new_token, count | |
| def _add_token(self, pair: tuple) -> int: | |
| """Add new token to vocabulary""" | |
| token_str = self.itos[pair[0]] + self.itos[pair[1]] | |
| token_id = len(self.itos) | |
| self.stoi[token_str] = token_id | |
| self.itos[token_id] = token_str | |
| self.max_token_length = max(self.max_token_length, len(token_str)) | |
| return token_id | |
| def _replace_pairs(self, pair: tuple, new_token: int) -> List[int]: | |
| """Replace all occurrences of pair with new token""" | |
| result = [] | |
| i = 0 | |
| while i < len(self.data): | |
| if i < len(self.data) - 1 and self.data[i] == pair[0] and self.data[i + 1] == pair[1]: | |
| result.append(new_token) | |
| i += 2 | |
| else: | |
| result.append(self.data[i]) | |
| i += 1 | |
| return result | |
| def _update_stats(self, merge_count: int): | |
| """Update training statistics""" | |
| self.stats["vocab_sizes"].append(len(self.itos)) | |
| self.stats["data_sizes"].append(len(self.data)) | |
| compression = self.original_length / len(self.data) | |
| self.stats["compression_ratios"].append(compression) | |
| self.stats["merge_counts"].append(merge_count) | |
| self.stats["tokens_created"].append(self.itos[len(self.itos)-1]) | |
| self.stats["max_token_lengths"].append(self.max_token_length) | |
| def plot_statistics(self, iteration: Optional[int] = None): | |
| """Plot training statistics""" | |
| fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) | |
| # Plot training metrics | |
| ax1.plot(self.stats["vocab_sizes"], self.stats["data_sizes"]) | |
| ax1.set_title("Vocabulary vs Dataset Size") | |
| ax2.plot(self.stats["vocab_sizes"], self.stats["compression_ratios"]) | |
| ax2.set_title("Compression Ratio Progress") | |
| if self.stats["merge_counts"]: | |
| ax3.hist(self.stats["merge_counts"], bins=30) | |
| ax3.set_title("Merge Counts Distribution") | |
| if self.stats["tokens_created"]: | |
| lengths = [len(t) for t in self.stats["tokens_created"]] | |
| ax4.plot(range(len(lengths)), lengths) | |
| ax4.set_title("Token Length Evolution") | |
| plt.tight_layout() | |
| plt.show() | |
| def save_to_file(self, filepath: Path): | |
| """Save encoder state""" | |
| state = { | |
| "chars": self.chars, | |
| "stoi": self.stoi, | |
| "max_token_length": self.max_token_length, | |
| "stats": self.stats | |
| } | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(state, f, ensure_ascii=False, indent=2) | |
| def load_from_file(cls, filepath: Path): | |
| """Load encoder state""" | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| state = json.load(f) | |
| instance = cls("") # Create empty instance | |
| instance.chars = state["chars"] | |
| instance.stoi = state["stoi"] | |
| instance.itos = {int(i): s for s, i in state["stoi"].items()} | |
| instance.max_token_length = state["max_token_length"] | |
| instance.stats = state["stats"] | |
| return instance | |
| def _print_progress(self, iteration: int): | |
| """Print training progress""" | |
| print(f"\nIteration {iteration}:") | |
| print(f"Vocabulary size: {len(self.itos):,}") | |
| print(f"Data size: {len(self.data):,}") | |
| print(f"Compression ratio: {self.stats['compression_ratios'][-1]:.2f}") | |
| if self.stats["merge_counts"]: | |
| last_merge = self.stats["merge_counts"][-1] | |
| last_token = self.stats["tokens_created"][-1] | |
| print(f"Last merge count: {last_merge:,}") | |
| print(f"Last token created: '{last_token}'") | |
| print(f"Max token length: {self.max_token_length}") | |
| class TokenizerInternal: | |
| """Tokenizer using trained BPE model""" | |
| def __init__(self, encoder: BytePairEncoder): | |
| self.stoi = encoder.stoi | |
| self.max_token_length = encoder.max_token_length | |
| self._trie = self._build_trie() | |
| def _build_trie(self) -> TrieNode: | |
| """Build trie for efficient tokenization""" | |
| root = TrieNode() | |
| for token in self.stoi: | |
| node = root | |
| for char in token: | |
| if char not in node.children: | |
| node.children[char] = TrieNode() | |
| node = node.children[char] | |
| node.is_token = True | |
| node.token = token | |
| return root | |
| def tokenize(self, text: str) -> List[str]: | |
| """Tokenize text using trie-based matching""" | |
| tokens = [] | |
| pos = 0 | |
| while pos < len(text): | |
| token = self._find_longest_token(text[pos:]) | |
| tokens.append(token) | |
| pos += len(token) | |
| return tokens | |
| def _find_longest_token(self, text: str) -> str: | |
| """Find longest matching token starting at current position""" | |
| node = self._trie | |
| longest = text[0] | |
| current = "" | |
| for char in text[:self.max_token_length]: | |
| if char not in node.children: | |
| break | |
| current += char | |
| node = node.children[char] | |
| if node.is_token: | |
| longest = node.token | |
| return longest |