talkie_converted / tokenization_talkie.py
wave-on-discord's picture
Add files using upload-large-folder tool
23e626d verified
from __future__ import annotations
from pathlib import Path
import torch
from transformers import PreTrainedTokenizer
from transformers.utils import cached_file
from talkie.chat import Message, format_chat
from talkie.tokenizer import build_tokenizer
class TalkieTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.txt"}
def __init__(
self,
vocab_file="vocab.txt",
style="it",
vocab_size=65536,
encode_chat_special_tokens=None,
**kwargs,
):
vocab_path = Path(vocab_file)
if not vocab_path.exists() and not vocab_path.is_absolute():
vocab_path = Path(__file__).resolve().parent / vocab_path
if not vocab_path.exists():
name_or_path = kwargs.get("name_or_path") or kwargs.get("_name_or_path")
if name_or_path:
resolved = cached_file(name_or_path, vocab_file)
if resolved is not None:
vocab_path = Path(resolved)
if not vocab_path.exists():
raise OSError(f"TalkieTokenizer could not find vocab file: {vocab_file!r}")
self.vocab_file = str(vocab_path)
self.style = style
self._vocab_size = int(vocab_size)
if encode_chat_special_tokens is None:
encode_chat_special_tokens = style == "it" and self._vocab_size > 65536
self.encode_chat_special_tokens = bool(encode_chat_special_tokens)
tokenizer_style = style if self.encode_chat_special_tokens else "base"
self.encoding = build_tokenizer(self.vocab_file, style=tokenizer_style)
eos = "<|end|>" if self.encode_chat_special_tokens else "<|endoftext|>"
kwargs.setdefault("eos_token", eos)
kwargs.setdefault("pad_token", eos)
super().__init__(**kwargs)
self.eos_token = eos
self.pad_token = eos
@property
def vocab_size(self):
return 65536
def get_vocab(self):
vocab = {str(i): i for i in range(65536)}
if self.encode_chat_special_tokens:
vocab.update(self.encoding._special_tokens)
return vocab
def __len__(self):
return self._vocab_size
def _tokenize(self, text, **kwargs):
return [str(i) for i in self.encode(text)]
def _convert_token_to_id(self, token):
if token in self.encoding._special_tokens:
return self.encoding._special_tokens[token]
return int(token)
def _convert_id_to_token(self, index):
for token, token_id in self.encoding._special_tokens.items():
if int(index) == int(token_id):
return token
return str(index)
def convert_tokens_to_string(self, tokens):
return self.encoding.decode([self._convert_token_to_id(t) for t in tokens])
def encode(self, text, add_special_tokens=False, **kwargs):
del add_special_tokens
if self.encode_chat_special_tokens:
return self.encoding.encode(text, allowed_special="all")
return self.encoding.encode(text, allowed_special=set(), disallowed_special=())
def decode(self, token_ids, skip_special_tokens=False, **kwargs):
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
if skip_special_tokens:
specials = {self.encoding.encode_single_token(t) for t in self.encoding.special_tokens_set}
token_ids = [int(i) for i in token_ids if int(i) not in specials]
return self.encoding.decode([int(i) for i in token_ids])
def apply_chat_template(self, conversation, tokenize=False, add_generation_prompt=True, return_tensors=None, **kwargs):
messages = [m if isinstance(m, Message) else Message(m["role"], m["content"]) for m in conversation]
text = format_chat(messages)
if not add_generation_prompt and text.endswith("<|assistant|>"):
text = text[:-len("<|assistant|>")]
if not tokenize:
return text
ids = self.encode(text)
return torch.tensor([ids], dtype=torch.long) if return_tensors == "pt" else ids
def save_vocabulary(self, save_directory, filename_prefix=None):
out = Path(save_directory) / ((filename_prefix or "") + "vocab.txt")
if Path(self.vocab_file).resolve() != out.resolve():
out.write_bytes(Path(self.vocab_file).read_bytes())
return (str(out),)