tiny-gpt-2-1m / tokenization_tiny_gpt.py
vjkhambe's picture
Publish TinyGPT checkpoint
bd8a48b verified
Raw
History Blame Contribute Delete
2.73 kB
from __future__ import annotations
import shutil
from pathlib import Path
import sentencepiece as spm
from transformers import PreTrainedTokenizer
class TinyGPTTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "tokenizer.model"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file: str,
unk_token: str = "<unk>",
bos_token: str = "<s>",
eos_token: str = "</s>",
pad_token: str = "<pad>",
**kwargs,
) -> None:
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor(model_file=vocab_file)
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
**kwargs,
)
@property
def vocab_size(self) -> int:
return self.sp_model.get_piece_size()
def get_vocab(self) -> dict[str, int]:
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text: str) -> list[str]:
return list(self.sp_model.encode(text, out_type=str))
def _convert_token_to_id(self, token: str) -> int:
return int(self.sp_model.piece_to_id(token))
def _convert_id_to_token(self, index: int) -> str:
return self.sp_model.id_to_piece(int(index))
def convert_tokens_to_string(self, tokens: list[str]) -> str:
return self.sp_model.decode_pieces(tokens)
def build_inputs_with_special_tokens(
self,
token_ids_0: list[int],
token_ids_1: list[int] | None = None,
) -> list[int]:
if token_ids_1 is None:
return token_ids_0
return token_ids_0 + token_ids_1
def get_special_tokens_mask(
self,
token_ids_0: list[int],
token_ids_1: list[int] | None = None,
already_has_special_tokens: bool = False,
) -> list[int]:
if already_has_special_tokens:
return [0] * (len(token_ids_0) + (len(token_ids_1) if token_ids_1 else 0))
return [0] * (len(token_ids_0) + (len(token_ids_1) if token_ids_1 else 0))
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
if not self.vocab_file:
raise ValueError("No SentencePiece model file to save")
save_dir = Path(save_directory)
save_dir.mkdir(parents=True, exist_ok=True)
filename = "tokenizer.model"
out_path = save_dir / filename
shutil.copy2(self.vocab_file, out_path)
return (str(out_path),)
TinyGPTTokenizer.register_for_auto_class()