steerling-8b / tokenization_steerling.py
AyaGL's picture
Update tokenization_steerling.py
efaecc2 verified
from __future__ import annotations
from typing import Any
import tiktoken
from transformers import PreTrainedTokenizer
import tiktoken
class _SteerlingTokenizer:
"""
Tokenizer for Steerling models.
Uses tiktoken cl100k_base with 4 additional special tokens as mentioned above.
"""
ENCODING_NAME = 'cl100k_base'
def __init__(self):
base_enc = tiktoken.get_encoding(self.ENCODING_NAME)
base_vocab = base_enc.n_vocab
self._pad_token_id = base_vocab
self._bos_token_id = base_vocab + 1
self._endofchunk_token_id = base_vocab + 2
self._mask_token_id = base_vocab + 3
self._eos_token_id = base_enc._special_tokens['<|endoftext|>']
self._vocab_size = base_vocab + 4
self._tokenizer = tiktoken.Encoding(name=f'{self.ENCODING_NAME}_steerling', pat_str=base_enc._pat_str, mergeable_ranks=base_enc._mergeable_ranks, special_tokens={**base_enc._special_tokens, '<|pad|>': self._pad_token_id, '<|bos|>': self._bos_token_id, '<|endofchunk|>': self._endofchunk_token_id, '<|mask|>': self._mask_token_id})
self._special_token_ids = {self._pad_token_id, self._bos_token_id, self._eos_token_id, self._endofchunk_token_id, self._mask_token_id}
def encode(self, text: str, add_special_tokens: bool=True) -> list[int]:
"""
Encode text to token IDs.
Args:
text: Input text
add_special_tokens: If True, prepend BOS and append EOS
Returns:
List of token IDs
"""
tokens = self._tokenizer.encode(text, disallowed_special=())
if add_special_tokens:
tokens = [self._bos_token_id] + tokens + [self._eos_token_id]
return tokens
def decode(self, tokens: list[int], skip_special_tokens: bool=True) -> str:
"""
Decode token IDs to text.
Args:
tokens: Token IDs (list, numpy array, or torch tensor)
skip_special_tokens: If True, filter out special tokens before decoding
Returns:
Decoded text
"""
if skip_special_tokens:
tokens = [int(t) for t in tokens if int(t) not in self._special_token_ids]
else:
tokens = [int(t) for t in tokens]
return self._tokenizer.decode(tokens)
@property
def vocab_size(self) -> int:
return self._vocab_size
@property
def pad_token_id(self) -> int:
return self._pad_token_id
@property
def bos_token_id(self) -> int:
return self._bos_token_id
@property
def eos_token_id(self) -> int:
return self._eos_token_id
@property
def endofchunk_token_id(self) -> int:
return self._endofchunk_token_id
@property
def mask_token_id(self) -> int:
return self._mask_token_id
class SteerlingTokenizer(PreTrainedTokenizer):
vocab_files_names: dict[str, str] = {}
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, encoding_name="cl100k_base", pad_token_id=100277,
bos_token_id=100278, eos_token_id=100257,
endofchunk_token_id=100279, mask_token_id=100280, **kwargs):
self._core = _SteerlingTokenizer()
self._endofchunk_token_id = endofchunk_token_id
self._mask_token_id = mask_token_id
for k in ("pad_token", "bos_token", "eos_token", "additional_special_tokens"):
kwargs.pop(k, None)
super().__init__(pad_token="<|pad|>", bos_token="<|bos|>", eos_token="<|endoftext|>",
additional_special_tokens=["<|endofchunk|>", "<|mask|>"], **kwargs)
@property
def vocab_size(self): return self._core.vocab_size
@property
def endofchunk_token_id(self): return self._core.endofchunk_token_id
@property
def mask_token_id(self): return self._core.mask_token_id
def get_vocab(self): return dict(self._core._tokenizer._special_tokens)
def _tokenize(self, text, **kwargs):
return [str(i) for i in self._core._tokenizer.encode(text, disallowed_special=())]
def _convert_token_to_id(self, token):
special = self._core._tokenizer._special_tokens
if token in special: return special[token]
try: return int(token)
except ValueError:
ids = self._core._tokenizer.encode(token, disallowed_special=())
return ids[0] if ids else self._core.pad_token_id
def _convert_id_to_token(self, index):
for name, idx in self._core._tokenizer._special_tokens.items():
if idx == index: return name
try: return self._core._tokenizer.decode([index])
except Exception: return f"<|token_{index}|>"
def convert_tokens_to_string(self, tokens):
ids, special = [], self._core._tokenizer._special_tokens
for t in tokens:
if t in special: continue
try:
tid = int(t)
if tid not in self._core._special_token_ids: ids.append(tid)
except ValueError:
ids.extend(self._core._tokenizer.encode(t, disallowed_special=()))
return self._core._tokenizer.decode(ids)
def _decode(self, token_ids, skip_special_tokens=False, **kwargs):
return self._core.decode(list(token_ids) if not isinstance(token_ids, list) else token_ids,
skip_special_tokens=skip_special_tokens)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
return token_ids_0
def save_vocabulary(self, save_directory, filename_prefix=None):
return ()