MiniWhisper-ASR / tokenization_mini_whisper.py
NeuraCraft's picture
Update Main Files
f8a9c31 verified
# tokenization_mini_whisper.py
"""Tokenization classes for MiniWhisper"""
import json
import os
import re
from functools import lru_cache
from typing import List, Optional, Tuple, Union
import numpy as np
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils_base import AddedToken
from transformers.utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"tokenizer_file": "tokenizer.json",
"merges_file": "merges.txt",
"normalizer_file": "normalizer.json",
}
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
"mandarin": "zh",
}
TASK_IDS = ["translate", "transcribe"]
class MiniWhisperTokenizer(PreTrainedTokenizer):
"""
Construct a MiniWhisper tokenizer.
For actual use, load pretrained:
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
Or provide path to trained vocab/merges files:
tokenizer = MiniWhisperTokenizer(vocab_file="vocab.json", merges_file="merges.txt")
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
merges_file=None,
normalizer_file=None,
tokenizer_file=None,
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
add_prefix_space=False,
language=None,
task=None,
predict_timestamps=False,
**kwargs,
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
self._vocab = {}
self._merges = []
if vocab_file is not None and os.path.isfile(vocab_file):
with open(vocab_file, encoding="utf-8") as f:
self._vocab = json.load(f)
if merges_file is not None and os.path.isfile(merges_file):
with open(merges_file, encoding="utf-8") as f:
lines = f.read().splitlines()
if lines and lines[0] == "#version: 0.2":
self._merges = lines[1:]
else:
self._merges = lines
self.english_spelling_normalizer = None
if normalizer_file is not None and os.path.isfile(normalizer_file):
with open(normalizer_file, encoding="utf-8") as f:
self.english_spelling_normalizer = json.load(f)
self.add_prefix_space = add_prefix_space
self.language = language
self.task = task
self.predict_timestamps = predict_timestamps
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
self.set_prefix_tokens()
@property
def vocab_size(self) -> int:
if self._vocab:
return len(self._vocab)
return 51865
def get_vocab(self):
if self._vocab:
return dict(self._vocab, **self.added_tokens_encoder)
return {self.unk_token: 0, self.bos_token: 1, self.eos_token: 2, self.pad_token: 3}
def _tokenize(self, text, **kwargs):
"""Tokenize a string using BPE"""
if not self._vocab or not self._merges:
return list(text)
tokens = []
i = 0
text = text.lower()
while i < len(text):
if i < len(text) - 1:
bigram = text[i:i+2]
if bigram in self._merges:
tokens.append(bigram)
i += 2
continue
tokens.append(text[i])
i += 1
return tokens
def _convert_token_to_id(self, token):
if self._vocab:
return self._vocab.get(token, self._vocab.get(self.unk_token))
index = len(self._vocab) if self._vocab else 0
if token == self.unk_token:
return index
if token == self.bos_token:
return index + 1
if token == self.eos_token:
return index + 2
if token == self.pad_token:
return index + 3
return index
def _convert_id_to_token(self, index):
if hasattr(self, 'ids_to_tokens') and self.ids_to_tokens:
return self.ids_to_tokens.get(index, self.unk_token)
return super()._convert_id_to_token(index)
def convert_tokens_to_string(self, tokens):
return "".join(tokens)
@property
def prefix_tokens(self) -> List[int]:
bos_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
transcribe_token_id = self.convert_tokens_to_ids("<|transcribe|>")
translate_token_id = self.convert_tokens_to_ids("<|translate|>")
notimestamps_token_id = self.convert_tokens_to_ids("<|notimestamps|>")
langs = tuple(LANGUAGES.keys())
if self.language is not None:
self.language = self.language.lower()
if self.language in TO_LANGUAGE_CODE:
language_id = TO_LANGUAGE_CODE[self.language]
elif self.language in TO_LANGUAGE_CODE.values():
language_id = self.language
else:
language_id = "en"
bos_sequence = [bos_token_id]
if self.language is not None:
bos_sequence.append(bos_token_id + 1 + langs.index(language_id))
if self.task is not None:
if self.task == "transcribe":
bos_sequence.append(transcribe_token_id)
elif self.task == "translate":
bos_sequence.append(translate_token_id)
if not self.predict_timestamps:
bos_sequence.append(notimestamps_token_id)
return bos_sequence
def set_prefix_tokens(self, language=None, task=None, predict_timestamps=None):
if language is not None:
self.language = language
if task is not None:
self.task = task
if predict_timestamps is not None:
self.predict_timestamps = predict_timestamps
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
prefix = self.prefix_tokens
eos = [self.eos_token_id]
if token_ids_1 is None:
return prefix + token_ids_0 + eos
return prefix + token_ids_0 + token_ids_1 + eos
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1,
already_has_special_tokens=True
)
prefix_ones = [1] * len(self.prefix_tokens)
suffix_ones = [1]
if token_ids_1 is None:
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps)
forced_tokens = self.prefix_tokens[1:]
forced_decoder_ids = [(rank + 1, token_id) for rank, token_id in enumerate(forced_tokens)]
return forced_decoder_ids
def get_prompt_ids(self, text: str, return_tensors="np"):
"""Converts prompt text to IDs that can be passed to generation."""
batch_encoding = self("<|startofprev|>", " " + text.strip(), add_special_tokens=False)
prompt_text_ids = batch_encoding["input_ids"][1:]
special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None)
if special_token_id is not None:
token = self.convert_ids_to_tokens(special_token_id)
raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.")
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
return batch_encoding["input_ids"]
def _decode(
self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True,
normalize=False, basic_normalize=False, remove_diacritics=False, **kwargs
) -> str:
filtered_ids = token_ids
if skip_special_tokens:
filtered_ids = [t for t in token_ids if t not in self.all_special_ids]
tokens = self.convert_ids_to_tokens(filtered_ids)
text = self.convert_tokens_to_string(tokens)
if clean_up_tokenization_spaces:
text = re.sub(r"\s+", " ", text).strip()
text = re.sub(self.timestamp_pat, "", text)
return text
def decode(
self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True,
output_offsets=False, time_precision=0.02, decode_with_timestamps=False, **kwargs
) -> str:
return self._decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
merges_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
if self._vocab:
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, indent=2, ensure_ascii=False)
if self._merges:
with open(merges_file, "w", encoding="utf-8") as f:
f.write("#version: 0.2\n")
f.writelines(merge + "\n" for merge in self._merges)
return (vocab_file, merges_file)
@lru_cache
def timestamp_ids(self, time_precision=0.02):
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
__all__ = ["MiniWhisperTokenizer"]