""" tokenizer_wrapper.py — nanochat-compatible wrapper for the Victorian BPE tokenizer nanochat's base_train.py imports: from nanochat.tokenizer import get_tokenizer, get_token_bytes This wrapper provides a VictorianTokenizer class that satisfies nanochat's full interface, plus get_tokenizer() and get_token_bytes() drop-in replacements. Special token mapping: <|endoftext|> → bos (document boundary, prepended to every document) <|pad|> → pad → user_start (replaces nanochat's <|user_start|>) → assistant_start (replaces nanochat's <|assistant_start|>) Usage — patch nanochat/tokenizer.py by adding at the bottom: from pathlib import Path import sys sys.path.insert(0, "/path/to/victorian") from tokenizer_wrapper import get_tokenizer, get_token_bytes """ from pathlib import Path import torch from tokenizers import Tokenizer TOKENIZER_PATH = Path(__file__).parent / "tokenizer.json" class VictorianTokenizer: """ Wraps our HuggingFace BPE tokenizer to match nanochat's expected interface. """ def __init__(self, tokenizer_path: str | Path = TOKENIZER_PATH): self._tok = Tokenizer.from_file(str(tokenizer_path)) self._tok.no_padding() self._tok.no_truncation() # ------------------------------------------------------------------ # Core nanochat interface (used by dataloader and base_train.py) # ------------------------------------------------------------------ def get_vocab_size(self) -> int: return self._tok.get_vocab_size() def get_bos_token_id(self) -> int: """Prepended to every document by nanochat's dataloader.""" return self._tok.token_to_id("<|endoftext|>") def encode( self, texts: list[str] | str, prepend: int | str | None = None, append: int | str | None = None, num_threads: int = 4, ) -> list[int] | list[list[int]]: """ Encode strings → token ID list(s). Matches nanochat's native tokenizer behaviour exactly: - Single string → list[int] - List of strings → list[list[int]] prepend/append may be an int token ID or a special-token string (e.g. prepend="<|bos|>"), matching nanochat's _encode_one interface. """ single = isinstance(texts, str) if single: texts = [texts] # Resolve string prepend/append to token IDs (e.g. "<|bos|>" → 0) if isinstance(prepend, str): prepend = self.encode_special(prepend) if isinstance(append, str): append = self.encode_special(append) encodings = self._tok.encode_batch(texts, is_pretokenized=False) ids = [enc.ids for enc in encodings] if prepend is not None: ids = [[prepend] + seq for seq in ids] if append is not None: ids = [seq + [append] for seq in ids] # Single string → flat list[int] to match nanochat's native encode() return ids[0] if single else ids def decode(self, ids: list[int]) -> str: return self._tok.decode(ids) # ------------------------------------------------------------------ # Special token accessors # ------------------------------------------------------------------ def encode_special(self, token: str) -> int | None: """ Look up a special token ID by exact match. Maps nanochat's native special tokens to Victorian equivalents where needed. Required by nanochat's engine.py for sample generation. """ # Try exact match first (covers our own special tokens) result = self._tok.token_to_id(token) if result is not None: return result # Map nanochat's native chat tokens to Victorian equivalents _map = { "<|assistant_start|>": "", "<|assistant_end|>": "<|endoftext|>", "<|user_start|>": "", "<|user_end|>": "<|endoftext|>", "<|bos|>": "<|endoftext|>", "<|eos|>": "<|endoftext|>", } mapped = _map.get(token) if mapped: return self._tok.token_to_id(mapped) return None def get_pad_token_id(self) -> int: return self._tok.token_to_id("<|pad|>") def get_user_start_id(self) -> int: """Maps to nanochat's <|user_start|> role.""" return self._tok.token_to_id("") def get_assistant_start_id(self) -> int: """Maps to nanochat's <|assistant_start|> role.""" return self._tok.token_to_id("") # ------------------------------------------------------------------ # Chat / fine-tuning interface (used by chat_sft.py) # ------------------------------------------------------------------ def render_conversation( self, conversation: list[dict], max_tokens: int = 2048, ) -> tuple[list[int], list[int]]: """ Encode a conversation into token IDs and a loss mask. conversation: list of {"role": "user"|"assistant", "content": str} Returns: (token_ids, loss_mask) — loss_mask is 1 for assistant tokens, 0 otherwise. Victorian mapping: "user" → ... "assistant" → ... <|endoftext|> (end token trains model to stop) """ human_id = self.get_user_start_id() victorian_id = self.get_assistant_start_id() bos_id = self.get_bos_token_id() tokens: list[int] = [bos_id] mask: list[int] = [0] for turn in conversation: role = turn["role"] content = turn["content"] content_ids = self.encode(content) if role == "user": turn_tokens = [human_id] + content_ids turn_mask = [0] * len(turn_tokens) else: # assistant turn_tokens = [victorian_id] + content_ids + [bos_id] turn_mask = [1] * len(turn_tokens) tokens.extend(turn_tokens) mask.extend(turn_mask) if len(tokens) >= max_tokens: tokens = tokens[:max_tokens] mask = mask[:max_tokens] break return tokens, mask # ------------------------------------------------------------------ def __call__(self, texts, **kwargs): """Allow tokenizer(texts, ...) as an alias for encode() — required by nanochat's core_eval.""" return self.encode(texts, **kwargs) @property def vocab_size(self) -> int: return self.get_vocab_size() def __repr__(self) -> str: return ( f"VictorianTokenizer(vocab_size={self.vocab_size}, " f"bos={self.get_bos_token_id()}, " f"human={self.get_user_start_id()}, " f"victorian={self.get_assistant_start_id()})" ) # --------------------------------------------------------------------------- # nanochat drop-in functions # --------------------------------------------------------------------------- _tokenizer_singleton: VictorianTokenizer | None = None def get_tokenizer(tokenizer_path: str | Path = TOKENIZER_PATH) -> VictorianTokenizer: """Drop-in replacement for nanochat's get_tokenizer().""" global _tokenizer_singleton if _tokenizer_singleton is None: _tokenizer_singleton = VictorianTokenizer(tokenizer_path) return _tokenizer_singleton def get_token_bytes(device: str | torch.device = "cpu") -> torch.Tensor: """ Drop-in replacement for nanochat's get_token_bytes(). Returns a 1D tensor of shape [vocab_size] where each entry is the UTF-8 byte length of that token. Used by base_train.py to convert loss from nats/token → bits/byte (the BPB evaluation metric). """ tok = get_tokenizer() vocab = tok._tok.get_vocab() # {token_str: id} vocab_size = tok.get_vocab_size() # Build id → token string mapping id_to_token = {v: k for k, v in vocab.items()} byte_lengths = [] for i in range(vocab_size): token_str = id_to_token.get(i, "") # ByteLevel BPE: Ġ represents a leading space (0x20). # Decode the display string back to actual bytes for a correct byte count. try: # Replace Ġ with space, then encode to UTF-8 actual = token_str.replace("Ġ", " ").replace("Ċ", "\n").replace("ĉ", "\t") n_bytes = len(actual.encode("utf-8")) except Exception: n_bytes = 1 byte_lengths.append(max(1, n_bytes)) # floor at 1 to avoid div-by-zero return torch.tensor(byte_lengths, dtype=torch.long, device=device) # --------------------------------------------------------------------------- # Sanity check # --------------------------------------------------------------------------- if __name__ == "__main__": import sys if not TOKENIZER_PATH.exists(): print(f"Tokenizer not found at {TOKENIZER_PATH}") sys.exit(1) tok = get_tokenizer() print(tok) print(f" pad={tok.get_pad_token_id()}") texts = [ "It is a truth universally acknowledged.", "The phrenological examination was most illuminating, dear fellow.", ] ids = tok.encode(texts, prepend=tok.get_bos_token_id()) for text, seq in zip(texts, ids): decoded = tok.decode(seq[1:]) ok = "✓" if decoded == text else "✗" print(f" {ok} {len(seq):3d} tokens {text!r}") # Test render_conversation conv = [ {"role": "user", "content": "What is your opinion on the railways?"}, {"role": "assistant", "content": "The railways are a most alarming development, yet undeniably useful."}, ] token_ids, loss_mask = tok.render_conversation(conv) print(f"\n render_conversation: {len(token_ids)} tokens, " f"{sum(loss_mask)} assistant tokens in loss mask") # Test get_token_bytes tb = get_token_bytes() print(f"\n get_token_bytes: shape={tuple(tb.shape)}, " f"mean={tb.mean():.2f} bytes/token, " f"min={tb.min():.0f}, max={tb.max():.0f}")