open_mp_generator / tokenizer.py
mohamedahraf273's picture
add generator
e8aab00
import json
from collections import Counter
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Tuple
class Tokenizer:
def __init__(self, vocab_size: int = 1000):
self.special_tokens = ['<PAD>', '<UNK>', '<SOS>', '<EOS>']
self.char2idx: Dict[str, int] = {}
self.idx2char: Dict[int, str] = {}
self.vocab_size: int = 0
self.target_vocab_size: int = vocab_size
self.bpe_ranks: Dict[Tuple[str, str], int] = {}
for idx, token in enumerate(self.special_tokens):
self.char2idx[token] = idx
self.idx2char[idx] = token
self.vocab_size = len(self.special_tokens)
def _get_stats(self, words: Dict[Tuple[str, ...], int]) -> Counter:
pairs = Counter()
for word, freq in words.items():
for i in range(len(word) - 1):
pairs[(word[i], word[i + 1])] += freq
return pairs
def _merge_vocab(
self, pair: Tuple[str, str], words: Dict[Tuple[str, ...], int]
) -> Dict[Tuple[str, ...], int]:
new_words = {}
replacement = "".join(pair)
for word in words:
new_word = []
i = 0
while i < len(word):
if (
i < len(word) - 1
and word[i] == pair[0]
and word[i + 1] == pair[1]
):
new_word.append(replacement)
i += 2
else:
new_word.append(word[i])
i += 1
new_words[tuple(new_word)] = words[word]
return new_words
def build_vocab(self, texts: List[str]) -> None:
print(f"Building BPE vocabulary from {len(texts)} texts...")
vocab = set()
for text in texts:
vocab.update(text)
for char in sorted(vocab):
if char not in self.char2idx:
self.char2idx[char] = self.vocab_size
self.idx2char[self.vocab_size] = char
self.vocab_size += 1
print(
f"Initial character vocabulary: "
f"{self.vocab_size - len(self.special_tokens)} characters"
)
words = defaultdict(int)
for text in texts:
word = tuple(text)
words[word] += 1
num_merges = self.target_vocab_size - self.vocab_size
print(f"Learning {num_merges} BPE merges...")
for i in range(num_merges):
pairs = self._get_stats(words)
if not pairs:
break
best_pair = max(pairs, key=pairs.get)
words = self._merge_vocab(best_pair, words)
new_token = ''.join(best_pair)
if new_token not in self.char2idx:
self.char2idx[new_token] = self.vocab_size
self.idx2char[self.vocab_size] = new_token
self.vocab_size += 1
self.bpe_ranks[best_pair] = i
if (i + 1) % 100 == 0:
print(
f" Learned {i + 1} merges, "
f"vocab size: {self.vocab_size}"
)
print(f"BPE Vocabulary built! Total tokens: {self.vocab_size}")
print(f" - Special tokens: {len(self.special_tokens)}")
print(f" - Base characters: {len(vocab)}")
print(f" - BPE subwords: {len(self.bpe_ranks)}")
print(f" - Sample subwords: {list(self.bpe_ranks.keys())[:5]}")
def _tokenize(self, text: str) -> List[str]:
if not text:
return []
word = tuple(text)
while len(word) > 1:
pairs = [(word[i], word[i + 1]) for i in range(len(word) - 1)]
valid_pairs = [p for p in pairs if p in self.bpe_ranks]
if not valid_pairs:
break
bigram = min(valid_pairs, key=lambda p: self.bpe_ranks[p])
new_word = []
i = 0
while i < len(word):
if (
i < len(word) - 1
and word[i] == bigram[0]
and word[i + 1] == bigram[1]
):
new_word.append("".join(bigram))
i += 2
else:
new_word.append(word[i])
i += 1
word = tuple(new_word)
return list(word)
def add_token(self, token: str) -> None:
if token not in self.char2idx:
idx = self.vocab_size
self.char2idx[token] = idx
self.idx2char[idx] = token
self.vocab_size += 1
def encode(
self, text: str, max_length: int, add_special_tokens: bool = True
) -> List[int]:
tokens = self._tokenize(text)
indices = []
if add_special_tokens:
indices.append(self.char2idx['<SOS>'])
for token in tokens[:max_length - (2 if add_special_tokens else 0)]:
indices.append(self.char2idx.get(token, self.char2idx['<UNK>']))
if add_special_tokens:
indices.append(self.char2idx['<EOS>'])
while len(indices) < max_length:
indices.append(self.char2idx['<PAD>'])
return indices
def decode(self, indices: List[int]) -> str:
chars = []
for idx in indices:
token = self.idx2char.get(idx, '<UNK>')
if token == '<EOS>':
break
if token not in ['<PAD>', '<SOS>', '<UNK>']:
chars.append(token)
return ''.join(chars)
def save(self, filepath: str) -> None:
state = {
"char2idx": self.char2idx,
"special_tokens": self.special_tokens,
"vocab_size": self.vocab_size,
"target_vocab_size": self.target_vocab_size,
"bpe_ranks": {
f"{k[0]}_{k[1]}": v for k, v in self.bpe_ranks.items()
},
}
with open(filepath, "w") as f:
json.dump(state, f, indent=2)
print(f"BPE Tokenizer saved to {filepath}")
def load(self, filepath: str) -> "Tokenizer":
with open(filepath, "r") as f:
state = json.load(f)
self.char2idx = state["char2idx"]
self.special_tokens = state["special_tokens"]
self.vocab_size = state["vocab_size"]
self.target_vocab_size = state.get("target_vocab_size", 1000)
self.idx2char = {v: k for k, v in self.char2idx.items()}
if "bpe_ranks" in state:
self.bpe_ranks = {}
for key, value in state["bpe_ranks"].items():
parts = key.split("_", 1)
if len(parts) == 2:
self.bpe_ranks[(parts[0], parts[1])] = value
print(f"BPE Tokenizer loaded from {filepath}")
print(f" - Vocab size: {self.vocab_size}")
print(f" - BPE merges: {len(self.bpe_ranks)}")
return self