|
|
""" |
|
|
Train a tokenizer using the HuggingFace Tokenizers library. |
|
|
In the style of GPT-4 tokenizer. |
|
|
""" |
|
|
import os |
|
|
import time |
|
|
import argparse |
|
|
import torch |
|
|
from nanochat.tokenizer import RustBPETokenizer |
|
|
from nanochat.common import get_base_dir |
|
|
from nanochat.dataset import parquets_iter_batched |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Train a BPE tokenizer') |
|
|
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') |
|
|
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') |
|
|
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)') |
|
|
args = parser.parse_args() |
|
|
print(f"max_chars: {args.max_chars:,}") |
|
|
print(f"doc_cap: {args.doc_cap:,}") |
|
|
print(f"vocab_size: {args.vocab_size:,}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def text_iterator(): |
|
|
""" |
|
|
1) Flatten the batches into a single iterator |
|
|
2) Crop every document to args.doc_cap characters |
|
|
3) Break when we've seen args.max_chars characters |
|
|
""" |
|
|
nchars = 0 |
|
|
for batch in parquets_iter_batched(split="train"): |
|
|
for doc in batch: |
|
|
doc_text = doc |
|
|
if len(doc_text) > args.doc_cap: |
|
|
doc_text = doc_text[:args.doc_cap] |
|
|
nchars += len(doc_text) |
|
|
yield doc_text |
|
|
if nchars > args.max_chars: |
|
|
return |
|
|
text_iter = text_iterator() |
|
|
|
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size) |
|
|
t1 = time.time() |
|
|
train_time = t1 - t0 |
|
|
print(f"Training time: {train_time:.2f}s") |
|
|
|
|
|
|
|
|
|
|
|
base_dir = get_base_dir() |
|
|
tokenizer_dir = os.path.join(base_dir, "tokenizer") |
|
|
tokenizer.save(tokenizer_dir) |
|
|
|
|
|
|
|
|
|
|
|
test_text = """Hello world! This is a test. |
|
|
Numbers: 123, 4567, 89 |
|
|
Contractions: I'm, you're, it's |
|
|
Special chars: @#$%^&*() |
|
|
Unicode: 你好世界 🌍""" |
|
|
encoded = tokenizer.encode(test_text) |
|
|
decoded = tokenizer.decode(encoded) |
|
|
assert decoded == test_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocab_size = tokenizer.get_vocab_size() |
|
|
special_set = set(tokenizer.get_special_tokens()) |
|
|
token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)] |
|
|
token_bytes = [] |
|
|
for token_id in range(vocab_size): |
|
|
token_str = token_strings[token_id] |
|
|
if token_str in special_set: |
|
|
token_bytes.append(0) |
|
|
else: |
|
|
id_bytes = len(token_str.encode("utf-8")) |
|
|
token_bytes.append(id_bytes) |
|
|
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu') |
|
|
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") |
|
|
with open(token_bytes_path, "wb") as f: |
|
|
torch.save(token_bytes, f) |
|
|
print(f"Saved token_bytes to {token_bytes_path}") |
|
|
|
|
|
|
|
|
from nanochat.report import get_report |
|
|
token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32) |
|
|
get_report().log(section="Tokenizer training", data=[ |
|
|
vars(args), |
|
|
{"train_time": train_time}, |
|
|
{"num_special_tokens": len(special_set)}, |
|
|
{ |
|
|
"token_bytes_min": int(token_bytes_nonzero.min().item()), |
|
|
"token_bytes_max": int(token_bytes_nonzero.max().item()), |
|
|
"token_bytes_mean": token_bytes_nonzero.mean().item(), |
|
|
"token_bytes_std": token_bytes_nonzero.std().item(), |
|
|
} |
|
|
]) |
|
|
|