tinyvic / tokenizer.py
Viclim's picture
Upload 17 files
9299fff verified
"""
VicAI Tokenizer
Byte-Pair Encoding (BPE) tokenizer implementation.
"""
import json
import pickle
import re
from collections import defaultdict
from typing import Dict, List, Optional, Union
class BPETokenizer:
"""Byte-Pair Encoding Tokenizer."""
def __init__(self, vocab_size: int = 32000):
self.vocab_size = vocab_size
self.vocab = {}
self.merges = []
self.special_tokens = {
'<pad>': 0,
'<unk>': 1,
'<s>': 2,
'</s>': 3,
'<mask>': 4,
}
self.pad_token_id = 0
self.unk_token_id = 1
self.bos_token_id = 2
self.eos_token_id = 3
self.mask_token_id = 4
def _get_stats(self, vocab):
"""Get counts of all symbol pairs."""
pairs = defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols) - 1):
pairs[(symbols[i], symbols[i + 1])] += freq
return pairs
def _merge_vocab(self, pair, vocab):
"""Merge all occurrences of pair in vocab."""
bigram = re.escape(' '.join(pair))
pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
new_vocab = {}
for word in vocab:
new_word = pattern.sub(''.join(pair), word)
new_vocab[new_word] = vocab[word]
return new_vocab
def _pre_tokenize(self, text: str) -> List[str]:
"""Pre-tokenize text into words."""
# Simple whitespace and punctuation tokenization
pattern = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
return re.findall(pattern, text)
def train(self, texts: List[str]):
"""Train BPE on a list of texts."""
print(f"Training BPE tokenizer with vocab_size={self.vocab_size}")
# Initialize vocabulary with special tokens
self.vocab = {token: i for token, i in self.special_tokens.items()}
# Build word frequency dictionary
vocab = defaultdict(int)
for text in texts:
words = self._pre_tokenize(text.lower())
for word in words:
# End word with </w>
word = ' '.join(list(word)) + ' </w>'
vocab[tuple(word.split())] += 1
# Convert to string format
vocab = {' '.join(k): v for k, v in vocab.items()}
# Add individual characters to vocab
for word in vocab:
for char in word.split():
if char not in self.vocab:
self.vocab[char] = len(self.vocab)
# BPE training
num_merges = self.vocab_size - len(self.vocab)
for i in range(num_merges):
pairs = self._get_stats(vocab)
if not pairs:
break
best = max(pairs, key=pairs.get)
vocab = self._merge_vocab(best, vocab)
self.merges.append(best)
# Add merged token to vocab
merged_token = ''.join(best)
if merged_token not in self.vocab:
self.vocab[merged_token] = len(self.vocab)
if (i + 1) % 1000 == 0:
print(f" Completed {i + 1}/{num_merges} merges")
print(f"Final vocabulary size: {len(self.vocab)}")
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
"""Encode text to token IDs."""
words = self._pre_tokenize(text)
token_ids = []
if add_special_tokens:
token_ids.append(self.bos_token_id)
for word in words:
word = word.lower()
word_tokens = ' '.join(list(word)) + ' </w>'
# Apply BPE merges
for merge in self.merges:
bigram = re.escape(' '.join(merge))
pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
word_tokens = pattern.sub(''.join(merge), word_tokens)
# Convert to IDs
for token in word_tokens.split():
token_ids.append(self.vocab.get(token, self.unk_token_id))
if add_special_tokens:
token_ids.append(self.eos_token_id)
return token_ids
def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
"""Decode token IDs to text."""
# Build reverse vocab
reverse_vocab = {v: k for k, v in self.vocab.items()}
tokens = []
for token_id in token_ids:
if token_id in self.special_tokens.values() and skip_special_tokens:
continue
tokens.append(reverse_vocab.get(token_id, '<unk>'))
text = ''.join(tokens)
text = text.replace('</w>', ' ')
return text.strip()
def save(self, path: str):
"""Save tokenizer to file."""
data = {
'vocab': self.vocab,
'merges': self.merges,
'special_tokens': self.special_tokens,
'vocab_size': self.vocab_size,
}
with open(path, 'wb') as f:
pickle.dump(data, f)
print(f"Tokenizer saved to {path}")
def load(self, path: str):
"""Load tokenizer from file."""
with open(path, 'rb') as f:
data = pickle.load(f)
self.vocab = data['vocab']
self.merges = data['merges']
self.special_tokens = data['special_tokens']
self.vocab_size = data['vocab_size']
self.pad_token_id = self.special_tokens['<pad>']
self.unk_token_id = self.special_tokens['<unk>']
self.bos_token_id = self.special_tokens['<s>']
self.eos_token_id = self.special_tokens['</s>']
self.mask_token_id = self.special_tokens['<mask>']
print(f"Tokenizer loaded from {path}")
def batch_encode(
self,
texts: List[str],
max_length: int = 512,
padding: bool = True,
truncation: bool = True,
) -> Dict[str, List]:
"""Batch encode texts."""
encoded = [self.encode(text) for text in texts]
if truncation:
encoded = [seq[:max_length] for seq in encoded]
if padding:
max_len = min(max(len(seq) for seq in encoded), max_length)
attention_mask = []
for seq in encoded:
mask = [1] * len(seq) + [0] * (max_len - len(seq))
seq.extend([self.pad_token_id] * (max_len - len(seq)))
attention_mask.append(mask[:max_len])
else:
attention_mask = [[1] * len(seq) for seq in encoded]
return {
'input_ids': encoded,
'attention_mask': attention_mask,
}
def __len__(self):
return len(self.vocab)
class ByteLevelBPETokenizer:
"""Byte-level BPE tokenizer (similar to GPT-2/3)."""
def __init__(self, vocab_size: int = 32000):
self.vocab_size = vocab_size
self.vocab = {}
self.merges = []
self.byte_encoder = {i: chr(i + 128) for i in range(256)} # Shift to printable range
self.byte_decoder = {chr(i + 128): i for i in range(256)}
self.special_tokens = {
'<|endoftext|>': 0,
'<|pad|>': 1,
}
self.eos_token_id = 0
self.pad_token_id = 1
def _bytes_to_unicode(self, text: str) -> str:
"""Convert string to byte-level representation."""
return ''.join(self.byte_encoder[b] for b in text.encode('utf-8'))
def _unicode_to_bytes(self, text: str) -> str:
"""Convert byte-level representation back to string."""
return bytes(self.byte_decoder[c] for c in text).decode('utf-8', errors='replace')
def train(self, texts: List[str]):
"""Train byte-level BPE."""
print(f"Training byte-level BPE tokenizer with vocab_size={self.vocab_size}")
# Initialize vocab with special tokens and all bytes
self.vocab = {token: i for token, i in self.special_tokens.items()}
for i in range(256):
byte_char = self.byte_encoder[i]
if byte_char not in self.vocab:
self.vocab[byte_char] = len(self.vocab)
# Build corpus as byte sequences
corpus = []
for text in texts:
byte_text = self._bytes_to_unicode(text)
corpus.extend(list(byte_text))
# Get initial word frequencies
vocab = defaultdict(int)
for text in texts:
byte_text = self._bytes_to_unicode(text)
# Add end token
byte_text += '<|endoftext|>'
vocab[tuple(byte_text)] += 1
# BPE training
num_merges = self.vocab_size - len(self.vocab)
for i in range(num_merges):
pairs = self._get_stats(vocab)
if not pairs:
break
best = max(pairs, key=pairs.get)
vocab = self._merge_vocab(best, vocab)
self.merges.append(best)
merged = ''.join(best)
if merged not in self.vocab:
self.vocab[merged] = len(self.vocab)
if (i + 1) % 1000 == 0:
print(f" Completed {i + 1}/{num_merges} merges")
print(f"Final vocabulary size: {len(self.vocab)}")
def _get_stats(self, vocab):
pairs = defaultdict(int)
for word, freq in vocab.items():
symbols = list(word)
for i in range(len(symbols) - 1):
pairs[(symbols[i], symbols[i + 1])] += freq
return pairs
def _merge_vocab(self, pair, vocab):
new_vocab = {}
bigram = pair[0] + pair[1]
for word in vocab:
new_word = []
i = 0
while i < len(word):
if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]:
new_word.append(bigram)
i += 2
else:
new_word.append(word[i])
i += 1
new_vocab[tuple(new_word)] = vocab[word]
return new_vocab
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
"""Encode text to token IDs."""
byte_text = self._bytes_to_unicode(text)
if add_special_tokens:
byte_text += '<|endoftext|>'
# Apply merges
word = list(byte_text)
for merge in self.merges:
new_word = []
i = 0
while i < len(word):
if i < len(word) - 1 and word[i] == merge[0] and word[i + 1] == merge[1]:
new_word.append(merge[0] + merge[1])
i += 2
else:
new_word.append(word[i])
i += 1
word = new_word
# Convert to IDs
return [self.vocab.get(token, self.special_tokens['<|pad|>']) for token in word]
def decode(self, token_ids: List[int]) -> str:
"""Decode token IDs to text."""
reverse_vocab = {v: k for k, v in self.vocab.items()}
text = ''.join(reverse_vocab.get(id, '') for id in token_ids)
text = text.replace('<|endoftext|>', '')
return self._unicode_to_bytes(text)
def save(self, path: str):
"""Save tokenizer to file."""
data = {
'vocab': self.vocab,
'merges': self.merges,
'special_tokens': self.special_tokens,
'vocab_size': self.vocab_size,
'byte_encoder': self.byte_encoder,
'byte_decoder': self.byte_decoder,
}
with open(path, 'wb') as f:
pickle.dump(data, f)
print(f"Tokenizer saved to {path}")
def load(self, path: str):
"""Load tokenizer from file."""
with open(path, 'rb') as f:
data = pickle.load(f)
self.vocab = data['vocab']
self.merges = data['merges']
self.special_tokens = data['special_tokens']
self.vocab_size = data['vocab_size']
self.byte_encoder = data.get('byte_encoder', self.byte_encoder)
self.byte_decoder = data.get('byte_decoder', self.byte_decoder)
# Ensure all special tokens exist
if '<|endoftext|>' not in self.special_tokens:
self.special_tokens['<|endoftext|>'] = 0
if '<|pad|>' not in self.special_tokens:
self.special_tokens['<|pad|>'] = 1
self.eos_token_id = self.special_tokens.get('<|endoftext|>', 0)
self.pad_token_id = self.special_tokens.get('<|pad|>', 1)
print(f"Tokenizer loaded from {path}")
def __len__(self):
return len(self.vocab)
def create_and_train_tokenizer(texts: List[str], vocab_size: int = 32000, output_path: str = "tokenizer.pkl"):
"""Create and train a tokenizer on the given texts."""
tokenizer = ByteLevelBPETokenizer(vocab_size=vocab_size)
tokenizer.train(texts)
tokenizer.save(output_path)
return tokenizer
if __name__ == "__main__":
# Test tokenizer
sample_texts = [
"Hello, world! This is a test.",
"The quick brown fox jumps over the lazy dog.",
"Machine learning is fascinating.",
"Artificial intelligence will change the world.",
]
tokenizer = BPETokenizer(vocab_size=1000)
tokenizer.train(sample_texts)
test_text = "Hello world!"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
print(f"\nOriginal: {test_text}")
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")