|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| import torch |
| from transformers import PreTrainedTokenizer |
| from transformers.utils import cached_file |
| from talkie.chat import Message, format_chat |
| from talkie.tokenizer import build_tokenizer |
|
|
|
|
| class TalkieTokenizer(PreTrainedTokenizer): |
| model_input_names = ["input_ids", "attention_mask"] |
| vocab_files_names = {"vocab_file": "vocab.txt"} |
|
|
| def __init__( |
| self, |
| vocab_file="vocab.txt", |
| style="it", |
| vocab_size=65536, |
| encode_chat_special_tokens=None, |
| **kwargs, |
| ): |
| vocab_path = Path(vocab_file) |
| if not vocab_path.exists() and not vocab_path.is_absolute(): |
| vocab_path = Path(__file__).resolve().parent / vocab_path |
| if not vocab_path.exists(): |
| name_or_path = kwargs.get("name_or_path") or kwargs.get("_name_or_path") |
| if name_or_path: |
| resolved = cached_file(name_or_path, vocab_file) |
| if resolved is not None: |
| vocab_path = Path(resolved) |
| if not vocab_path.exists(): |
| raise OSError(f"TalkieTokenizer could not find vocab file: {vocab_file!r}") |
| self.vocab_file = str(vocab_path) |
| self.style = style |
| self._vocab_size = int(vocab_size) |
| if encode_chat_special_tokens is None: |
| encode_chat_special_tokens = style == "it" and self._vocab_size > 65536 |
| self.encode_chat_special_tokens = bool(encode_chat_special_tokens) |
| tokenizer_style = style if self.encode_chat_special_tokens else "base" |
| self.encoding = build_tokenizer(self.vocab_file, style=tokenizer_style) |
| eos = "<|end|>" if self.encode_chat_special_tokens else "<|endoftext|>" |
| kwargs.setdefault("eos_token", eos) |
| kwargs.setdefault("pad_token", eos) |
| super().__init__(**kwargs) |
| self.eos_token = eos |
| self.pad_token = eos |
|
|
| @property |
| def vocab_size(self): |
| return 65536 |
|
|
| def get_vocab(self): |
| vocab = {str(i): i for i in range(65536)} |
| if self.encode_chat_special_tokens: |
| vocab.update(self.encoding._special_tokens) |
| return vocab |
|
|
| def __len__(self): |
| return self._vocab_size |
|
|
| def _tokenize(self, text, **kwargs): |
| return [str(i) for i in self.encode(text)] |
|
|
| def _convert_token_to_id(self, token): |
| if token in self.encoding._special_tokens: |
| return self.encoding._special_tokens[token] |
| return int(token) |
|
|
| def _convert_id_to_token(self, index): |
| for token, token_id in self.encoding._special_tokens.items(): |
| if int(index) == int(token_id): |
| return token |
| return str(index) |
|
|
| def convert_tokens_to_string(self, tokens): |
| return self.encoding.decode([self._convert_token_to_id(t) for t in tokens]) |
|
|
| def encode(self, text, add_special_tokens=False, **kwargs): |
| del add_special_tokens |
| if self.encode_chat_special_tokens: |
| return self.encoding.encode(text, allowed_special="all") |
| return self.encoding.encode(text, allowed_special=set(), disallowed_special=()) |
|
|
| def decode(self, token_ids, skip_special_tokens=False, **kwargs): |
| if isinstance(token_ids, torch.Tensor): |
| token_ids = token_ids.tolist() |
| if skip_special_tokens: |
| specials = {self.encoding.encode_single_token(t) for t in self.encoding.special_tokens_set} |
| token_ids = [int(i) for i in token_ids if int(i) not in specials] |
| return self.encoding.decode([int(i) for i in token_ids]) |
|
|
| def apply_chat_template(self, conversation, tokenize=False, add_generation_prompt=True, return_tensors=None, **kwargs): |
| messages = [m if isinstance(m, Message) else Message(m["role"], m["content"]) for m in conversation] |
| text = format_chat(messages) |
| if not add_generation_prompt and text.endswith("<|assistant|>"): |
| text = text[:-len("<|assistant|>")] |
| if not tokenize: |
| return text |
| ids = self.encode(text) |
| return torch.tensor([ids], dtype=torch.long) if return_tensors == "pt" else ids |
|
|
| def save_vocabulary(self, save_directory, filename_prefix=None): |
| out = Path(save_directory) / ((filename_prefix or "") + "vocab.txt") |
| if Path(self.vocab_file).resolve() != out.resolve(): |
| out.write_bytes(Path(self.vocab_file).read_bytes()) |
| return (str(out),) |
|
|