| """ |
| VicAI Tokenizer |
| Byte-Pair Encoding (BPE) tokenizer implementation. |
| """ |
|
|
| import json |
| import pickle |
| import re |
| from collections import defaultdict |
| from typing import Dict, List, Optional, Union |
|
|
|
|
| class BPETokenizer: |
| """Byte-Pair Encoding Tokenizer.""" |
| |
| def __init__(self, vocab_size: int = 32000): |
| self.vocab_size = vocab_size |
| self.vocab = {} |
| self.merges = [] |
| self.special_tokens = { |
| '<pad>': 0, |
| '<unk>': 1, |
| '<s>': 2, |
| '</s>': 3, |
| '<mask>': 4, |
| } |
| self.pad_token_id = 0 |
| self.unk_token_id = 1 |
| self.bos_token_id = 2 |
| self.eos_token_id = 3 |
| self.mask_token_id = 4 |
| |
| def _get_stats(self, vocab): |
| """Get counts of all symbol pairs.""" |
| pairs = defaultdict(int) |
| for word, freq in vocab.items(): |
| symbols = word.split() |
| for i in range(len(symbols) - 1): |
| pairs[(symbols[i], symbols[i + 1])] += freq |
| return pairs |
| |
| def _merge_vocab(self, pair, vocab): |
| """Merge all occurrences of pair in vocab.""" |
| bigram = re.escape(' '.join(pair)) |
| pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)') |
| new_vocab = {} |
| for word in vocab: |
| new_word = pattern.sub(''.join(pair), word) |
| new_vocab[new_word] = vocab[word] |
| return new_vocab |
| |
| def _pre_tokenize(self, text: str) -> List[str]: |
| """Pre-tokenize text into words.""" |
| |
| pattern = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" |
| return re.findall(pattern, text) |
| |
| def train(self, texts: List[str]): |
| """Train BPE on a list of texts.""" |
| print(f"Training BPE tokenizer with vocab_size={self.vocab_size}") |
| |
| |
| self.vocab = {token: i for token, i in self.special_tokens.items()} |
| |
| |
| vocab = defaultdict(int) |
| for text in texts: |
| words = self._pre_tokenize(text.lower()) |
| for word in words: |
| |
| word = ' '.join(list(word)) + ' </w>' |
| vocab[tuple(word.split())] += 1 |
| |
| |
| vocab = {' '.join(k): v for k, v in vocab.items()} |
| |
| |
| for word in vocab: |
| for char in word.split(): |
| if char not in self.vocab: |
| self.vocab[char] = len(self.vocab) |
| |
| |
| num_merges = self.vocab_size - len(self.vocab) |
| for i in range(num_merges): |
| pairs = self._get_stats(vocab) |
| if not pairs: |
| break |
| |
| best = max(pairs, key=pairs.get) |
| vocab = self._merge_vocab(best, vocab) |
| self.merges.append(best) |
| |
| |
| merged_token = ''.join(best) |
| if merged_token not in self.vocab: |
| self.vocab[merged_token] = len(self.vocab) |
| |
| if (i + 1) % 1000 == 0: |
| print(f" Completed {i + 1}/{num_merges} merges") |
| |
| print(f"Final vocabulary size: {len(self.vocab)}") |
| |
| def encode(self, text: str, add_special_tokens: bool = True) -> List[int]: |
| """Encode text to token IDs.""" |
| words = self._pre_tokenize(text) |
| token_ids = [] |
| |
| if add_special_tokens: |
| token_ids.append(self.bos_token_id) |
| |
| for word in words: |
| word = word.lower() |
| word_tokens = ' '.join(list(word)) + ' </w>' |
| |
| |
| for merge in self.merges: |
| bigram = re.escape(' '.join(merge)) |
| pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)') |
| word_tokens = pattern.sub(''.join(merge), word_tokens) |
| |
| |
| for token in word_tokens.split(): |
| token_ids.append(self.vocab.get(token, self.unk_token_id)) |
| |
| if add_special_tokens: |
| token_ids.append(self.eos_token_id) |
| |
| return token_ids |
| |
| def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str: |
| """Decode token IDs to text.""" |
| |
| reverse_vocab = {v: k for k, v in self.vocab.items()} |
| |
| tokens = [] |
| for token_id in token_ids: |
| if token_id in self.special_tokens.values() and skip_special_tokens: |
| continue |
| tokens.append(reverse_vocab.get(token_id, '<unk>')) |
| |
| text = ''.join(tokens) |
| text = text.replace('</w>', ' ') |
| return text.strip() |
| |
| def save(self, path: str): |
| """Save tokenizer to file.""" |
| data = { |
| 'vocab': self.vocab, |
| 'merges': self.merges, |
| 'special_tokens': self.special_tokens, |
| 'vocab_size': self.vocab_size, |
| } |
| with open(path, 'wb') as f: |
| pickle.dump(data, f) |
| print(f"Tokenizer saved to {path}") |
| |
| def load(self, path: str): |
| """Load tokenizer from file.""" |
| with open(path, 'rb') as f: |
| data = pickle.load(f) |
| self.vocab = data['vocab'] |
| self.merges = data['merges'] |
| self.special_tokens = data['special_tokens'] |
| self.vocab_size = data['vocab_size'] |
| |
| self.pad_token_id = self.special_tokens['<pad>'] |
| self.unk_token_id = self.special_tokens['<unk>'] |
| self.bos_token_id = self.special_tokens['<s>'] |
| self.eos_token_id = self.special_tokens['</s>'] |
| self.mask_token_id = self.special_tokens['<mask>'] |
| print(f"Tokenizer loaded from {path}") |
| |
| def batch_encode( |
| self, |
| texts: List[str], |
| max_length: int = 512, |
| padding: bool = True, |
| truncation: bool = True, |
| ) -> Dict[str, List]: |
| """Batch encode texts.""" |
| encoded = [self.encode(text) for text in texts] |
| |
| if truncation: |
| encoded = [seq[:max_length] for seq in encoded] |
| |
| if padding: |
| max_len = min(max(len(seq) for seq in encoded), max_length) |
| attention_mask = [] |
| for seq in encoded: |
| mask = [1] * len(seq) + [0] * (max_len - len(seq)) |
| seq.extend([self.pad_token_id] * (max_len - len(seq))) |
| attention_mask.append(mask[:max_len]) |
| else: |
| attention_mask = [[1] * len(seq) for seq in encoded] |
| |
| return { |
| 'input_ids': encoded, |
| 'attention_mask': attention_mask, |
| } |
| |
| def __len__(self): |
| return len(self.vocab) |
|
|
|
|
| class ByteLevelBPETokenizer: |
| """Byte-level BPE tokenizer (similar to GPT-2/3).""" |
| |
| def __init__(self, vocab_size: int = 32000): |
| self.vocab_size = vocab_size |
| self.vocab = {} |
| self.merges = [] |
| self.byte_encoder = {i: chr(i + 128) for i in range(256)} |
| self.byte_decoder = {chr(i + 128): i for i in range(256)} |
| |
| self.special_tokens = { |
| '<|endoftext|>': 0, |
| '<|pad|>': 1, |
| } |
| self.eos_token_id = 0 |
| self.pad_token_id = 1 |
| |
| def _bytes_to_unicode(self, text: str) -> str: |
| """Convert string to byte-level representation.""" |
| return ''.join(self.byte_encoder[b] for b in text.encode('utf-8')) |
| |
| def _unicode_to_bytes(self, text: str) -> str: |
| """Convert byte-level representation back to string.""" |
| return bytes(self.byte_decoder[c] for c in text).decode('utf-8', errors='replace') |
| |
| def train(self, texts: List[str]): |
| """Train byte-level BPE.""" |
| print(f"Training byte-level BPE tokenizer with vocab_size={self.vocab_size}") |
| |
| |
| self.vocab = {token: i for token, i in self.special_tokens.items()} |
| for i in range(256): |
| byte_char = self.byte_encoder[i] |
| if byte_char not in self.vocab: |
| self.vocab[byte_char] = len(self.vocab) |
| |
| |
| corpus = [] |
| for text in texts: |
| byte_text = self._bytes_to_unicode(text) |
| corpus.extend(list(byte_text)) |
| |
| |
| vocab = defaultdict(int) |
| for text in texts: |
| byte_text = self._bytes_to_unicode(text) |
| |
| byte_text += '<|endoftext|>' |
| vocab[tuple(byte_text)] += 1 |
| |
| |
| num_merges = self.vocab_size - len(self.vocab) |
| |
| for i in range(num_merges): |
| pairs = self._get_stats(vocab) |
| if not pairs: |
| break |
| |
| best = max(pairs, key=pairs.get) |
| vocab = self._merge_vocab(best, vocab) |
| self.merges.append(best) |
| |
| merged = ''.join(best) |
| if merged not in self.vocab: |
| self.vocab[merged] = len(self.vocab) |
| |
| if (i + 1) % 1000 == 0: |
| print(f" Completed {i + 1}/{num_merges} merges") |
| |
| print(f"Final vocabulary size: {len(self.vocab)}") |
| |
| def _get_stats(self, vocab): |
| pairs = defaultdict(int) |
| for word, freq in vocab.items(): |
| symbols = list(word) |
| for i in range(len(symbols) - 1): |
| pairs[(symbols[i], symbols[i + 1])] += freq |
| return pairs |
| |
| def _merge_vocab(self, pair, vocab): |
| new_vocab = {} |
| bigram = pair[0] + pair[1] |
| for word in vocab: |
| new_word = [] |
| i = 0 |
| while i < len(word): |
| if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]: |
| new_word.append(bigram) |
| i += 2 |
| else: |
| new_word.append(word[i]) |
| i += 1 |
| new_vocab[tuple(new_word)] = vocab[word] |
| return new_vocab |
| |
| def encode(self, text: str, add_special_tokens: bool = True) -> List[int]: |
| """Encode text to token IDs.""" |
| byte_text = self._bytes_to_unicode(text) |
| if add_special_tokens: |
| byte_text += '<|endoftext|>' |
| |
| |
| word = list(byte_text) |
| for merge in self.merges: |
| new_word = [] |
| i = 0 |
| while i < len(word): |
| if i < len(word) - 1 and word[i] == merge[0] and word[i + 1] == merge[1]: |
| new_word.append(merge[0] + merge[1]) |
| i += 2 |
| else: |
| new_word.append(word[i]) |
| i += 1 |
| word = new_word |
| |
| |
| return [self.vocab.get(token, self.special_tokens['<|pad|>']) for token in word] |
| |
| def decode(self, token_ids: List[int]) -> str: |
| """Decode token IDs to text.""" |
| reverse_vocab = {v: k for k, v in self.vocab.items()} |
| text = ''.join(reverse_vocab.get(id, '') for id in token_ids) |
| text = text.replace('<|endoftext|>', '') |
| return self._unicode_to_bytes(text) |
| |
| def save(self, path: str): |
| """Save tokenizer to file.""" |
| data = { |
| 'vocab': self.vocab, |
| 'merges': self.merges, |
| 'special_tokens': self.special_tokens, |
| 'vocab_size': self.vocab_size, |
| 'byte_encoder': self.byte_encoder, |
| 'byte_decoder': self.byte_decoder, |
| } |
| with open(path, 'wb') as f: |
| pickle.dump(data, f) |
| print(f"Tokenizer saved to {path}") |
| |
| def load(self, path: str): |
| """Load tokenizer from file.""" |
| with open(path, 'rb') as f: |
| data = pickle.load(f) |
| self.vocab = data['vocab'] |
| self.merges = data['merges'] |
| self.special_tokens = data['special_tokens'] |
| self.vocab_size = data['vocab_size'] |
| self.byte_encoder = data.get('byte_encoder', self.byte_encoder) |
| self.byte_decoder = data.get('byte_decoder', self.byte_decoder) |
| |
| |
| if '<|endoftext|>' not in self.special_tokens: |
| self.special_tokens['<|endoftext|>'] = 0 |
| if '<|pad|>' not in self.special_tokens: |
| self.special_tokens['<|pad|>'] = 1 |
| |
| self.eos_token_id = self.special_tokens.get('<|endoftext|>', 0) |
| self.pad_token_id = self.special_tokens.get('<|pad|>', 1) |
| print(f"Tokenizer loaded from {path}") |
| |
| def __len__(self): |
| return len(self.vocab) |
|
|
|
|
| def create_and_train_tokenizer(texts: List[str], vocab_size: int = 32000, output_path: str = "tokenizer.pkl"): |
| """Create and train a tokenizer on the given texts.""" |
| tokenizer = ByteLevelBPETokenizer(vocab_size=vocab_size) |
| tokenizer.train(texts) |
| tokenizer.save(output_path) |
| return tokenizer |
|
|
|
|
| if __name__ == "__main__": |
| |
| sample_texts = [ |
| "Hello, world! This is a test.", |
| "The quick brown fox jumps over the lazy dog.", |
| "Machine learning is fascinating.", |
| "Artificial intelligence will change the world.", |
| ] |
| |
| tokenizer = BPETokenizer(vocab_size=1000) |
| tokenizer.train(sample_texts) |
| |
| test_text = "Hello world!" |
| encoded = tokenizer.encode(test_text) |
| decoded = tokenizer.decode(encoded) |
| |
| print(f"\nOriginal: {test_text}") |
| print(f"Encoded: {encoded}") |
| print(f"Decoded: {decoded}") |
|
|