| |
| """PyTorch-friendly wrapper around the trained byte-level BPE tokenizer. |
| |
| ``JSCoderTokenizer`` is the single object the rest of the pipeline (dataset |
| packing, training loop, inference) talks to. It hides the ``tokenizers`` |
| library behind a small, typed API and returns ``torch`` tensors so it drops |
| straight into a ``DataLoader`` / training loop. |
| |
| Example:: |
| |
| from tokenizer.tokenizer import JSCoderTokenizer |
| |
| tok = JSCoderTokenizer.load() |
| ids = tok.encode("const x = 1;\\n") # list[int] |
| batch = tok.encode_batch([s1, s2], device="cpu") # padded tensors + mask |
| text = tok.decode(ids) |
| |
| # Build a fill-in-the-middle prompt for autocomplete at the cursor: |
| prompt = tok.build_fim_prompt(prefix, suffix) # ends with <fim_middle> |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
| from typing import Dict, List, Optional, Sequence, Union |
|
|
| import torch |
| from tokenizers import Tokenizer |
|
|
| try: |
| from .special_tokens import ( |
| EOT, |
| FIM_MIDDLE, |
| FIM_PREFIX, |
| FIM_SUFFIX, |
| PAD, |
| ) |
| except ImportError: |
| import sys |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent)) |
| from special_tokens import ( |
| EOT, |
| FIM_MIDDLE, |
| FIM_PREFIX, |
| FIM_SUFFIX, |
| PAD, |
| ) |
|
|
| REPO_ROOT = Path(__file__).resolve().parent.parent |
| DEFAULT_TOKENIZER_PATH = REPO_ROOT / "tokenizer" / "js_bpe.json" |
|
|
| IdsLike = Union[Sequence[int], "torch.Tensor"] |
|
|
|
|
| class JSCoderTokenizer: |
| """Thin, typed, torch-returning wrapper over a byte-level BPE tokenizer.""" |
|
|
| def __init__(self, tokenizer: Tokenizer): |
| self._tk = tokenizer |
|
|
| |
| self.pad_id = self._require_id(PAD) |
| self.eot_id = self._require_id(EOT) |
| self.fim_prefix_id = self._require_id(FIM_PREFIX) |
| self.fim_middle_id = self._require_id(FIM_MIDDLE) |
| self.fim_suffix_id = self._require_id(FIM_SUFFIX) |
|
|
| |
| |
| |
| @classmethod |
| def load(cls, path: Union[str, Path] = DEFAULT_TOKENIZER_PATH) -> "JSCoderTokenizer": |
| path = Path(path) |
| if not path.exists(): |
| raise FileNotFoundError( |
| f"No tokenizer at {path}. Train one first:\n" |
| f" python3 tokenizer/train_tokenizer.py" |
| ) |
| return cls(Tokenizer.from_file(str(path))) |
|
|
| def _require_id(self, token: str) -> int: |
| token_id = self._tk.token_to_id(token) |
| if token_id is None: |
| raise ValueError( |
| f"Special token {token!r} is missing from the tokenizer vocab. " |
| "Retrain with tokenizer/train_tokenizer.py." |
| ) |
| return token_id |
|
|
| |
| |
| |
| @property |
| def vocab_size(self) -> int: |
| return self._tk.get_vocab_size() |
|
|
| def token_to_id(self, token: str) -> Optional[int]: |
| return self._tk.token_to_id(token) |
|
|
| def id_to_token(self, token_id: int) -> Optional[str]: |
| return self._tk.id_to_token(token_id) |
|
|
| |
| |
| |
| def encode(self, text: str, add_eot: bool = False) -> List[int]: |
| ids = self._tk.encode(text).ids |
| if add_eot: |
| ids.append(self.eot_id) |
| return ids |
|
|
| def encode_many( |
| self, texts: Sequence[str], add_eot: bool = False |
| ) -> List[List[int]]: |
| """Encode many texts at once across all CPU cores. |
| |
| Delegates to the Rust ``tokenizers`` batch path, which releases the GIL |
| and parallelises with Rayon. This is dramatically faster than calling |
| :meth:`encode` in a Python loop when tokenizing large corpora. |
| """ |
| encodings = self._tk.encode_batch(list(texts)) |
| if add_eot: |
| return [enc.ids + [self.eot_id] for enc in encodings] |
| return [enc.ids for enc in encodings] |
|
|
| def encode_to_tensor( |
| self, |
| text: str, |
| add_eot: bool = False, |
| device: Optional[Union[str, "torch.device"]] = None, |
| ) -> "torch.Tensor": |
| ids = self.encode(text, add_eot=add_eot) |
| return torch.tensor(ids, dtype=torch.long, device=device) |
|
|
| def decode(self, ids: IdsLike, skip_special_tokens: bool = True) -> str: |
| if isinstance(ids, torch.Tensor): |
| ids = ids.tolist() |
| return self._tk.decode(list(ids), skip_special_tokens=skip_special_tokens) |
|
|
| |
| |
| |
| def encode_batch( |
| self, |
| texts: Sequence[str], |
| add_eot: bool = False, |
| max_length: Optional[int] = None, |
| device: Optional[Union[str, "torch.device"]] = None, |
| ) -> Dict[str, "torch.Tensor"]: |
| """Encode and right-pad a batch. |
| |
| Returns a dict with ``input_ids`` and ``attention_mask`` (1 for real |
| tokens, 0 for padding), both ``[batch, seq_len]`` long tensors. |
| """ |
| sequences = [self.encode(text, add_eot=add_eot) for text in texts] |
| if max_length is not None: |
| sequences = [seq[:max_length] for seq in sequences] |
| return self.pad(sequences, device=device) |
|
|
| def pad( |
| self, |
| sequences: Sequence[Sequence[int]], |
| pad_to: Optional[int] = None, |
| device: Optional[Union[str, "torch.device"]] = None, |
| ) -> Dict[str, "torch.Tensor"]: |
| """Right-pad pre-tokenized id sequences into padded tensors + mask.""" |
| longest = max((len(seq) for seq in sequences), default=0) |
| width = max(longest, pad_to or 0) |
| batch = len(sequences) |
|
|
| input_ids = torch.full((batch, width), self.pad_id, dtype=torch.long) |
| attention_mask = torch.zeros((batch, width), dtype=torch.long) |
| for row, seq in enumerate(sequences): |
| length = len(seq) |
| if length: |
| input_ids[row, :length] = torch.tensor(seq, dtype=torch.long) |
| attention_mask[row, :length] = 1 |
|
|
| if device is not None: |
| input_ids = input_ids.to(device) |
| attention_mask = attention_mask.to(device) |
| return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
| |
| |
| |
| def build_fim_prompt(self, prefix: str, suffix: str, mode: str = "psm") -> List[int]: |
| """Token ids for an *inference* FIM prompt (ends at ``<fim_middle>``). |
| |
| Feed these ids to the model and let it generate the completion; stop on |
| ``eot_id``. ``mode`` mirrors the training mix: ``"psm"`` (prefix, |
| suffix, middle) or ``"spm"`` (suffix, prefix, middle). |
| """ |
| prefix_ids = self.encode(prefix) |
| suffix_ids = self.encode(suffix) |
| if mode == "spm": |
| return ( |
| [self.fim_prefix_id, self.fim_suffix_id] |
| + suffix_ids |
| + [self.fim_middle_id] |
| + prefix_ids |
| ) |
| if mode == "psm": |
| return ( |
| [self.fim_prefix_id] |
| + prefix_ids |
| + [self.fim_suffix_id] |
| + suffix_ids |
| + [self.fim_middle_id] |
| ) |
| raise ValueError(f"unknown FIM mode {mode!r}; expected 'psm' or 'spm'") |
|
|
| |
| |
| |
| def __len__(self) -> int: |
| return self.vocab_size |
|
|
| def __repr__(self) -> str: |
| return ( |
| f"JSCoderTokenizer(vocab_size={self.vocab_size}, " |
| f"pad_id={self.pad_id}, eot_id={self.eot_id})" |
| ) |
|
|
|
|
| def _demo() -> None: |
| """Quick smoke test: load, round-trip, and show a FIM prompt.""" |
| tok = JSCoderTokenizer.load() |
| print(tok) |
|
|
| sample = "export const add = (a, b) => a + b;\n" |
| ids = tok.encode(sample, add_eot=True) |
| print(f"\nsample: {sample!r}") |
| print(f"ids ({len(ids)}): {ids}") |
| print(f"round-trip ok: {tok.decode(ids) == sample}") |
|
|
| prompt = tok.build_fim_prompt(prefix="function sum(arr) {\n ", suffix="\n}\n") |
| print(f"\nFIM prompt ids ({len(prompt)}): {prompt[:12]}...") |
| print(f"FIM prompt text: {tok.decode(prompt, skip_special_tokens=False)!r}") |
|
|
|
|
| if __name__ == "__main__": |
| _demo() |
|
|