""" VicAI Tokenizer Byte-Pair Encoding (BPE) tokenizer implementation. """ import json import pickle import re from collections import defaultdict from typing import Dict, List, Optional, Union class BPETokenizer: """Byte-Pair Encoding Tokenizer.""" def __init__(self, vocab_size: int = 32000): self.vocab_size = vocab_size self.vocab = {} self.merges = [] self.special_tokens = { '': 0, '': 1, '': 2, '': 3, '': 4, } self.pad_token_id = 0 self.unk_token_id = 1 self.bos_token_id = 2 self.eos_token_id = 3 self.mask_token_id = 4 def _get_stats(self, vocab): """Get counts of all symbol pairs.""" pairs = defaultdict(int) for word, freq in vocab.items(): symbols = word.split() for i in range(len(symbols) - 1): pairs[(symbols[i], symbols[i + 1])] += freq return pairs def _merge_vocab(self, pair, vocab): """Merge all occurrences of pair in vocab.""" bigram = re.escape(' '.join(pair)) pattern = re.compile(r'(? List[str]: """Pre-tokenize text into words.""" # Simple whitespace and punctuation tokenization pattern = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" return re.findall(pattern, text) def train(self, texts: List[str]): """Train BPE on a list of texts.""" print(f"Training BPE tokenizer with vocab_size={self.vocab_size}") # Initialize vocabulary with special tokens self.vocab = {token: i for token, i in self.special_tokens.items()} # Build word frequency dictionary vocab = defaultdict(int) for text in texts: words = self._pre_tokenize(text.lower()) for word in words: # End word with word = ' '.join(list(word)) + ' ' vocab[tuple(word.split())] += 1 # Convert to string format vocab = {' '.join(k): v for k, v in vocab.items()} # Add individual characters to vocab for word in vocab: for char in word.split(): if char not in self.vocab: self.vocab[char] = len(self.vocab) # BPE training num_merges = self.vocab_size - len(self.vocab) for i in range(num_merges): pairs = self._get_stats(vocab) if not pairs: break best = max(pairs, key=pairs.get) vocab = self._merge_vocab(best, vocab) self.merges.append(best) # Add merged token to vocab merged_token = ''.join(best) if merged_token not in self.vocab: self.vocab[merged_token] = len(self.vocab) if (i + 1) % 1000 == 0: print(f" Completed {i + 1}/{num_merges} merges") print(f"Final vocabulary size: {len(self.vocab)}") def encode(self, text: str, add_special_tokens: bool = True) -> List[int]: """Encode text to token IDs.""" words = self._pre_tokenize(text) token_ids = [] if add_special_tokens: token_ids.append(self.bos_token_id) for word in words: word = word.lower() word_tokens = ' '.join(list(word)) + ' ' # Apply BPE merges for merge in self.merges: bigram = re.escape(' '.join(merge)) pattern = re.compile(r'(? str: """Decode token IDs to text.""" # Build reverse vocab reverse_vocab = {v: k for k, v in self.vocab.items()} tokens = [] for token_id in token_ids: if token_id in self.special_tokens.values() and skip_special_tokens: continue tokens.append(reverse_vocab.get(token_id, '')) text = ''.join(tokens) text = text.replace('', ' ') return text.strip() def save(self, path: str): """Save tokenizer to file.""" data = { 'vocab': self.vocab, 'merges': self.merges, 'special_tokens': self.special_tokens, 'vocab_size': self.vocab_size, } with open(path, 'wb') as f: pickle.dump(data, f) print(f"Tokenizer saved to {path}") def load(self, path: str): """Load tokenizer from file.""" with open(path, 'rb') as f: data = pickle.load(f) self.vocab = data['vocab'] self.merges = data['merges'] self.special_tokens = data['special_tokens'] self.vocab_size = data['vocab_size'] self.pad_token_id = self.special_tokens[''] self.unk_token_id = self.special_tokens[''] self.bos_token_id = self.special_tokens[''] self.eos_token_id = self.special_tokens[''] self.mask_token_id = self.special_tokens[''] print(f"Tokenizer loaded from {path}") def batch_encode( self, texts: List[str], max_length: int = 512, padding: bool = True, truncation: bool = True, ) -> Dict[str, List]: """Batch encode texts.""" encoded = [self.encode(text) for text in texts] if truncation: encoded = [seq[:max_length] for seq in encoded] if padding: max_len = min(max(len(seq) for seq in encoded), max_length) attention_mask = [] for seq in encoded: mask = [1] * len(seq) + [0] * (max_len - len(seq)) seq.extend([self.pad_token_id] * (max_len - len(seq))) attention_mask.append(mask[:max_len]) else: attention_mask = [[1] * len(seq) for seq in encoded] return { 'input_ids': encoded, 'attention_mask': attention_mask, } def __len__(self): return len(self.vocab) class ByteLevelBPETokenizer: """Byte-level BPE tokenizer (similar to GPT-2/3).""" def __init__(self, vocab_size: int = 32000): self.vocab_size = vocab_size self.vocab = {} self.merges = [] self.byte_encoder = {i: chr(i + 128) for i in range(256)} # Shift to printable range self.byte_decoder = {chr(i + 128): i for i in range(256)} self.special_tokens = { '<|endoftext|>': 0, '<|pad|>': 1, } self.eos_token_id = 0 self.pad_token_id = 1 def _bytes_to_unicode(self, text: str) -> str: """Convert string to byte-level representation.""" return ''.join(self.byte_encoder[b] for b in text.encode('utf-8')) def _unicode_to_bytes(self, text: str) -> str: """Convert byte-level representation back to string.""" return bytes(self.byte_decoder[c] for c in text).decode('utf-8', errors='replace') def train(self, texts: List[str]): """Train byte-level BPE.""" print(f"Training byte-level BPE tokenizer with vocab_size={self.vocab_size}") # Initialize vocab with special tokens and all bytes self.vocab = {token: i for token, i in self.special_tokens.items()} for i in range(256): byte_char = self.byte_encoder[i] if byte_char not in self.vocab: self.vocab[byte_char] = len(self.vocab) # Build corpus as byte sequences corpus = [] for text in texts: byte_text = self._bytes_to_unicode(text) corpus.extend(list(byte_text)) # Get initial word frequencies vocab = defaultdict(int) for text in texts: byte_text = self._bytes_to_unicode(text) # Add end token byte_text += '<|endoftext|>' vocab[tuple(byte_text)] += 1 # BPE training num_merges = self.vocab_size - len(self.vocab) for i in range(num_merges): pairs = self._get_stats(vocab) if not pairs: break best = max(pairs, key=pairs.get) vocab = self._merge_vocab(best, vocab) self.merges.append(best) merged = ''.join(best) if merged not in self.vocab: self.vocab[merged] = len(self.vocab) if (i + 1) % 1000 == 0: print(f" Completed {i + 1}/{num_merges} merges") print(f"Final vocabulary size: {len(self.vocab)}") def _get_stats(self, vocab): pairs = defaultdict(int) for word, freq in vocab.items(): symbols = list(word) for i in range(len(symbols) - 1): pairs[(symbols[i], symbols[i + 1])] += freq return pairs def _merge_vocab(self, pair, vocab): new_vocab = {} bigram = pair[0] + pair[1] for word in vocab: new_word = [] i = 0 while i < len(word): if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]: new_word.append(bigram) i += 2 else: new_word.append(word[i]) i += 1 new_vocab[tuple(new_word)] = vocab[word] return new_vocab def encode(self, text: str, add_special_tokens: bool = True) -> List[int]: """Encode text to token IDs.""" byte_text = self._bytes_to_unicode(text) if add_special_tokens: byte_text += '<|endoftext|>' # Apply merges word = list(byte_text) for merge in self.merges: new_word = [] i = 0 while i < len(word): if i < len(word) - 1 and word[i] == merge[0] and word[i + 1] == merge[1]: new_word.append(merge[0] + merge[1]) i += 2 else: new_word.append(word[i]) i += 1 word = new_word # Convert to IDs return [self.vocab.get(token, self.special_tokens['<|pad|>']) for token in word] def decode(self, token_ids: List[int]) -> str: """Decode token IDs to text.""" reverse_vocab = {v: k for k, v in self.vocab.items()} text = ''.join(reverse_vocab.get(id, '') for id in token_ids) text = text.replace('<|endoftext|>', '') return self._unicode_to_bytes(text) def save(self, path: str): """Save tokenizer to file.""" data = { 'vocab': self.vocab, 'merges': self.merges, 'special_tokens': self.special_tokens, 'vocab_size': self.vocab_size, 'byte_encoder': self.byte_encoder, 'byte_decoder': self.byte_decoder, } with open(path, 'wb') as f: pickle.dump(data, f) print(f"Tokenizer saved to {path}") def load(self, path: str): """Load tokenizer from file.""" with open(path, 'rb') as f: data = pickle.load(f) self.vocab = data['vocab'] self.merges = data['merges'] self.special_tokens = data['special_tokens'] self.vocab_size = data['vocab_size'] self.byte_encoder = data.get('byte_encoder', self.byte_encoder) self.byte_decoder = data.get('byte_decoder', self.byte_decoder) # Ensure all special tokens exist if '<|endoftext|>' not in self.special_tokens: self.special_tokens['<|endoftext|>'] = 0 if '<|pad|>' not in self.special_tokens: self.special_tokens['<|pad|>'] = 1 self.eos_token_id = self.special_tokens.get('<|endoftext|>', 0) self.pad_token_id = self.special_tokens.get('<|pad|>', 1) print(f"Tokenizer loaded from {path}") def __len__(self): return len(self.vocab) def create_and_train_tokenizer(texts: List[str], vocab_size: int = 32000, output_path: str = "tokenizer.pkl"): """Create and train a tokenizer on the given texts.""" tokenizer = ByteLevelBPETokenizer(vocab_size=vocab_size) tokenizer.train(texts) tokenizer.save(output_path) return tokenizer if __name__ == "__main__": # Test tokenizer sample_texts = [ "Hello, world! This is a test.", "The quick brown fox jumps over the lazy dog.", "Machine learning is fascinating.", "Artificial intelligence will change the world.", ] tokenizer = BPETokenizer(vocab_size=1000) tokenizer.train(sample_texts) test_text = "Hello world!" encoded = tokenizer.encode(test_text) decoded = tokenizer.decode(encoded) print(f"\nOriginal: {test_text}") print(f"Encoded: {encoded}") print(f"Decoded: {decoded}")