from __future__ import annotations import shutil from pathlib import Path import sentencepiece as spm from transformers import PreTrainedTokenizer class HanForgeTokenizer(PreTrainedTokenizer): vocab_files_names = {"vocab_file": "tokenizer.model"} model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file: str, bos_token: str = "", eos_token: str = "", unk_token: str = "", pad_token: str = "", additional_special_tokens: list[str] | None = None, **kwargs, ): self.vocab_file = vocab_file self.sp_model = spm.SentencePieceProcessor(model_file=vocab_file) super().__init__( bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, additional_special_tokens=additional_special_tokens or [], **kwargs, ) @property def vocab_size(self) -> int: return int(self.sp_model.vocab_size()) def get_vocab(self) -> dict[str, int]: vocab = {self.sp_model.id_to_piece(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab def _tokenize(self, text: str) -> list[str]: return list(self.sp_model.encode(text, out_type=str)) def _convert_token_to_id(self, token: str) -> int: return int(self.sp_model.piece_to_id(token)) def _convert_id_to_token(self, index: int) -> str: return str(self.sp_model.id_to_piece(index)) def convert_tokens_to_string(self, tokens: list[str]) -> str: return self.sp_model.decode_pieces(tokens) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): output = [self.bos_token_id] + list(token_ids_0) if token_ids_1 is not None: output += list(token_ids_1) output += [self.eos_token_id] return output def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None): save_dir = Path(save_directory) save_dir.mkdir(parents=True, exist_ok=True) out_name = f"{filename_prefix + '-' if filename_prefix else ''}tokenizer.model" out_path = save_dir / out_name if Path(self.vocab_file).resolve() != out_path.resolve(): shutil.copy2(self.vocab_file, out_path) vocab_src = Path(self.vocab_file).with_suffix(".vocab") if vocab_src.exists(): vocab_out = save_dir / f"{filename_prefix + '-' if filename_prefix else ''}tokenizer.vocab" if vocab_src.resolve() != vocab_out.resolve(): shutil.copy2(vocab_src, vocab_out) return (str(out_path),)