Spec-2 / spec2_tokenizer.py
SVECTOR-OFFICIAL's picture
Upload 13 files
33c60bf verified
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The Spec-2 Authors
# Licensed under the Apache License, Version 2.0 (the "License")
"""Tokenizer for Spec-2 model"""
import json
import os
from typing import Dict, List, Optional, Tuple, Union
import regex as re
from transformers import PreTrainedTokenizer
from transformers.utils import is_sentencepiece_available, logging
if is_sentencepiece_available():
import sentencepiece as spm
else:
raise ImportError(
"You need to install sentencepiece to use Spec2Tokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
logger = logging.get_logger(__name__)
class Spec2Tokenizer(PreTrainedTokenizer):
"""
Construct a Spec-2 tokenizer based on SentencePiece.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
Path to the vocabulary file generated by SentencePiece.
additional_special_tokens (`List[str]`, *optional*):
Additional special tokens used by the tokenizer.
bos_token (`str`, *optional*, defaults to `"<bos>"`):
The beginning of sequence token that was used during pretraining.
eos_token (`str`, *optional*, defaults to `"<eos>"`):
The end of sequence token.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
sp_model_kwargs (`dict`, *optional*):
Arguments to be passed to the SentencePiece model initialization method.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the tokenization spaces.
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
Whether or not to use the default system prompt.
"""
vocab_files_names = {"vocab_file": "tokenizer.model"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
additional_special_tokens=None,
bos_token="<bos>",
eos_token="<eos>",
unk_token="<unk>",
pad_token="<pad>",
sp_model_kwargs: Optional[Dict[str, str]] = None,
clean_up_tokenization_spaces=True,
use_default_system_prompt=False,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
# Mapping special tokens
self.special_tokens_map = {
"bos_token": bos_token,
"eos_token": eos_token,
"unk_token": unk_token,
"pad_token": pad_token,
}
# Add additional special tokens
self._additional_special_tokens = []
if additional_special_tokens:
self._additional_special_tokens = list(additional_special_tokens)
self.use_default_system_prompt = use_default_system_prompt
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
# Dictionary to store the token ids for special tokens
self.special_token_ids = {}
for token_name, token in self.special_tokens_map.items():
token_id = self.sp_model.piece_to_id(token)
self.special_token_ids[token_name] = token_id
setattr(self, f"{token_name}_id", token_id)
# Load additional special token mappings if available
self.vocab_mapping = {}
vocab_mapping_file = os.path.join(os.path.dirname(vocab_file), "tokenizer_config.json")
if os.path.exists(vocab_mapping_file):
with open(vocab_mapping_file, "r", encoding="utf-8") as f:
config = json.load(f)
if "vocab_mapping" in config:
self.vocab_mapping = config["vocab_mapping"]
# Initialize PreTrainedTokenizer
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
additional_special_tokens=self._additional_special_tokens,
**kwargs,
)
@property
def vocab_size(self):
"""Return the size of vocabulary."""
return self.sp_model.get_piece_size()
def get_vocab(self):
"""Return vocab as a dict."""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text):
"""Tokenize a string."""
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Convert a token to an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Convert an id to a token using the vocab."""
if index in self.added_tokens_decoder:
return self.added_tokens_decoder[index]
if index >= self.sp_model.get_piece_size():
for token_id_str, info in self.vocab_mapping.items():
if int(token_id_str) == index:
return info["content"]
return self.unk_token
token = self.sp_model.id_to_piece(index)
return token
def convert_tokens_to_string(self, tokens):
"""Convert a list of tokens to a string."""
text = self.sp_model.decode(tokens)
if self.clean_up_tokenization_spaces:
text = self.clean_up_tokenization(text)
return text
def save_vocabulary(self, save_directory, filename_prefix=None):
"""Save the vocabulary to a directory."""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
import shutil
shutil.copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content = self.sp_model.serialized_model_proto()
fi.write(content)
# Save tokenizer config with vocab mapping
config_file = os.path.join(save_directory, "tokenizer_config.json")
tokenizer_config = {
"vocab_file": self.vocab_files_names["vocab_file"],
"bos_token": self.bos_token,
"eos_token": self.eos_token,
"unk_token": self.unk_token,
"pad_token": self.pad_token,
"additional_special_tokens": self._additional_special_tokens,
"clean_up_tokenization_spaces": self.clean_up_tokenization_spaces,
"use_default_system_prompt": self.use_default_system_prompt,
"sp_model_kwargs": self.sp_model_kwargs,
"tokenizer_class": "Spec2Tokenizer",
"vocab_mapping": self.vocab_mapping
}
with open(config_file, "w", encoding="utf-8") as f:
json.dump(tokenizer_config, f, indent=2)
return (out_vocab_file, config_file)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
if token_ids_1 is None:
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences.
"""
eos = [self.eos_token_id]
bos = [self.bos_token_id]
if token_ids_1 is None:
return len(bos + token_ids_0 + eos) * [0]
return len(bos + token_ids_0 + eos + token_ids_1 + eos) * [0]
def prepare_for_model(
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
add_special_tokens: bool = True,
**kwargs
):
"""
Prepare inputs for the model.
"""
return super().prepare_for_model(
ids, pair_ids, add_special_tokens=add_special_tokens, **kwargs
)
def prepare_seq2seq_batch(
self,
src_texts: Union[str, List[str]],
tgt_texts: Optional[Union[str, List[str]]] = None,
**kwargs
):
"""
Prepare a batch for sequence-to-sequence tasks.
"""
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)