#!/usr/bin/env python3 """PyTorch-friendly wrapper around the trained byte-level BPE tokenizer. ``JSCoderTokenizer`` is the single object the rest of the pipeline (dataset packing, training loop, inference) talks to. It hides the ``tokenizers`` library behind a small, typed API and returns ``torch`` tensors so it drops straight into a ``DataLoader`` / training loop. Example:: from tokenizer.tokenizer import JSCoderTokenizer tok = JSCoderTokenizer.load() ids = tok.encode("const x = 1;\\n") # list[int] batch = tok.encode_batch([s1, s2], device="cpu") # padded tensors + mask text = tok.decode(ids) # Build a fill-in-the-middle prompt for autocomplete at the cursor: prompt = tok.build_fim_prompt(prefix, suffix) # ends with """ from __future__ import annotations import json from pathlib import Path from typing import Dict, List, Optional, Sequence, Union import torch from tokenizers import Tokenizer try: from .special_tokens import ( EOT, FIM_MIDDLE, FIM_PREFIX, FIM_SUFFIX, PAD, ) except ImportError: # pragma: no cover - script execution fallback import sys sys.path.insert(0, str(Path(__file__).resolve().parent)) from special_tokens import ( # type: ignore EOT, FIM_MIDDLE, FIM_PREFIX, FIM_SUFFIX, PAD, ) REPO_ROOT = Path(__file__).resolve().parent.parent DEFAULT_TOKENIZER_PATH = REPO_ROOT / "tokenizer" / "js_bpe.json" IdsLike = Union[Sequence[int], "torch.Tensor"] class JSCoderTokenizer: """Thin, typed, torch-returning wrapper over a byte-level BPE tokenizer.""" def __init__(self, tokenizer: Tokenizer): self._tk = tokenizer # Resolve and cache the control-token ids once. self.pad_id = self._require_id(PAD) self.eot_id = self._require_id(EOT) self.fim_prefix_id = self._require_id(FIM_PREFIX) self.fim_middle_id = self._require_id(FIM_MIDDLE) self.fim_suffix_id = self._require_id(FIM_SUFFIX) # ------------------------------------------------------------------ # # Construction # ------------------------------------------------------------------ # @classmethod def load(cls, path: Union[str, Path] = DEFAULT_TOKENIZER_PATH) -> "JSCoderTokenizer": path = Path(path) if not path.exists(): raise FileNotFoundError( f"No tokenizer at {path}. Train one first:\n" f" python3 tokenizer/train_tokenizer.py" ) return cls(Tokenizer.from_file(str(path))) def _require_id(self, token: str) -> int: token_id = self._tk.token_to_id(token) if token_id is None: raise ValueError( f"Special token {token!r} is missing from the tokenizer vocab. " "Retrain with tokenizer/train_tokenizer.py." ) return token_id # ------------------------------------------------------------------ # # Vocab / metadata # ------------------------------------------------------------------ # @property def vocab_size(self) -> int: return self._tk.get_vocab_size() def token_to_id(self, token: str) -> Optional[int]: return self._tk.token_to_id(token) def id_to_token(self, token_id: int) -> Optional[str]: return self._tk.id_to_token(token_id) # ------------------------------------------------------------------ # # Core encode / decode # ------------------------------------------------------------------ # def encode(self, text: str, add_eot: bool = False) -> List[int]: ids = self._tk.encode(text).ids if add_eot: ids.append(self.eot_id) return ids def encode_many( self, texts: Sequence[str], add_eot: bool = False ) -> List[List[int]]: """Encode many texts at once across all CPU cores. Delegates to the Rust ``tokenizers`` batch path, which releases the GIL and parallelises with Rayon. This is dramatically faster than calling :meth:`encode` in a Python loop when tokenizing large corpora. """ encodings = self._tk.encode_batch(list(texts)) if add_eot: return [enc.ids + [self.eot_id] for enc in encodings] return [enc.ids for enc in encodings] def encode_to_tensor( self, text: str, add_eot: bool = False, device: Optional[Union[str, "torch.device"]] = None, ) -> "torch.Tensor": ids = self.encode(text, add_eot=add_eot) return torch.tensor(ids, dtype=torch.long, device=device) def decode(self, ids: IdsLike, skip_special_tokens: bool = True) -> str: if isinstance(ids, torch.Tensor): ids = ids.tolist() return self._tk.decode(list(ids), skip_special_tokens=skip_special_tokens) # ------------------------------------------------------------------ # # Batched encoding with padding (ready for a DataLoader collate_fn) # ------------------------------------------------------------------ # def encode_batch( self, texts: Sequence[str], add_eot: bool = False, max_length: Optional[int] = None, device: Optional[Union[str, "torch.device"]] = None, ) -> Dict[str, "torch.Tensor"]: """Encode and right-pad a batch. Returns a dict with ``input_ids`` and ``attention_mask`` (1 for real tokens, 0 for padding), both ``[batch, seq_len]`` long tensors. """ sequences = [self.encode(text, add_eot=add_eot) for text in texts] if max_length is not None: sequences = [seq[:max_length] for seq in sequences] return self.pad(sequences, device=device) def pad( self, sequences: Sequence[Sequence[int]], pad_to: Optional[int] = None, device: Optional[Union[str, "torch.device"]] = None, ) -> Dict[str, "torch.Tensor"]: """Right-pad pre-tokenized id sequences into padded tensors + mask.""" longest = max((len(seq) for seq in sequences), default=0) width = max(longest, pad_to or 0) batch = len(sequences) input_ids = torch.full((batch, width), self.pad_id, dtype=torch.long) attention_mask = torch.zeros((batch, width), dtype=torch.long) for row, seq in enumerate(sequences): length = len(seq) if length: input_ids[row, :length] = torch.tensor(seq, dtype=torch.long) attention_mask[row, :length] = 1 if device is not None: input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) return {"input_ids": input_ids, "attention_mask": attention_mask} # ------------------------------------------------------------------ # # Fill-in-the-Middle helpers # ------------------------------------------------------------------ # def build_fim_prompt(self, prefix: str, suffix: str, mode: str = "psm") -> List[int]: """Token ids for an *inference* FIM prompt (ends at ````). Feed these ids to the model and let it generate the completion; stop on ``eot_id``. ``mode`` mirrors the training mix: ``"psm"`` (prefix, suffix, middle) or ``"spm"`` (suffix, prefix, middle). """ prefix_ids = self.encode(prefix) suffix_ids = self.encode(suffix) if mode == "spm": return ( [self.fim_prefix_id, self.fim_suffix_id] + suffix_ids + [self.fim_middle_id] + prefix_ids ) if mode == "psm": return ( [self.fim_prefix_id] + prefix_ids + [self.fim_suffix_id] + suffix_ids + [self.fim_middle_id] ) raise ValueError(f"unknown FIM mode {mode!r}; expected 'psm' or 'spm'") # ------------------------------------------------------------------ # # Convenience dunders # ------------------------------------------------------------------ # def __len__(self) -> int: return self.vocab_size def __repr__(self) -> str: # pragma: no cover - debug aid return ( f"JSCoderTokenizer(vocab_size={self.vocab_size}, " f"pad_id={self.pad_id}, eot_id={self.eot_id})" ) def _demo() -> None: """Quick smoke test: load, round-trip, and show a FIM prompt.""" tok = JSCoderTokenizer.load() print(tok) sample = "export const add = (a, b) => a + b;\n" ids = tok.encode(sample, add_eot=True) print(f"\nsample: {sample!r}") print(f"ids ({len(ids)}): {ids}") print(f"round-trip ok: {tok.decode(ids) == sample}") prompt = tok.build_fim_prompt(prefix="function sum(arr) {\n ", suffix="\n}\n") print(f"\nFIM prompt ids ({len(prompt)}): {prompt[:12]}...") print(f"FIM prompt text: {tok.decode(prompt, skip_special_tokens=False)!r}") if __name__ == "__main__": _demo()