mr_chatterbox / tokenizer_wrapper.py
tventurella's picture
Upload 17 files
59856b4 verified
"""
tokenizer_wrapper.py β€” nanochat-compatible wrapper for the Victorian BPE tokenizer
nanochat's base_train.py imports:
from nanochat.tokenizer import get_tokenizer, get_token_bytes
This wrapper provides a VictorianTokenizer class that satisfies nanochat's full
interface, plus get_tokenizer() and get_token_bytes() drop-in replacements.
Special token mapping:
<|endoftext|> β†’ bos (document boundary, prepended to every document)
<|pad|> β†’ pad
<human> β†’ user_start (replaces nanochat's <|user_start|>)
<victorian> β†’ assistant_start (replaces nanochat's <|assistant_start|>)
Usage β€” patch nanochat/tokenizer.py by adding at the bottom:
from pathlib import Path
import sys
sys.path.insert(0, "/path/to/victorian")
from tokenizer_wrapper import get_tokenizer, get_token_bytes
"""
from pathlib import Path
import torch
from tokenizers import Tokenizer
TOKENIZER_PATH = Path(__file__).parent / "tokenizer.json"
class VictorianTokenizer:
"""
Wraps our HuggingFace BPE tokenizer to match nanochat's expected interface.
"""
def __init__(self, tokenizer_path: str | Path = TOKENIZER_PATH):
self._tok = Tokenizer.from_file(str(tokenizer_path))
self._tok.no_padding()
self._tok.no_truncation()
# ------------------------------------------------------------------
# Core nanochat interface (used by dataloader and base_train.py)
# ------------------------------------------------------------------
def get_vocab_size(self) -> int:
return self._tok.get_vocab_size()
def get_bos_token_id(self) -> int:
"""Prepended to every document by nanochat's dataloader."""
return self._tok.token_to_id("<|endoftext|>")
def encode(
self,
texts: list[str] | str,
prepend: int | str | None = None,
append: int | str | None = None,
num_threads: int = 4,
) -> list[int] | list[list[int]]:
"""
Encode strings β†’ token ID list(s).
Matches nanochat's native tokenizer behaviour exactly:
- Single string β†’ list[int]
- List of strings β†’ list[list[int]]
prepend/append may be an int token ID or a special-token string
(e.g. prepend="<|bos|>"), matching nanochat's _encode_one interface.
"""
single = isinstance(texts, str)
if single:
texts = [texts]
# Resolve string prepend/append to token IDs (e.g. "<|bos|>" β†’ 0)
if isinstance(prepend, str):
prepend = self.encode_special(prepend)
if isinstance(append, str):
append = self.encode_special(append)
encodings = self._tok.encode_batch(texts, is_pretokenized=False)
ids = [enc.ids for enc in encodings]
if prepend is not None:
ids = [[prepend] + seq for seq in ids]
if append is not None:
ids = [seq + [append] for seq in ids]
# Single string β†’ flat list[int] to match nanochat's native encode()
return ids[0] if single else ids
def decode(self, ids: list[int]) -> str:
return self._tok.decode(ids)
# ------------------------------------------------------------------
# Special token accessors
# ------------------------------------------------------------------
def encode_special(self, token: str) -> int | None:
"""
Look up a special token ID by exact match.
Maps nanochat's native special tokens to Victorian equivalents where needed.
Required by nanochat's engine.py for sample generation.
"""
# Try exact match first (covers our own special tokens)
result = self._tok.token_to_id(token)
if result is not None:
return result
# Map nanochat's native chat tokens to Victorian equivalents
_map = {
"<|assistant_start|>": "<victorian>",
"<|assistant_end|>": "<|endoftext|>",
"<|user_start|>": "<human>",
"<|user_end|>": "<|endoftext|>",
"<|bos|>": "<|endoftext|>",
"<|eos|>": "<|endoftext|>",
}
mapped = _map.get(token)
if mapped:
return self._tok.token_to_id(mapped)
return None
def get_pad_token_id(self) -> int:
return self._tok.token_to_id("<|pad|>")
def get_user_start_id(self) -> int:
"""Maps to nanochat's <|user_start|> role."""
return self._tok.token_to_id("<human>")
def get_assistant_start_id(self) -> int:
"""Maps to nanochat's <|assistant_start|> role."""
return self._tok.token_to_id("<victorian>")
# ------------------------------------------------------------------
# Chat / fine-tuning interface (used by chat_sft.py)
# ------------------------------------------------------------------
def render_conversation(
self,
conversation: list[dict],
max_tokens: int = 2048,
) -> tuple[list[int], list[int]]:
"""
Encode a conversation into token IDs and a loss mask.
conversation: list of {"role": "user"|"assistant", "content": str}
Returns: (token_ids, loss_mask) β€” loss_mask is 1 for assistant tokens, 0 otherwise.
Victorian mapping:
"user" β†’ <human> ...
"assistant" β†’ <victorian> ... <|endoftext|> (end token trains model to stop)
"""
human_id = self.get_user_start_id()
victorian_id = self.get_assistant_start_id()
bos_id = self.get_bos_token_id()
tokens: list[int] = [bos_id]
mask: list[int] = [0]
for turn in conversation:
role = turn["role"]
content = turn["content"]
content_ids = self.encode(content)
if role == "user":
turn_tokens = [human_id] + content_ids
turn_mask = [0] * len(turn_tokens)
else: # assistant
turn_tokens = [victorian_id] + content_ids + [bos_id]
turn_mask = [1] * len(turn_tokens)
tokens.extend(turn_tokens)
mask.extend(turn_mask)
if len(tokens) >= max_tokens:
tokens = tokens[:max_tokens]
mask = mask[:max_tokens]
break
return tokens, mask
# ------------------------------------------------------------------
def __call__(self, texts, **kwargs):
"""Allow tokenizer(texts, ...) as an alias for encode() β€” required by nanochat's core_eval."""
return self.encode(texts, **kwargs)
@property
def vocab_size(self) -> int:
return self.get_vocab_size()
def __repr__(self) -> str:
return (
f"VictorianTokenizer(vocab_size={self.vocab_size}, "
f"bos={self.get_bos_token_id()}, "
f"human={self.get_user_start_id()}, "
f"victorian={self.get_assistant_start_id()})"
)
# ---------------------------------------------------------------------------
# nanochat drop-in functions
# ---------------------------------------------------------------------------
_tokenizer_singleton: VictorianTokenizer | None = None
def get_tokenizer(tokenizer_path: str | Path = TOKENIZER_PATH) -> VictorianTokenizer:
"""Drop-in replacement for nanochat's get_tokenizer()."""
global _tokenizer_singleton
if _tokenizer_singleton is None:
_tokenizer_singleton = VictorianTokenizer(tokenizer_path)
return _tokenizer_singleton
def get_token_bytes(device: str | torch.device = "cpu") -> torch.Tensor:
"""
Drop-in replacement for nanochat's get_token_bytes().
Returns a 1D tensor of shape [vocab_size] where each entry is the
UTF-8 byte length of that token. Used by base_train.py to convert
loss from nats/token β†’ bits/byte (the BPB evaluation metric).
"""
tok = get_tokenizer()
vocab = tok._tok.get_vocab() # {token_str: id}
vocab_size = tok.get_vocab_size()
# Build id β†’ token string mapping
id_to_token = {v: k for k, v in vocab.items()}
byte_lengths = []
for i in range(vocab_size):
token_str = id_to_token.get(i, "")
# ByteLevel BPE: Δ  represents a leading space (0x20).
# Decode the display string back to actual bytes for a correct byte count.
try:
# Replace Δ  with space, then encode to UTF-8
actual = token_str.replace("Δ ", " ").replace("Ċ", "\n").replace("Δ‰", "\t")
n_bytes = len(actual.encode("utf-8"))
except Exception:
n_bytes = 1
byte_lengths.append(max(1, n_bytes)) # floor at 1 to avoid div-by-zero
return torch.tensor(byte_lengths, dtype=torch.long, device=device)
# ---------------------------------------------------------------------------
# Sanity check
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import sys
if not TOKENIZER_PATH.exists():
print(f"Tokenizer not found at {TOKENIZER_PATH}")
sys.exit(1)
tok = get_tokenizer()
print(tok)
print(f" pad={tok.get_pad_token_id()}")
texts = [
"It is a truth universally acknowledged.",
"The phrenological examination was most illuminating, dear fellow.",
]
ids = tok.encode(texts, prepend=tok.get_bos_token_id())
for text, seq in zip(texts, ids):
decoded = tok.decode(seq[1:])
ok = "βœ“" if decoded == text else "βœ—"
print(f" {ok} {len(seq):3d} tokens {text!r}")
# Test render_conversation
conv = [
{"role": "user", "content": "What is your opinion on the railways?"},
{"role": "assistant", "content": "The railways are a most alarming development, yet undeniably useful."},
]
token_ids, loss_mask = tok.render_conversation(conv)
print(f"\n render_conversation: {len(token_ids)} tokens, "
f"{sum(loss_mask)} assistant tokens in loss mask")
# Test get_token_bytes
tb = get_token_bytes()
print(f"\n get_token_bytes: shape={tuple(tb.shape)}, "
f"mean={tb.mean():.2f} bytes/token, "
f"min={tb.min():.0f}, max={tb.max():.0f}")