File size: 2,512 Bytes
f86dc09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | """tilelli.distillery.tokenize — day-0 byte-level tokenizer.
Why byte-level:
- Zero training. Deterministic. No BPE merges table, no corpus sweep.
- Universal coverage: any text, any language, any code, any math symbol
fits in 256 ids. Perfect for our four initial sources — English,
Python, Ubuntu commands, math — without a single special case.
- Aligns with the manifesto's "built from absolute zero" clause. We
literally implemented it in twenty lines.
- A BPE-style learned tokenizer can replace this later as a Distillery
upgrade. Until then, every downstream piece (shard, trainer,
probes) works against the byte interface and benefits for free when
the tokenizer improves.
Limits we accept day-0:
- Sequence length in bytes is ~3-4× that of a good BPE tokenizer for
English, ~1× for code. This matters for context-window calculations
but not for correctness. We're validating the architecture, not
pushing tokens/second yet.
"""
from __future__ import annotations
from typing import Iterable
import torch
from torch import Tensor
class ByteTokenizer:
"""UTF-8 byte-level tokenizer. Vocab size is fixed at 256.
encode(text) and decode(ids) are exact inverses for any str input:
the encode path is ``text.encode("utf-8")`` and decode is
``bytes(ids).decode("utf-8", errors="replace")``. The ``errors="replace"``
is a conservative default so decode never raises — useful when
sampling mid-sequence leaves us with a dangling multi-byte
codepoint.
"""
vocab_size: int = 256
def encode(self, text: str) -> Tensor:
"""str → 1-D int64 tensor of byte ids.
Uses ``torch.frombuffer`` so encoding a 50 MB text doesn't
allocate a 1.4 GB Python list of ints on the way through.
The ``bytearray`` wrapper is what makes the buffer writable,
which ``frombuffer`` requires.
"""
data = text.encode("utf-8")
if not data:
return torch.empty(0, dtype=torch.int64)
buf = torch.frombuffer(bytearray(data), dtype=torch.uint8)
return buf.to(torch.int64)
def decode(self, ids: Tensor | Iterable[int]) -> str:
"""1-D tensor (or iterable of ints) → str."""
if isinstance(ids, Tensor):
if ids.dim() != 1:
raise ValueError(f"expected 1-D tensor, got shape {tuple(ids.shape)}")
ids = ids.tolist()
return bytes(int(i) for i in ids).decode("utf-8", errors="replace")
|