jscoder-300m / tokenizer /tokenizer.py
Shadid's picture
Upload tokenizer/tokenizer.py with huggingface_hub
d1ee0ec verified
Raw
History Blame Contribute Delete
9.07 kB
#!/usr/bin/env python3
"""PyTorch-friendly wrapper around the trained byte-level BPE tokenizer.
``JSCoderTokenizer`` is the single object the rest of the pipeline (dataset
packing, training loop, inference) talks to. It hides the ``tokenizers``
library behind a small, typed API and returns ``torch`` tensors so it drops
straight into a ``DataLoader`` / training loop.
Example::
from tokenizer.tokenizer import JSCoderTokenizer
tok = JSCoderTokenizer.load()
ids = tok.encode("const x = 1;\\n") # list[int]
batch = tok.encode_batch([s1, s2], device="cpu") # padded tensors + mask
text = tok.decode(ids)
# Build a fill-in-the-middle prompt for autocomplete at the cursor:
prompt = tok.build_fim_prompt(prefix, suffix) # ends with <fim_middle>
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union
import torch
from tokenizers import Tokenizer
try:
from .special_tokens import (
EOT,
FIM_MIDDLE,
FIM_PREFIX,
FIM_SUFFIX,
PAD,
)
except ImportError: # pragma: no cover - script execution fallback
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent))
from special_tokens import ( # type: ignore
EOT,
FIM_MIDDLE,
FIM_PREFIX,
FIM_SUFFIX,
PAD,
)
REPO_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_TOKENIZER_PATH = REPO_ROOT / "tokenizer" / "js_bpe.json"
IdsLike = Union[Sequence[int], "torch.Tensor"]
class JSCoderTokenizer:
"""Thin, typed, torch-returning wrapper over a byte-level BPE tokenizer."""
def __init__(self, tokenizer: Tokenizer):
self._tk = tokenizer
# Resolve and cache the control-token ids once.
self.pad_id = self._require_id(PAD)
self.eot_id = self._require_id(EOT)
self.fim_prefix_id = self._require_id(FIM_PREFIX)
self.fim_middle_id = self._require_id(FIM_MIDDLE)
self.fim_suffix_id = self._require_id(FIM_SUFFIX)
# ------------------------------------------------------------------ #
# Construction
# ------------------------------------------------------------------ #
@classmethod
def load(cls, path: Union[str, Path] = DEFAULT_TOKENIZER_PATH) -> "JSCoderTokenizer":
path = Path(path)
if not path.exists():
raise FileNotFoundError(
f"No tokenizer at {path}. Train one first:\n"
f" python3 tokenizer/train_tokenizer.py"
)
return cls(Tokenizer.from_file(str(path)))
def _require_id(self, token: str) -> int:
token_id = self._tk.token_to_id(token)
if token_id is None:
raise ValueError(
f"Special token {token!r} is missing from the tokenizer vocab. "
"Retrain with tokenizer/train_tokenizer.py."
)
return token_id
# ------------------------------------------------------------------ #
# Vocab / metadata
# ------------------------------------------------------------------ #
@property
def vocab_size(self) -> int:
return self._tk.get_vocab_size()
def token_to_id(self, token: str) -> Optional[int]:
return self._tk.token_to_id(token)
def id_to_token(self, token_id: int) -> Optional[str]:
return self._tk.id_to_token(token_id)
# ------------------------------------------------------------------ #
# Core encode / decode
# ------------------------------------------------------------------ #
def encode(self, text: str, add_eot: bool = False) -> List[int]:
ids = self._tk.encode(text).ids
if add_eot:
ids.append(self.eot_id)
return ids
def encode_many(
self, texts: Sequence[str], add_eot: bool = False
) -> List[List[int]]:
"""Encode many texts at once across all CPU cores.
Delegates to the Rust ``tokenizers`` batch path, which releases the GIL
and parallelises with Rayon. This is dramatically faster than calling
:meth:`encode` in a Python loop when tokenizing large corpora.
"""
encodings = self._tk.encode_batch(list(texts))
if add_eot:
return [enc.ids + [self.eot_id] for enc in encodings]
return [enc.ids for enc in encodings]
def encode_to_tensor(
self,
text: str,
add_eot: bool = False,
device: Optional[Union[str, "torch.device"]] = None,
) -> "torch.Tensor":
ids = self.encode(text, add_eot=add_eot)
return torch.tensor(ids, dtype=torch.long, device=device)
def decode(self, ids: IdsLike, skip_special_tokens: bool = True) -> str:
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
return self._tk.decode(list(ids), skip_special_tokens=skip_special_tokens)
# ------------------------------------------------------------------ #
# Batched encoding with padding (ready for a DataLoader collate_fn)
# ------------------------------------------------------------------ #
def encode_batch(
self,
texts: Sequence[str],
add_eot: bool = False,
max_length: Optional[int] = None,
device: Optional[Union[str, "torch.device"]] = None,
) -> Dict[str, "torch.Tensor"]:
"""Encode and right-pad a batch.
Returns a dict with ``input_ids`` and ``attention_mask`` (1 for real
tokens, 0 for padding), both ``[batch, seq_len]`` long tensors.
"""
sequences = [self.encode(text, add_eot=add_eot) for text in texts]
if max_length is not None:
sequences = [seq[:max_length] for seq in sequences]
return self.pad(sequences, device=device)
def pad(
self,
sequences: Sequence[Sequence[int]],
pad_to: Optional[int] = None,
device: Optional[Union[str, "torch.device"]] = None,
) -> Dict[str, "torch.Tensor"]:
"""Right-pad pre-tokenized id sequences into padded tensors + mask."""
longest = max((len(seq) for seq in sequences), default=0)
width = max(longest, pad_to or 0)
batch = len(sequences)
input_ids = torch.full((batch, width), self.pad_id, dtype=torch.long)
attention_mask = torch.zeros((batch, width), dtype=torch.long)
for row, seq in enumerate(sequences):
length = len(seq)
if length:
input_ids[row, :length] = torch.tensor(seq, dtype=torch.long)
attention_mask[row, :length] = 1
if device is not None:
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
return {"input_ids": input_ids, "attention_mask": attention_mask}
# ------------------------------------------------------------------ #
# Fill-in-the-Middle helpers
# ------------------------------------------------------------------ #
def build_fim_prompt(self, prefix: str, suffix: str, mode: str = "psm") -> List[int]:
"""Token ids for an *inference* FIM prompt (ends at ``<fim_middle>``).
Feed these ids to the model and let it generate the completion; stop on
``eot_id``. ``mode`` mirrors the training mix: ``"psm"`` (prefix,
suffix, middle) or ``"spm"`` (suffix, prefix, middle).
"""
prefix_ids = self.encode(prefix)
suffix_ids = self.encode(suffix)
if mode == "spm":
return (
[self.fim_prefix_id, self.fim_suffix_id]
+ suffix_ids
+ [self.fim_middle_id]
+ prefix_ids
)
if mode == "psm":
return (
[self.fim_prefix_id]
+ prefix_ids
+ [self.fim_suffix_id]
+ suffix_ids
+ [self.fim_middle_id]
)
raise ValueError(f"unknown FIM mode {mode!r}; expected 'psm' or 'spm'")
# ------------------------------------------------------------------ #
# Convenience dunders
# ------------------------------------------------------------------ #
def __len__(self) -> int:
return self.vocab_size
def __repr__(self) -> str: # pragma: no cover - debug aid
return (
f"JSCoderTokenizer(vocab_size={self.vocab_size}, "
f"pad_id={self.pad_id}, eot_id={self.eot_id})"
)
def _demo() -> None:
"""Quick smoke test: load, round-trip, and show a FIM prompt."""
tok = JSCoderTokenizer.load()
print(tok)
sample = "export const add = (a, b) => a + b;\n"
ids = tok.encode(sample, add_eot=True)
print(f"\nsample: {sample!r}")
print(f"ids ({len(ids)}): {ids}")
print(f"round-trip ok: {tok.decode(ids) == sample}")
prompt = tok.build_fim_prompt(prefix="function sum(arr) {\n ", suffix="\n}\n")
print(f"\nFIM prompt ids ({len(prompt)}): {prompt[:12]}...")
print(f"FIM prompt text: {tok.decode(prompt, skip_special_tokens=False)!r}")
if __name__ == "__main__":
_demo()