File size: 10,268 Bytes
59856b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
"""
tokenizer_wrapper.py β€” nanochat-compatible wrapper for the Victorian BPE tokenizer

nanochat's base_train.py imports:
    from nanochat.tokenizer import get_tokenizer, get_token_bytes

This wrapper provides a VictorianTokenizer class that satisfies nanochat's full
interface, plus get_tokenizer() and get_token_bytes() drop-in replacements.

Special token mapping:
    <|endoftext|>  β†’  bos (document boundary, prepended to every document)
    <|pad|>        β†’  pad
    <human>        β†’  user_start  (replaces nanochat's <|user_start|>)
    <victorian>    β†’  assistant_start  (replaces nanochat's <|assistant_start|>)

Usage β€” patch nanochat/tokenizer.py by adding at the bottom:
    from pathlib import Path
    import sys
    sys.path.insert(0, "/path/to/victorian")
    from tokenizer_wrapper import get_tokenizer, get_token_bytes
"""

from pathlib import Path
import torch
from tokenizers import Tokenizer

TOKENIZER_PATH = Path(__file__).parent / "tokenizer.json"


class VictorianTokenizer:
    """
    Wraps our HuggingFace BPE tokenizer to match nanochat's expected interface.
    """

    def __init__(self, tokenizer_path: str | Path = TOKENIZER_PATH):
        self._tok = Tokenizer.from_file(str(tokenizer_path))
        self._tok.no_padding()
        self._tok.no_truncation()

    # ------------------------------------------------------------------
    # Core nanochat interface (used by dataloader and base_train.py)
    # ------------------------------------------------------------------

    def get_vocab_size(self) -> int:
        return self._tok.get_vocab_size()

    def get_bos_token_id(self) -> int:
        """Prepended to every document by nanochat's dataloader."""
        return self._tok.token_to_id("<|endoftext|>")

    def encode(
        self,
        texts: list[str] | str,
        prepend: int | str | None = None,
        append: int | str | None = None,
        num_threads: int = 4,
    ) -> list[int] | list[list[int]]:
        """
        Encode strings β†’ token ID list(s).

        Matches nanochat's native tokenizer behaviour exactly:
          - Single string  β†’ list[int]
          - List of strings β†’ list[list[int]]

        prepend/append may be an int token ID or a special-token string
        (e.g. prepend="<|bos|>"), matching nanochat's _encode_one interface.
        """
        single = isinstance(texts, str)
        if single:
            texts = [texts]

        # Resolve string prepend/append to token IDs (e.g. "<|bos|>" β†’ 0)
        if isinstance(prepend, str):
            prepend = self.encode_special(prepend)
        if isinstance(append, str):
            append = self.encode_special(append)

        encodings = self._tok.encode_batch(texts, is_pretokenized=False)
        ids = [enc.ids for enc in encodings]

        if prepend is not None:
            ids = [[prepend] + seq for seq in ids]
        if append is not None:
            ids = [seq + [append] for seq in ids]

        # Single string β†’ flat list[int] to match nanochat's native encode()
        return ids[0] if single else ids

    def decode(self, ids: list[int]) -> str:
        return self._tok.decode(ids)

    # ------------------------------------------------------------------
    # Special token accessors
    # ------------------------------------------------------------------

    def encode_special(self, token: str) -> int | None:
        """
        Look up a special token ID by exact match.
        Maps nanochat's native special tokens to Victorian equivalents where needed.
        Required by nanochat's engine.py for sample generation.
        """
        # Try exact match first (covers our own special tokens)
        result = self._tok.token_to_id(token)
        if result is not None:
            return result
        # Map nanochat's native chat tokens to Victorian equivalents
        _map = {
            "<|assistant_start|>": "<victorian>",
            "<|assistant_end|>":   "<|endoftext|>",
            "<|user_start|>":      "<human>",
            "<|user_end|>":        "<|endoftext|>",
            "<|bos|>":             "<|endoftext|>",
            "<|eos|>":             "<|endoftext|>",
        }
        mapped = _map.get(token)
        if mapped:
            return self._tok.token_to_id(mapped)
        return None

    def get_pad_token_id(self) -> int:
        return self._tok.token_to_id("<|pad|>")

    def get_user_start_id(self) -> int:
        """Maps to nanochat's <|user_start|> role."""
        return self._tok.token_to_id("<human>")

    def get_assistant_start_id(self) -> int:
        """Maps to nanochat's <|assistant_start|> role."""
        return self._tok.token_to_id("<victorian>")

    # ------------------------------------------------------------------
    # Chat / fine-tuning interface (used by chat_sft.py)
    # ------------------------------------------------------------------

    def render_conversation(
        self,
        conversation: list[dict],
        max_tokens: int = 2048,
    ) -> tuple[list[int], list[int]]:
        """
        Encode a conversation into token IDs and a loss mask.

        conversation: list of {"role": "user"|"assistant", "content": str}
        Returns: (token_ids, loss_mask)  β€” loss_mask is 1 for assistant tokens, 0 otherwise.

        Victorian mapping:
            "user"      β†’ <human> ...
            "assistant" β†’ <victorian> ... <|endoftext|>  (end token trains model to stop)
        """
        human_id     = self.get_user_start_id()
        victorian_id = self.get_assistant_start_id()
        bos_id       = self.get_bos_token_id()

        tokens: list[int] = [bos_id]
        mask:   list[int] = [0]

        for turn in conversation:
            role    = turn["role"]
            content = turn["content"]
            content_ids = self.encode(content)

            if role == "user":
                turn_tokens = [human_id] + content_ids
                turn_mask   = [0] * len(turn_tokens)
            else:  # assistant
                turn_tokens = [victorian_id] + content_ids + [bos_id]
                turn_mask   = [1] * len(turn_tokens)

            tokens.extend(turn_tokens)
            mask.extend(turn_mask)

            if len(tokens) >= max_tokens:
                tokens = tokens[:max_tokens]
                mask   = mask[:max_tokens]
                break

        return tokens, mask

    # ------------------------------------------------------------------

    def __call__(self, texts, **kwargs):
        """Allow tokenizer(texts, ...) as an alias for encode() β€” required by nanochat's core_eval."""
        return self.encode(texts, **kwargs)

    @property
    def vocab_size(self) -> int:
        return self.get_vocab_size()

    def __repr__(self) -> str:
        return (
            f"VictorianTokenizer(vocab_size={self.vocab_size}, "
            f"bos={self.get_bos_token_id()}, "
            f"human={self.get_user_start_id()}, "
            f"victorian={self.get_assistant_start_id()})"
        )


# ---------------------------------------------------------------------------
# nanochat drop-in functions
# ---------------------------------------------------------------------------

_tokenizer_singleton: VictorianTokenizer | None = None


def get_tokenizer(tokenizer_path: str | Path = TOKENIZER_PATH) -> VictorianTokenizer:
    """Drop-in replacement for nanochat's get_tokenizer()."""
    global _tokenizer_singleton
    if _tokenizer_singleton is None:
        _tokenizer_singleton = VictorianTokenizer(tokenizer_path)
    return _tokenizer_singleton


def get_token_bytes(device: str | torch.device = "cpu") -> torch.Tensor:
    """
    Drop-in replacement for nanochat's get_token_bytes().

    Returns a 1D tensor of shape [vocab_size] where each entry is the
    UTF-8 byte length of that token. Used by base_train.py to convert
    loss from nats/token β†’ bits/byte (the BPB evaluation metric).
    """
    tok = get_tokenizer()
    vocab = tok._tok.get_vocab()  # {token_str: id}
    vocab_size = tok.get_vocab_size()

    # Build id β†’ token string mapping
    id_to_token = {v: k for k, v in vocab.items()}

    byte_lengths = []
    for i in range(vocab_size):
        token_str = id_to_token.get(i, "")
        # ByteLevel BPE: Δ  represents a leading space (0x20).
        # Decode the display string back to actual bytes for a correct byte count.
        try:
            # Replace Δ  with space, then encode to UTF-8
            actual = token_str.replace("Δ ", " ").replace("Ċ", "\n").replace("Δ‰", "\t")
            n_bytes = len(actual.encode("utf-8"))
        except Exception:
            n_bytes = 1
        byte_lengths.append(max(1, n_bytes))  # floor at 1 to avoid div-by-zero

    return torch.tensor(byte_lengths, dtype=torch.long, device=device)


# ---------------------------------------------------------------------------
# Sanity check
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    import sys

    if not TOKENIZER_PATH.exists():
        print(f"Tokenizer not found at {TOKENIZER_PATH}")
        sys.exit(1)

    tok = get_tokenizer()
    print(tok)
    print(f"  pad={tok.get_pad_token_id()}")

    texts = [
        "It is a truth universally acknowledged.",
        "The phrenological examination was most illuminating, dear fellow.",
    ]
    ids = tok.encode(texts, prepend=tok.get_bos_token_id())
    for text, seq in zip(texts, ids):
        decoded = tok.decode(seq[1:])
        ok = "βœ“" if decoded == text else "βœ—"
        print(f"  {ok}  {len(seq):3d} tokens  {text!r}")

    # Test render_conversation
    conv = [
        {"role": "user",      "content": "What is your opinion on the railways?"},
        {"role": "assistant", "content": "The railways are a most alarming development, yet undeniably useful."},
    ]
    token_ids, loss_mask = tok.render_conversation(conv)
    print(f"\n  render_conversation: {len(token_ids)} tokens, "
          f"{sum(loss_mask)} assistant tokens in loss mask")

    # Test get_token_bytes
    tb = get_token_bytes()
    print(f"\n  get_token_bytes: shape={tuple(tb.shape)}, "
          f"mean={tb.mean():.2f} bytes/token, "
          f"min={tb.min():.0f}, max={tb.max():.0f}")