from __future__ import annotations from pathlib import Path import torch from transformers import PreTrainedTokenizer from transformers.utils import cached_file from talkie.chat import Message, format_chat from talkie.tokenizer import build_tokenizer class TalkieTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask"] vocab_files_names = {"vocab_file": "vocab.txt"} def __init__( self, vocab_file="vocab.txt", style="it", vocab_size=65536, encode_chat_special_tokens=None, **kwargs, ): vocab_path = Path(vocab_file) if not vocab_path.exists() and not vocab_path.is_absolute(): vocab_path = Path(__file__).resolve().parent / vocab_path if not vocab_path.exists(): name_or_path = kwargs.get("name_or_path") or kwargs.get("_name_or_path") if name_or_path: resolved = cached_file(name_or_path, vocab_file) if resolved is not None: vocab_path = Path(resolved) if not vocab_path.exists(): raise OSError(f"TalkieTokenizer could not find vocab file: {vocab_file!r}") self.vocab_file = str(vocab_path) self.style = style self._vocab_size = int(vocab_size) if encode_chat_special_tokens is None: encode_chat_special_tokens = style == "it" and self._vocab_size > 65536 self.encode_chat_special_tokens = bool(encode_chat_special_tokens) tokenizer_style = style if self.encode_chat_special_tokens else "base" self.encoding = build_tokenizer(self.vocab_file, style=tokenizer_style) eos = "<|end|>" if self.encode_chat_special_tokens else "<|endoftext|>" kwargs.setdefault("eos_token", eos) kwargs.setdefault("pad_token", eos) super().__init__(**kwargs) self.eos_token = eos self.pad_token = eos @property def vocab_size(self): return 65536 def get_vocab(self): vocab = {str(i): i for i in range(65536)} if self.encode_chat_special_tokens: vocab.update(self.encoding._special_tokens) return vocab def __len__(self): return self._vocab_size def _tokenize(self, text, **kwargs): return [str(i) for i in self.encode(text)] def _convert_token_to_id(self, token): if token in self.encoding._special_tokens: return self.encoding._special_tokens[token] return int(token) def _convert_id_to_token(self, index): for token, token_id in self.encoding._special_tokens.items(): if int(index) == int(token_id): return token return str(index) def convert_tokens_to_string(self, tokens): return self.encoding.decode([self._convert_token_to_id(t) for t in tokens]) def encode(self, text, add_special_tokens=False, **kwargs): del add_special_tokens if self.encode_chat_special_tokens: return self.encoding.encode(text, allowed_special="all") return self.encoding.encode(text, allowed_special=set(), disallowed_special=()) def decode(self, token_ids, skip_special_tokens=False, **kwargs): if isinstance(token_ids, torch.Tensor): token_ids = token_ids.tolist() if skip_special_tokens: specials = {self.encoding.encode_single_token(t) for t in self.encoding.special_tokens_set} token_ids = [int(i) for i in token_ids if int(i) not in specials] return self.encoding.decode([int(i) for i in token_ids]) def apply_chat_template(self, conversation, tokenize=False, add_generation_prompt=True, return_tensors=None, **kwargs): messages = [m if isinstance(m, Message) else Message(m["role"], m["content"]) for m in conversation] text = format_chat(messages) if not add_generation_prompt and text.endswith("<|assistant|>"): text = text[:-len("<|assistant|>")] if not tokenize: return text ids = self.encode(text) return torch.tensor([ids], dtype=torch.long) if return_tensors == "pt" else ids def save_vocabulary(self, save_directory, filename_prefix=None): out = Path(save_directory) / ((filename_prefix or "") + "vocab.txt") if Path(self.vocab_file).resolve() != out.resolve(): out.write_bytes(Path(self.vocab_file).read_bytes()) return (str(out),)