Spaces:
Sleeping
Sleeping
| import json | |
| from collections import Counter | |
| from collections import defaultdict | |
| from typing import Dict | |
| from typing import List | |
| from typing import Tuple | |
| class Tokenizer: | |
| def __init__(self, vocab_size: int = 1000): | |
| self.special_tokens = ['<PAD>', '<UNK>', '<SOS>', '<EOS>'] | |
| self.char2idx: Dict[str, int] = {} | |
| self.idx2char: Dict[int, str] = {} | |
| self.vocab_size: int = 0 | |
| self.target_vocab_size: int = vocab_size | |
| self.bpe_ranks: Dict[Tuple[str, str], int] = {} | |
| for idx, token in enumerate(self.special_tokens): | |
| self.char2idx[token] = idx | |
| self.idx2char[idx] = token | |
| self.vocab_size = len(self.special_tokens) | |
| def _get_stats(self, words: Dict[Tuple[str, ...], int]) -> Counter: | |
| pairs = Counter() | |
| for word, freq in words.items(): | |
| for i in range(len(word) - 1): | |
| pairs[(word[i], word[i + 1])] += freq | |
| return pairs | |
| def _merge_vocab( | |
| self, pair: Tuple[str, str], words: Dict[Tuple[str, ...], int] | |
| ) -> Dict[Tuple[str, ...], int]: | |
| new_words = {} | |
| replacement = "".join(pair) | |
| for word in words: | |
| new_word = [] | |
| i = 0 | |
| while i < len(word): | |
| if ( | |
| i < len(word) - 1 | |
| and word[i] == pair[0] | |
| and word[i + 1] == pair[1] | |
| ): | |
| new_word.append(replacement) | |
| i += 2 | |
| else: | |
| new_word.append(word[i]) | |
| i += 1 | |
| new_words[tuple(new_word)] = words[word] | |
| return new_words | |
| def build_vocab(self, texts: List[str]) -> None: | |
| print(f"Building BPE vocabulary from {len(texts)} texts...") | |
| vocab = set() | |
| for text in texts: | |
| vocab.update(text) | |
| for char in sorted(vocab): | |
| if char not in self.char2idx: | |
| self.char2idx[char] = self.vocab_size | |
| self.idx2char[self.vocab_size] = char | |
| self.vocab_size += 1 | |
| print( | |
| f"Initial character vocabulary: " | |
| f"{self.vocab_size - len(self.special_tokens)} characters" | |
| ) | |
| words = defaultdict(int) | |
| for text in texts: | |
| word = tuple(text) | |
| words[word] += 1 | |
| num_merges = self.target_vocab_size - self.vocab_size | |
| print(f"Learning {num_merges} BPE merges...") | |
| for i in range(num_merges): | |
| pairs = self._get_stats(words) | |
| if not pairs: | |
| break | |
| best_pair = max(pairs, key=pairs.get) | |
| words = self._merge_vocab(best_pair, words) | |
| new_token = ''.join(best_pair) | |
| if new_token not in self.char2idx: | |
| self.char2idx[new_token] = self.vocab_size | |
| self.idx2char[self.vocab_size] = new_token | |
| self.vocab_size += 1 | |
| self.bpe_ranks[best_pair] = i | |
| if (i + 1) % 100 == 0: | |
| print( | |
| f" Learned {i + 1} merges, " | |
| f"vocab size: {self.vocab_size}" | |
| ) | |
| print(f"BPE Vocabulary built! Total tokens: {self.vocab_size}") | |
| print(f" - Special tokens: {len(self.special_tokens)}") | |
| print(f" - Base characters: {len(vocab)}") | |
| print(f" - BPE subwords: {len(self.bpe_ranks)}") | |
| print(f" - Sample subwords: {list(self.bpe_ranks.keys())[:5]}") | |
| def _tokenize(self, text: str) -> List[str]: | |
| if not text: | |
| return [] | |
| word = tuple(text) | |
| while len(word) > 1: | |
| pairs = [(word[i], word[i + 1]) for i in range(len(word) - 1)] | |
| valid_pairs = [p for p in pairs if p in self.bpe_ranks] | |
| if not valid_pairs: | |
| break | |
| bigram = min(valid_pairs, key=lambda p: self.bpe_ranks[p]) | |
| new_word = [] | |
| i = 0 | |
| while i < len(word): | |
| if ( | |
| i < len(word) - 1 | |
| and word[i] == bigram[0] | |
| and word[i + 1] == bigram[1] | |
| ): | |
| new_word.append("".join(bigram)) | |
| i += 2 | |
| else: | |
| new_word.append(word[i]) | |
| i += 1 | |
| word = tuple(new_word) | |
| return list(word) | |
| def add_token(self, token: str) -> None: | |
| if token not in self.char2idx: | |
| idx = self.vocab_size | |
| self.char2idx[token] = idx | |
| self.idx2char[idx] = token | |
| self.vocab_size += 1 | |
| def encode( | |
| self, text: str, max_length: int, add_special_tokens: bool = True | |
| ) -> List[int]: | |
| tokens = self._tokenize(text) | |
| indices = [] | |
| if add_special_tokens: | |
| indices.append(self.char2idx['<SOS>']) | |
| for token in tokens[:max_length - (2 if add_special_tokens else 0)]: | |
| indices.append(self.char2idx.get(token, self.char2idx['<UNK>'])) | |
| if add_special_tokens: | |
| indices.append(self.char2idx['<EOS>']) | |
| while len(indices) < max_length: | |
| indices.append(self.char2idx['<PAD>']) | |
| return indices | |
| def decode(self, indices: List[int]) -> str: | |
| chars = [] | |
| for idx in indices: | |
| token = self.idx2char.get(idx, '<UNK>') | |
| if token == '<EOS>': | |
| break | |
| if token not in ['<PAD>', '<SOS>', '<UNK>']: | |
| chars.append(token) | |
| return ''.join(chars) | |
| def save(self, filepath: str) -> None: | |
| state = { | |
| "char2idx": self.char2idx, | |
| "special_tokens": self.special_tokens, | |
| "vocab_size": self.vocab_size, | |
| "target_vocab_size": self.target_vocab_size, | |
| "bpe_ranks": { | |
| f"{k[0]}_{k[1]}": v for k, v in self.bpe_ranks.items() | |
| }, | |
| } | |
| with open(filepath, "w") as f: | |
| json.dump(state, f, indent=2) | |
| print(f"BPE Tokenizer saved to {filepath}") | |
| def load(self, filepath: str) -> "Tokenizer": | |
| with open(filepath, "r") as f: | |
| state = json.load(f) | |
| self.char2idx = state["char2idx"] | |
| self.special_tokens = state["special_tokens"] | |
| self.vocab_size = state["vocab_size"] | |
| self.target_vocab_size = state.get("target_vocab_size", 1000) | |
| self.idx2char = {v: k for k, v in self.char2idx.items()} | |
| if "bpe_ranks" in state: | |
| self.bpe_ranks = {} | |
| for key, value in state["bpe_ranks"].items(): | |
| parts = key.split("_", 1) | |
| if len(parts) == 2: | |
| self.bpe_ranks[(parts[0], parts[1])] = value | |
| print(f"BPE Tokenizer loaded from {filepath}") | |
| print(f" - Vocab size: {self.vocab_size}") | |
| print(f" - BPE merges: {len(self.bpe_ranks)}") | |
| return self |