File size: 2,512 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""tilelli.distillery.tokenize — day-0 byte-level tokenizer.

Why byte-level:
  - Zero training. Deterministic. No BPE merges table, no corpus sweep.
  - Universal coverage: any text, any language, any code, any math symbol
    fits in 256 ids. Perfect for our four initial sources — English,
    Python, Ubuntu commands, math — without a single special case.
  - Aligns with the manifesto's "built from absolute zero" clause. We
    literally implemented it in twenty lines.
  - A BPE-style learned tokenizer can replace this later as a Distillery
    upgrade. Until then, every downstream piece (shard, trainer,
    probes) works against the byte interface and benefits for free when
    the tokenizer improves.

Limits we accept day-0:
  - Sequence length in bytes is ~3-4× that of a good BPE tokenizer for
    English, ~1× for code. This matters for context-window calculations
    but not for correctness. We're validating the architecture, not
    pushing tokens/second yet.
"""
from __future__ import annotations

from typing import Iterable

import torch
from torch import Tensor


class ByteTokenizer:
    """UTF-8 byte-level tokenizer. Vocab size is fixed at 256.

    encode(text) and decode(ids) are exact inverses for any str input:
    the encode path is ``text.encode("utf-8")`` and decode is
    ``bytes(ids).decode("utf-8", errors="replace")``. The ``errors="replace"``
    is a conservative default so decode never raises — useful when
    sampling mid-sequence leaves us with a dangling multi-byte
    codepoint.
    """

    vocab_size: int = 256

    def encode(self, text: str) -> Tensor:
        """str → 1-D int64 tensor of byte ids.

        Uses ``torch.frombuffer`` so encoding a 50 MB text doesn't
        allocate a 1.4 GB Python list of ints on the way through.
        The ``bytearray`` wrapper is what makes the buffer writable,
        which ``frombuffer`` requires.
        """
        data = text.encode("utf-8")
        if not data:
            return torch.empty(0, dtype=torch.int64)
        buf = torch.frombuffer(bytearray(data), dtype=torch.uint8)
        return buf.to(torch.int64)

    def decode(self, ids: Tensor | Iterable[int]) -> str:
        """1-D tensor (or iterable of ints) → str."""
        if isinstance(ids, Tensor):
            if ids.dim() != 1:
                raise ValueError(f"expected 1-D tensor, got shape {tuple(ids.shape)}")
            ids = ids.tolist()
        return bytes(int(i) for i in ids).decode("utf-8", errors="replace")