Spaces:
Sleeping
Sleeping
| from tokenizers.bpe import Tokenizer, get_stats, merge | |
| import json | |
| class BasicTokenizer(Tokenizer): | |
| def __init__(self): | |
| super().__init__() | |
| self.token_to_id = {} | |
| self.id_to_token = {} | |
| def train(self, text, vocab_size, verbose=False): | |
| assert vocab_size >= 256 and vocab_size <= 5000 | |
| num_merges = vocab_size - 256 | |
| text_bytes = text.encode('utf-8') | |
| ids = list(text_bytes) | |
| merges = {} | |
| vocab = {idx: bytes([idx]) for idx in range(256)} | |
| # Initialize token mappings | |
| self.token_to_id = {bytes([idx]).decode('utf-8', errors='replace'): idx for idx in range(256)} | |
| self.id_to_token = {idx: bytes([idx]).decode('utf-8', errors='replace') for idx in range(256)} | |
| for i in range(num_merges): | |
| stats = get_stats(ids) | |
| if not stats: | |
| break | |
| pair = max(stats, key=stats.get) | |
| idx = 256 + i | |
| ids = merge(ids, pair, idx) | |
| merges[pair] = idx | |
| vocab[idx] = vocab[pair[0]] + vocab[pair[1]] | |
| # Update token mappings | |
| token = vocab[idx].decode('utf-8', errors='replace') | |
| self.token_to_id[token] = idx | |
| self.id_to_token[idx] = token | |
| if verbose and i % 20 == 0: | |
| print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") | |
| self.merges = merges | |
| self.vocab = vocab | |
| self._save_vocabulary() | |
| def _save_vocabulary(self): | |
| vocabulary = { | |
| 'token_to_id': self.token_to_id, | |
| 'id_to_token': self.id_to_token, | |
| 'merges': {f"{k[0]},{k[1]}": v for k, v in self.merges.items()} | |
| } | |
| with open('vocabulary.json', 'w', encoding='utf-8') as f: | |
| json.dump(vocabulary, f, ensure_ascii=False, indent=2) | |
| def decode(self, ids): | |
| text_bytes = b"".join(self.vocab[idx] for idx in ids) | |
| text = text_bytes.decode("utf-8", errors="replace") | |
| return text | |
| def encode(self, text): | |
| text_bytes = text.encode("utf-8") | |
| ids = list(text_bytes) | |
| while len(ids) >= 2: | |
| stats = get_stats(ids) | |
| pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) | |
| if pair not in self.merges: | |
| break | |
| idx = self.merges[pair] | |
| ids = merge(ids, pair, idx) | |
| return ids | |