from __future__ import annotations from typing import Any import tiktoken from transformers import PreTrainedTokenizer import tiktoken class _SteerlingTokenizer: """ Tokenizer for Steerling models. Uses tiktoken cl100k_base with 4 additional special tokens as mentioned above. """ ENCODING_NAME = 'cl100k_base' def __init__(self): base_enc = tiktoken.get_encoding(self.ENCODING_NAME) base_vocab = base_enc.n_vocab self._pad_token_id = base_vocab self._bos_token_id = base_vocab + 1 self._endofchunk_token_id = base_vocab + 2 self._mask_token_id = base_vocab + 3 self._eos_token_id = base_enc._special_tokens['<|endoftext|>'] self._vocab_size = base_vocab + 4 self._tokenizer = tiktoken.Encoding(name=f'{self.ENCODING_NAME}_steerling', pat_str=base_enc._pat_str, mergeable_ranks=base_enc._mergeable_ranks, special_tokens={**base_enc._special_tokens, '<|pad|>': self._pad_token_id, '<|bos|>': self._bos_token_id, '<|endofchunk|>': self._endofchunk_token_id, '<|mask|>': self._mask_token_id}) self._special_token_ids = {self._pad_token_id, self._bos_token_id, self._eos_token_id, self._endofchunk_token_id, self._mask_token_id} def encode(self, text: str, add_special_tokens: bool=True) -> list[int]: """ Encode text to token IDs. Args: text: Input text add_special_tokens: If True, prepend BOS and append EOS Returns: List of token IDs """ tokens = self._tokenizer.encode(text, disallowed_special=()) if add_special_tokens: tokens = [self._bos_token_id] + tokens + [self._eos_token_id] return tokens def decode(self, tokens: list[int], skip_special_tokens: bool=True) -> str: """ Decode token IDs to text. Args: tokens: Token IDs (list, numpy array, or torch tensor) skip_special_tokens: If True, filter out special tokens before decoding Returns: Decoded text """ if skip_special_tokens: tokens = [int(t) for t in tokens if int(t) not in self._special_token_ids] else: tokens = [int(t) for t in tokens] return self._tokenizer.decode(tokens) @property def vocab_size(self) -> int: return self._vocab_size @property def pad_token_id(self) -> int: return self._pad_token_id @property def bos_token_id(self) -> int: return self._bos_token_id @property def eos_token_id(self) -> int: return self._eos_token_id @property def endofchunk_token_id(self) -> int: return self._endofchunk_token_id @property def mask_token_id(self) -> int: return self._mask_token_id class SteerlingTokenizer(PreTrainedTokenizer): vocab_files_names: dict[str, str] = {} model_input_names = ["input_ids", "attention_mask"] def __init__(self, encoding_name="cl100k_base", pad_token_id=100277, bos_token_id=100278, eos_token_id=100257, endofchunk_token_id=100279, mask_token_id=100280, **kwargs): self._core = _SteerlingTokenizer() self._endofchunk_token_id = endofchunk_token_id self._mask_token_id = mask_token_id for k in ("pad_token", "bos_token", "eos_token", "additional_special_tokens"): kwargs.pop(k, None) super().__init__(pad_token="<|pad|>", bos_token="<|bos|>", eos_token="<|endoftext|>", additional_special_tokens=["<|endofchunk|>", "<|mask|>"], **kwargs) @property def vocab_size(self): return self._core.vocab_size @property def endofchunk_token_id(self): return self._core.endofchunk_token_id @property def mask_token_id(self): return self._core.mask_token_id def get_vocab(self): return dict(self._core._tokenizer._special_tokens) def _tokenize(self, text, **kwargs): return [str(i) for i in self._core._tokenizer.encode(text, disallowed_special=())] def _convert_token_to_id(self, token): special = self._core._tokenizer._special_tokens if token in special: return special[token] try: return int(token) except ValueError: ids = self._core._tokenizer.encode(token, disallowed_special=()) return ids[0] if ids else self._core.pad_token_id def _convert_id_to_token(self, index): for name, idx in self._core._tokenizer._special_tokens.items(): if idx == index: return name try: return self._core._tokenizer.decode([index]) except Exception: return f"<|token_{index}|>" def convert_tokens_to_string(self, tokens): ids, special = [], self._core._tokenizer._special_tokens for t in tokens: if t in special: continue try: tid = int(t) if tid not in self._core._special_token_ids: ids.append(tid) except ValueError: ids.extend(self._core._tokenizer.encode(t, disallowed_special=())) return self._core._tokenizer.decode(ids) def _decode(self, token_ids, skip_special_tokens=False, **kwargs): return self._core.decode(list(token_ids) if not isinstance(token_ids, list) else token_ids, skip_special_tokens=skip_special_tokens) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): return token_ids_0 def save_vocabulary(self, save_directory, filename_prefix=None): return ()