tokenizers-training / src /tokenizers_trainer.py
theformatisvalid's picture
Upload 7 files
0463151 verified
import os
import re
import json
CORPUS_FILE = 'core/united_core.txt'
VOCAB_SIZE = 10000
OUTPUT_DIR = 'tokenizers'
os.makedirs(OUTPUT_DIR, exist_ok=True)
def simple_tokenize(text):
return re.findall(r'\S+', text)
def train_bpe(vocab_size, min_freq, corpus_path=None):
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
tokenizer = Tokenizer(BPE(unk_token='<UNK>'))
tokenizer.pre_tokenizer = Whitespace()
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=min_freq,
special_tokens=['<UNK>', '<NUM>', '<URL>', '<EMAIL>'],
continuing_subword_prefix='',
)
tokenizer.train(files=[corpus_path if corpus_path else CORPUS_FILE], trainer=trainer)
dir_path = f'{OUTPUT_DIR}/bpe_v{vocab_size//1000}k_f{min_freq}'
os.makedirs(dir_path, exist_ok=True)
tokenizer.save(os.path.join(dir_path, 'tokenizer.json'))
tokenizer_config = {
"added_tokens_decoder": {},
"unk_token": "<UNK>",
"cls_token": None,
"sep_token": None,
"mask_token": None,
"model_max_length": 512,
}
for token in ['<UNK>', '<NUM>', '<URL>', '<EMAIL>']:
t_id = str(tokenizer.encode(token).ids[0])
tokenizer_config['added_tokens_decoder'][t_id] = {
"content": token,
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
}
with open(os.path.join(dir_path, "tokenizer_config.json"), "w", encoding="utf-8") as file:
json.dump(tokenizer_config, file, indent=2)
special_tokens_map = {
"unk_token": "<UNK>",
}
with open(os.path.join(dir_path, "special_tokens_map.json"), "w", encoding="utf-8") as file:
json.dump(special_tokens_map, file, indent=2)
return tokenizer
def train_wordpiece(vocab_size, min_freq, corpus_path=None):
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace
tokenizer = Tokenizer(WordPiece(unk_token='<UNK>'))
tokenizer.pre_tokenizer = Whitespace()
trainer = WordPieceTrainer(
vocab_size=vocab_size,
min_frequency=min_freq,
special_tokens=['<UNK>', '<NUM>', '<URL>', '<EMAIL>'],
continuing_subword_prefix='',
)
tokenizer.train(files=[corpus_path if corpus_path else CORPUS_FILE], trainer=trainer)
path = f'{OUTPUT_DIR}/wordpiece_v{vocab_size}_f{min_freq}.json'
tokenizer.save(path)
return tokenizer
def train_unigram(vocab_size, min_freq, corpus_path=None):
import sentencepiece as spm
model_prefix = f'{OUTPUT_DIR}/unigram_v{vocab_size}_f{min_freq}'
try:
spm.SentencePieceTrainer.train(
input=corpus_path if corpus_path else CORPUS_FILE,
model_prefix=model_prefix,
vocab_size=vocab_size,
model_type='unigram',
character_coverage=0.9995,
pad_id=0, unk_id=1, bos_id=-1, eos_id=-1,
user_defined_symbols='<NUM>,<URL>,<EMAIL>',
shuffle_input_sentence=True,
input_sentence_size=100000,
normalization_rule_name='nmt_nfkc',
num_threads=8
)
sp = spm.SentencePieceProcessor()
sp.load(f'{model_prefix}.model')
return sp
except RuntimeError as e:
raise e
def fragmentation_rate(tokenizer_func, texts):
total_words = 0
fragmented = 0
for text in texts:
words = simple_tokenize(text)
for word in words:
tokens = tokenizer_func(word)
total_words += 1
if len(tokens) > 1:
fragmented += 1
return fragmented / total_words if total_words else 0
def compression_ratio(tokenizer_func, texts):
total_syms = 0
total_tokens = 0
for text in texts:
tokens = tokenizer_func(text)
total_syms += len(text)
total_tokens += len(tokens)
return total_syms / total_tokens if total_tokens else 0
def reconstruction_accuracy(tokenizer_obj, texts, model_type='hf'):
reconstructed_ok = 0
total_words = 0
for text in texts:
words = simple_tokenize(text)
for word in words:
total_words += 1
try:
if model_type == 'hf':
tokens = tokenizer_obj.encode(word).tokens
decoded = tokenizer_obj.decode(
tokenizer_obj.encode(word).ids,
skip_special_tokens=True
)
cleaned_decoded = re.sub(r'\s+', '', decoded.lower())
cleaned_word = re.sub(r'\s+|[^\w]', '', word.lower())
if cleaned_decoded == cleaned_word:
reconstructed_ok += 1
else:
pass
elif model_type == 'sp':
pieces = tokenizer_obj.encode_as_pieces(word)
decoded = ''.join(pieces).replace('▁', '')
cleaned_decoded = re.sub(r'\s+|[^\w]', '', decoded.lower())
cleaned_word = re.sub(r'\s+|[^\w]', '', word.lower())
if cleaned_decoded == cleaned_word:
reconstructed_ok += 1
except:
pass
return reconstructed_ok / total_words if total_words else 0
if __name__ == '__main__':
texts = []
with open('core/united_core.txt', encoding='utf-8') as file:
texts = file.readlines()
vocab_sizes = [8000, 16000, 32000]
min_freqs = [2, 3, 4, 5]
results = []
for vocab_size in vocab_sizes:
for min_freq in min_freqs:
print(f'vocab_size {vocab_size} min_freq {min_freq}')
try:
bpe = train_bpe(vocab_size, min_freq)
bpe_func = lambda x: bpe.encode(x).tokens
bpe_frag = fragmentation_rate(bpe_func, texts)
bpe_comp = compression_ratio(bpe_func, texts)
bpe_recon = reconstruction_accuracy(bpe, texts, model_type='hf')
results.append({
'model': 'BPE',
'vocab_size': vocab_size,
'min_freq': min_freq,
'fragmentation_rate': bpe_frag,
'compression_ratio': bpe_comp,
'reconstruction_acc': bpe_recon
})
except Exception as e:
print(f'BPE error: {e}')
try:
wp = train_wordpiece(vocab_size, min_freq)
wp_func = lambda x: wp.encode(x).tokens
wp_frag = fragmentation_rate(wp_func, texts)
wp_comp = compression_ratio(wp_func, texts)
wp_recon = reconstruction_accuracy(wp, texts, model_type='hf')
results.append({
'model': 'WordPiece',
'vocab_size': vocab_size,
'min_freq': min_freq,
'fragmentation_rate': wp_frag,
'compression_ratio': wp_comp,
'reconstruction_acc': wp_recon
})
except Exception as e:
print(f'WordPiece error: {e}')
try:
unigram = train_unigram(vocab_size, min_freq)
if unigram is not None:
uni_func = lambda x: unigram.encode_as_pieces(x)
uni_frag = fragmentation_rate(uni_func, texts)
uni_comp = compression_ratio(uni_func, texts)
uni_recon = reconstruction_accuracy(unigram, texts, model_type='sp')
results.append({
'model': 'Unigram',
'vocab_size': vocab_size,
'min_freq': min_freq,
'fragmentation_rate': uni_frag,
'compression_ratio': uni_comp,
'reconstruction_acc': uni_recon
})
except Exception as e:
print(f'Unigram error: {e}')
with open('reports/hf_sp_metrics.csv', 'w') as file:
file.write('model;vocab_size;min_freq;fragmentation_rate;compression_ratio;reconstruction_accuracy\n')
for r in results:
file.write(f'{r["model"]};{r["vocab_size"]};{r["min_freq"]};{round(r["fragmentation_rate"], 3)};')
file.write(f'{round(r["compression_ratio"], 3)};{round(r["reconstruction_acc"], 3)}\n')