# tokenization_mini_whisper.py """Tokenization classes for MiniWhisper""" import json import os import re from functools import lru_cache from typing import List, Optional, Tuple, Union import numpy as np from transformers import PreTrainedTokenizer from transformers.tokenization_utils_base import AddedToken from transformers.utils import logging logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = { "vocab_file": "vocab.json", "tokenizer_file": "tokenizer.json", "merges_file": "merges.txt", "normalizer_file": "normalizer.json", } LANGUAGES = { "en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian", "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese", "yue": "cantonese", } TO_LANGUAGE_CODE = { **{language: code for code, language in LANGUAGES.items()}, "burmese": "my", "valencian": "ca", "flemish": "nl", "haitian": "ht", "letzeburgesch": "lb", "pushto": "ps", "panjabi": "pa", "moldavian": "ro", "moldovan": "ro", "sinhalese": "si", "castilian": "es", "mandarin": "zh", } TASK_IDS = ["translate", "transcribe"] class MiniWhisperTokenizer(PreTrainedTokenizer): """ Construct a MiniWhisper tokenizer. For actual use, load pretrained: tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny") Or provide path to trained vocab/merges files: tokenizer = MiniWhisperTokenizer(vocab_file="vocab.json", merges_file="merges.txt") """ vocab_files_names = VOCAB_FILES_NAMES model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file=None, merges_file=None, normalizer_file=None, tokenizer_file=None, unk_token="<|endoftext|>", bos_token="<|endoftext|>", eos_token="<|endoftext|>", pad_token="<|endoftext|>", add_prefix_space=False, language=None, task=None, predict_timestamps=False, **kwargs, ): bos_token = AddedToken(bos_token, lstrip=False, rstrip=False, normalized=False, special=True) if isinstance(bos_token, str) else bos_token eos_token = AddedToken(eos_token, lstrip=False, rstrip=False, normalized=False, special=True) if isinstance(eos_token, str) else eos_token unk_token = AddedToken(unk_token, lstrip=False, rstrip=False, normalized=False, special=True) if isinstance(unk_token, str) else unk_token pad_token = AddedToken(pad_token, lstrip=False, rstrip=False, normalized=False, special=True) if isinstance(pad_token, str) else pad_token self._vocab = {} self._merges = [] if vocab_file is not None and os.path.isfile(vocab_file): with open(vocab_file, encoding="utf-8") as f: self._vocab = json.load(f) if merges_file is not None and os.path.isfile(merges_file): with open(merges_file, encoding="utf-8") as f: lines = f.read().splitlines() if lines and lines[0] == "#version: 0.2": self._merges = lines[1:] else: self._merges = lines self.english_spelling_normalizer = None if normalizer_file is not None and os.path.isfile(normalizer_file): with open(normalizer_file, encoding="utf-8") as f: self.english_spelling_normalizer = json.load(f) self.add_prefix_space = add_prefix_space self.language = language self.task = task self.predict_timestamps = predict_timestamps self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>") super().__init__( unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, add_prefix_space=add_prefix_space, **kwargs, ) self.set_prefix_tokens() @property def vocab_size(self) -> int: if self._vocab: return len(self._vocab) return 51865 def get_vocab(self): if self._vocab: return dict(self._vocab, **self.added_tokens_encoder) return {self.unk_token: 0, self.bos_token: 1, self.eos_token: 2, self.pad_token: 3} def _tokenize(self, text, **kwargs): """Tokenize a string using BPE""" if not self._vocab or not self._merges: return list(text) tokens = [] i = 0 text = text.lower() while i < len(text): if i < len(text) - 1: bigram = text[i:i+2] if bigram in self._merges: tokens.append(bigram) i += 2 continue tokens.append(text[i]) i += 1 return tokens def _convert_token_to_id(self, token): if self._vocab: return self._vocab.get(token, self._vocab.get(self.unk_token)) index = len(self._vocab) if self._vocab else 0 if token == self.unk_token: return index if token == self.bos_token: return index + 1 if token == self.eos_token: return index + 2 if token == self.pad_token: return index + 3 return index def _convert_id_to_token(self, index): if hasattr(self, 'ids_to_tokens') and self.ids_to_tokens: return self.ids_to_tokens.get(index, self.unk_token) return super()._convert_id_to_token(index) def convert_tokens_to_string(self, tokens): return "".join(tokens) @property def prefix_tokens(self) -> List[int]: bos_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") transcribe_token_id = self.convert_tokens_to_ids("<|transcribe|>") translate_token_id = self.convert_tokens_to_ids("<|translate|>") notimestamps_token_id = self.convert_tokens_to_ids("<|notimestamps|>") langs = tuple(LANGUAGES.keys()) if self.language is not None: self.language = self.language.lower() if self.language in TO_LANGUAGE_CODE: language_id = TO_LANGUAGE_CODE[self.language] elif self.language in TO_LANGUAGE_CODE.values(): language_id = self.language else: language_id = "en" bos_sequence = [bos_token_id] if self.language is not None: bos_sequence.append(bos_token_id + 1 + langs.index(language_id)) if self.task is not None: if self.task == "transcribe": bos_sequence.append(transcribe_token_id) elif self.task == "translate": bos_sequence.append(translate_token_id) if not self.predict_timestamps: bos_sequence.append(notimestamps_token_id) return bos_sequence def set_prefix_tokens(self, language=None, task=None, predict_timestamps=None): if language is not None: self.language = language if task is not None: self.task = task if predict_timestamps is not None: self.predict_timestamps = predict_timestamps def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: prefix = self.prefix_tokens eos = [self.eos_token_id] if token_ids_1 is None: return prefix + token_ids_0 + eos return prefix + token_ids_0 + token_ids_1 + eos def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True ) prefix_ones = [1] * len(self.prefix_tokens) suffix_ones = [1] if token_ids_1 is None: return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps) forced_tokens = self.prefix_tokens[1:] forced_decoder_ids = [(rank + 1, token_id) for rank, token_id in enumerate(forced_tokens)] return forced_decoder_ids def get_prompt_ids(self, text: str, return_tensors="np"): """Converts prompt text to IDs that can be passed to generation.""" batch_encoding = self("<|startofprev|>", " " + text.strip(), add_special_tokens=False) prompt_text_ids = batch_encoding["input_ids"][1:] special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None) if special_token_id is not None: token = self.convert_ids_to_tokens(special_token_id) raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.") batch_encoding.convert_to_tensors(tensor_type=return_tensors) return batch_encoding["input_ids"] def _decode( self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, normalize=False, basic_normalize=False, remove_diacritics=False, **kwargs ) -> str: filtered_ids = token_ids if skip_special_tokens: filtered_ids = [t for t in token_ids if t not in self.all_special_ids] tokens = self.convert_ids_to_tokens(filtered_ids) text = self.convert_tokens_to_string(tokens) if clean_up_tokenization_spaces: text = re.sub(r"\s+", " ", text).strip() text = re.sub(self.timestamp_pat, "", text) return text def decode( self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, output_offsets=False, time_precision=0.02, decode_with_timestamps=False, **kwargs ) -> str: return self._decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) merges_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] ) if self._vocab: with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, indent=2, ensure_ascii=False) if self._merges: with open(merges_file, "w", encoding="utf-8") as f: f.write("#version: 0.2\n") f.writelines(merge + "\n" for merge in self._merges) return (vocab_file, merges_file) @lru_cache def timestamp_ids(self, time_precision=0.02): return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)]) __all__ = ["MiniWhisperTokenizer"]