Delete tokenizer.py
Browse files- tokenizer.py +0 -233
tokenizer.py
DELETED
|
@@ -1,233 +0,0 @@
|
|
| 1 |
-
from typing import List, Optional, Union, Dict, Tuple, Any
|
| 2 |
-
import os
|
| 3 |
-
from functools import cached_property
|
| 4 |
-
|
| 5 |
-
from transformers import PreTrainedTokenizerFast
|
| 6 |
-
from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy
|
| 7 |
-
from tokenizers import Tokenizer, processors
|
| 8 |
-
from tokenizers.pre_tokenizers import WhitespaceSplit
|
| 9 |
-
from tokenizers.processors import TemplateProcessing
|
| 10 |
-
import torch
|
| 11 |
-
from hangul_romanize import Transliter
|
| 12 |
-
from hangul_romanize.rule import academic
|
| 13 |
-
import cutlet
|
| 14 |
-
|
| 15 |
-
from TTS.tts.layers.xtts.tokenizer import (multilingual_cleaners, basic_cleaners,
|
| 16 |
-
chinese_transliterate, korean_transliterate,
|
| 17 |
-
japanese_cleaners)
|
| 18 |
-
|
| 19 |
-
class XTTSTokenizerFast(PreTrainedTokenizerFast):
|
| 20 |
-
"""
|
| 21 |
-
Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast
|
| 22 |
-
"""
|
| 23 |
-
def __init__(
|
| 24 |
-
self,
|
| 25 |
-
vocab_file: str = None,
|
| 26 |
-
tokenizer_object: Optional[Tokenizer] = None,
|
| 27 |
-
unk_token: str = "[UNK]",
|
| 28 |
-
pad_token: str = "[PAD]",
|
| 29 |
-
bos_token: str = "[START]",
|
| 30 |
-
eos_token: str = "[STOP]",
|
| 31 |
-
clean_up_tokenization_spaces: bool = True,
|
| 32 |
-
**kwargs
|
| 33 |
-
):
|
| 34 |
-
if tokenizer_object is None and vocab_file is not None:
|
| 35 |
-
tokenizer_object = Tokenizer.from_file(vocab_file)
|
| 36 |
-
|
| 37 |
-
if tokenizer_object is not None:
|
| 38 |
-
# Configure the tokenizer
|
| 39 |
-
tokenizer_object.pre_tokenizer = WhitespaceSplit()
|
| 40 |
-
tokenizer_object.enable_padding(
|
| 41 |
-
direction='right',
|
| 42 |
-
pad_id=tokenizer_object.token_to_id(pad_token) or 0,
|
| 43 |
-
pad_token=pad_token
|
| 44 |
-
)
|
| 45 |
-
tokenizer_object.post_processor = TemplateProcessing(
|
| 46 |
-
single=f"{bos_token} $A {eos_token}",
|
| 47 |
-
special_tokens=[
|
| 48 |
-
(bos_token, tokenizer_object.token_to_id(bos_token)),
|
| 49 |
-
(eos_token, tokenizer_object.token_to_id(eos_token)),
|
| 50 |
-
],
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
super().__init__(
|
| 54 |
-
tokenizer_object=tokenizer_object,
|
| 55 |
-
unk_token=unk_token,
|
| 56 |
-
pad_token=pad_token,
|
| 57 |
-
bos_token=bos_token,
|
| 58 |
-
eos_token=eos_token,
|
| 59 |
-
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 60 |
-
**kwargs
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
# Character limits per language
|
| 64 |
-
self.char_limits = {
|
| 65 |
-
"en": 250, "de": 253, "fr": 273, "es": 239,
|
| 66 |
-
"it": 213, "pt": 203, "pl": 224, "zh": 82,
|
| 67 |
-
"ar": 166, "cs": 186, "ru": 182, "nl": 251,
|
| 68 |
-
"tr": 226, "ja": 71, "hu": 224, "ko": 95,
|
| 69 |
-
}
|
| 70 |
-
|
| 71 |
-
# Initialize language tools
|
| 72 |
-
self._katsu = None
|
| 73 |
-
self._korean_transliter = Transliter(academic)
|
| 74 |
-
|
| 75 |
-
@cached_property
|
| 76 |
-
def katsu(self):
|
| 77 |
-
if self._katsu is None:
|
| 78 |
-
self._katsu = cutlet.Cutlet()
|
| 79 |
-
return self._katsu
|
| 80 |
-
|
| 81 |
-
def check_input_length(self, text: str, lang: str):
|
| 82 |
-
"""Check if input text length is within limits for language"""
|
| 83 |
-
lang = lang.split("-")[0] # remove region
|
| 84 |
-
limit = self.char_limits.get(lang, 250)
|
| 85 |
-
if len(text) > limit:
|
| 86 |
-
print(f"Warning: Text length exceeds {limit} char limit for '{lang}', may cause truncation.")
|
| 87 |
-
|
| 88 |
-
def preprocess_text(self, text: str, lang: str) -> str:
|
| 89 |
-
"""Apply text preprocessing for language"""
|
| 90 |
-
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it",
|
| 91 |
-
"nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
|
| 92 |
-
text = multilingual_cleaners(text, lang)
|
| 93 |
-
if lang == "zh":
|
| 94 |
-
text = chinese_transliterate(text)
|
| 95 |
-
if lang == "ko":
|
| 96 |
-
text = korean_transliterate(text)
|
| 97 |
-
elif lang == "ja":
|
| 98 |
-
text = japanese_cleaners(text, self.katsu)
|
| 99 |
-
else:
|
| 100 |
-
text = basic_cleaners(text)
|
| 101 |
-
return text
|
| 102 |
-
|
| 103 |
-
def _batch_encode_plus(
|
| 104 |
-
self,
|
| 105 |
-
batch_text_or_text_pairs,
|
| 106 |
-
add_special_tokens: bool = True,
|
| 107 |
-
padding_strategy = PaddingStrategy.DO_NOT_PAD,
|
| 108 |
-
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
| 109 |
-
max_length: Optional[int] = 402,
|
| 110 |
-
stride: int = 0,
|
| 111 |
-
is_split_into_words: bool = False,
|
| 112 |
-
pad_to_multiple_of: Optional[int] = None,
|
| 113 |
-
return_tensors: Optional[str] = None,
|
| 114 |
-
return_token_type_ids: Optional[bool] = None,
|
| 115 |
-
return_attention_mask: Optional[bool] = None,
|
| 116 |
-
return_overflowing_tokens: bool = False,
|
| 117 |
-
return_special_tokens_mask: bool = False,
|
| 118 |
-
return_offsets_mapping: bool = False,
|
| 119 |
-
return_length: bool = False,
|
| 120 |
-
verbose: bool = True,
|
| 121 |
-
**kwargs
|
| 122 |
-
) -> Dict[str, Any]:
|
| 123 |
-
"""
|
| 124 |
-
Override batch encoding to handle language-specific preprocessing
|
| 125 |
-
"""
|
| 126 |
-
lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs))
|
| 127 |
-
if isinstance(lang, str):
|
| 128 |
-
lang = [lang] * len(batch_text_or_text_pairs)
|
| 129 |
-
|
| 130 |
-
# Preprocess each text in the batch with its corresponding language
|
| 131 |
-
processed_texts = []
|
| 132 |
-
for text, text_lang in zip(batch_text_or_text_pairs, lang):
|
| 133 |
-
if isinstance(text, str):
|
| 134 |
-
# Check length and preprocess
|
| 135 |
-
self.check_input_length(text, text_lang)
|
| 136 |
-
processed_text = self.preprocess_text(text, text_lang)
|
| 137 |
-
|
| 138 |
-
# Format text with language tag and spaces
|
| 139 |
-
lang_code = "zh-cn" if text_lang == "zh" else text_lang
|
| 140 |
-
processed_text = f"[{lang_code}]{processed_text}"
|
| 141 |
-
processed_text = processed_text.replace(" ", "[SPACE]")
|
| 142 |
-
|
| 143 |
-
processed_texts.append(processed_text)
|
| 144 |
-
else:
|
| 145 |
-
processed_texts.append(text)
|
| 146 |
-
|
| 147 |
-
# Call the parent class's encoding method with processed texts
|
| 148 |
-
return super()._batch_encode_plus(
|
| 149 |
-
processed_texts,
|
| 150 |
-
add_special_tokens=add_special_tokens,
|
| 151 |
-
padding_strategy=padding_strategy,
|
| 152 |
-
truncation_strategy=truncation_strategy,
|
| 153 |
-
max_length=max_length,
|
| 154 |
-
stride=stride,
|
| 155 |
-
is_split_into_words=is_split_into_words,
|
| 156 |
-
pad_to_multiple_of=pad_to_multiple_of,
|
| 157 |
-
return_tensors=return_tensors,
|
| 158 |
-
return_token_type_ids=return_token_type_ids,
|
| 159 |
-
return_attention_mask=return_attention_mask,
|
| 160 |
-
return_overflowing_tokens=return_overflowing_tokens,
|
| 161 |
-
return_special_tokens_mask=return_special_tokens_mask,
|
| 162 |
-
return_offsets_mapping=return_offsets_mapping,
|
| 163 |
-
return_length=return_length,
|
| 164 |
-
verbose=verbose,
|
| 165 |
-
**kwargs
|
| 166 |
-
)
|
| 167 |
-
|
| 168 |
-
def __call__(
|
| 169 |
-
self,
|
| 170 |
-
text: Union[str, List[str]],
|
| 171 |
-
lang: Union[str, List[str]] = "en",
|
| 172 |
-
add_special_tokens: bool = True,
|
| 173 |
-
padding: Union[bool, str, PaddingStrategy] = True, # Changed default to True
|
| 174 |
-
truncation: Union[bool, str, TruncationStrategy] = True, # Changed default to True
|
| 175 |
-
max_length: Optional[int] = 402,
|
| 176 |
-
stride: int = 0,
|
| 177 |
-
return_tensors: Optional[str] = None,
|
| 178 |
-
return_token_type_ids: Optional[bool] = None,
|
| 179 |
-
return_attention_mask: Optional[bool] = True, # Changed default to True
|
| 180 |
-
**kwargs
|
| 181 |
-
):
|
| 182 |
-
"""
|
| 183 |
-
Main tokenization method
|
| 184 |
-
Args:
|
| 185 |
-
text: Text or list of texts to tokenize
|
| 186 |
-
lang: Language code or list of language codes corresponding to each text
|
| 187 |
-
add_special_tokens: Whether to add special tokens
|
| 188 |
-
padding: Padding strategy (default True)
|
| 189 |
-
truncation: Truncation strategy (default True)
|
| 190 |
-
max_length: Maximum length
|
| 191 |
-
stride: Stride for truncation
|
| 192 |
-
return_tensors: Format of output tensors ("pt" for PyTorch)
|
| 193 |
-
return_token_type_ids: Whether to return token type IDs
|
| 194 |
-
return_attention_mask: Whether to return attention mask (default True)
|
| 195 |
-
"""
|
| 196 |
-
# Convert single string to list for batch processing
|
| 197 |
-
if isinstance(text, str):
|
| 198 |
-
text = [text]
|
| 199 |
-
if isinstance(lang, str):
|
| 200 |
-
lang = [lang]
|
| 201 |
-
|
| 202 |
-
# Ensure text and lang lists have same length
|
| 203 |
-
if len(text) != len(lang):
|
| 204 |
-
raise ValueError(f"Number of texts ({len(text)}) must match number of language codes ({len(lang)})")
|
| 205 |
-
|
| 206 |
-
# Convert padding strategy
|
| 207 |
-
if isinstance(padding, bool):
|
| 208 |
-
padding_strategy = PaddingStrategy.MAX_LENGTH if padding else PaddingStrategy.DO_NOT_PAD
|
| 209 |
-
else:
|
| 210 |
-
padding_strategy = PaddingStrategy(padding)
|
| 211 |
-
|
| 212 |
-
# Convert truncation strategy
|
| 213 |
-
if isinstance(truncation, bool):
|
| 214 |
-
truncation_strategy = TruncationStrategy.LONGEST_FIRST if truncation else TruncationStrategy.DO_NOT_TRUNCATE
|
| 215 |
-
else:
|
| 216 |
-
truncation_strategy = TruncationStrategy(truncation)
|
| 217 |
-
|
| 218 |
-
# Use the batch encoding method
|
| 219 |
-
encoded = self._batch_encode_plus(
|
| 220 |
-
text,
|
| 221 |
-
add_special_tokens=add_special_tokens,
|
| 222 |
-
padding_strategy=padding_strategy,
|
| 223 |
-
truncation_strategy=truncation_strategy,
|
| 224 |
-
max_length=max_length,
|
| 225 |
-
stride=stride,
|
| 226 |
-
return_tensors=return_tensors,
|
| 227 |
-
return_token_type_ids=return_token_type_ids,
|
| 228 |
-
return_attention_mask=return_attention_mask,
|
| 229 |
-
lang=lang,
|
| 230 |
-
**kwargs
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
return encoded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|