| | from typing import List, Optional, Union, Dict, Tuple, Any |
| | import os |
| | from functools import cached_property |
| |
|
| | from transformers import PreTrainedTokenizerFast |
| | from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy |
| | from tokenizers import Tokenizer, processors |
| | from tokenizers.pre_tokenizers import WhitespaceSplit |
| | from tokenizers.processors import TemplateProcessing |
| | import torch |
| | from hangul_romanize import Transliter |
| | from hangul_romanize.rule import academic |
| | import cutlet |
| |
|
| | from TTS.tts.layers.xtts.tokenizer import (multilingual_cleaners, basic_cleaners, |
| | chinese_transliterate, korean_transliterate, |
| | japanese_cleaners) |
| |
|
| | class XTTSTokenizerFast(PreTrainedTokenizerFast): |
| | """ |
| | Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast |
| | """ |
| | def __init__( |
| | self, |
| | vocab_file: str = None, |
| | tokenizer_object: Optional[Tokenizer] = None, |
| | unk_token: str = "[UNK]", |
| | pad_token: str = "[PAD]", |
| | bos_token: str = "[START]", |
| | eos_token: str = "[STOP]", |
| | clean_up_tokenization_spaces: bool = True, |
| | **kwargs |
| | ): |
| | if tokenizer_object is None and vocab_file is not None: |
| | tokenizer_object = Tokenizer.from_file(vocab_file) |
| |
|
| | if tokenizer_object is not None: |
| | |
| | tokenizer_object.pre_tokenizer = WhitespaceSplit() |
| | tokenizer_object.enable_padding( |
| | direction='right', |
| | pad_id=tokenizer_object.token_to_id(pad_token) or 0, |
| | pad_token=pad_token |
| | ) |
| | tokenizer_object.post_processor = TemplateProcessing( |
| | single=f"{bos_token} $A {eos_token}", |
| | special_tokens=[ |
| | (bos_token, tokenizer_object.token_to_id(bos_token)), |
| | (eos_token, tokenizer_object.token_to_id(eos_token)), |
| | ], |
| | ) |
| |
|
| | super().__init__( |
| | tokenizer_object=tokenizer_object, |
| | unk_token=unk_token, |
| | pad_token=pad_token, |
| | bos_token=bos_token, |
| | eos_token=eos_token, |
| | clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
| | **kwargs |
| | ) |
| |
|
| | |
| | self.char_limits = { |
| | "en": 250, "de": 253, "fr": 273, "es": 239, |
| | "it": 213, "pt": 203, "pl": 224, "zh": 82, |
| | "ar": 166, "cs": 186, "ru": 182, "nl": 251, |
| | "tr": 226, "ja": 71, "hu": 224, "ko": 95, |
| | } |
| |
|
| | |
| | self._katsu = None |
| | self._korean_transliter = Transliter(academic) |
| |
|
| | @cached_property |
| | def katsu(self): |
| | if self._katsu is None: |
| | self._katsu = cutlet.Cutlet() |
| | return self._katsu |
| |
|
| | def check_input_length(self, text: str, lang: str): |
| | """Check if input text length is within limits for language""" |
| | lang = lang.split("-")[0] |
| | limit = self.char_limits.get(lang, 250) |
| | if len(text) > limit: |
| | print(f"Warning: Text length exceeds {limit} char limit for '{lang}', may cause truncation.") |
| |
|
| | def preprocess_text(self, text: str, lang: str) -> str: |
| | """Apply text preprocessing for language""" |
| | if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", |
| | "nl", "pl", "pt", "ru", "tr", "zh", "ko"}: |
| | text = multilingual_cleaners(text, lang) |
| | if lang == "zh": |
| | text = chinese_transliterate(text) |
| | if lang == "ko": |
| | text = korean_transliterate(text) |
| | elif lang == "ja": |
| | text = japanese_cleaners(text, self.katsu) |
| | else: |
| | text = basic_cleaners(text) |
| | return text |
| |
|
| | def _batch_encode_plus( |
| | self, |
| | batch_text_or_text_pairs, |
| | add_special_tokens: bool = True, |
| | padding_strategy = PaddingStrategy.DO_NOT_PAD, |
| | truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE, |
| | max_length: Optional[int] = 402, |
| | stride: int = 0, |
| | is_split_into_words: bool = False, |
| | pad_to_multiple_of: Optional[int] = None, |
| | return_tensors: Optional[str] = None, |
| | return_token_type_ids: Optional[bool] = None, |
| | return_attention_mask: Optional[bool] = None, |
| | return_overflowing_tokens: bool = False, |
| | return_special_tokens_mask: bool = False, |
| | return_offsets_mapping: bool = False, |
| | return_length: bool = False, |
| | verbose: bool = True, |
| | **kwargs |
| | ) -> Dict[str, Any]: |
| | """ |
| | Override batch encoding to handle language-specific preprocessing |
| | """ |
| | lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs)) |
| | if isinstance(lang, str): |
| | lang = [lang] * len(batch_text_or_text_pairs) |
| |
|
| | |
| | processed_texts = [] |
| | for text, text_lang in zip(batch_text_or_text_pairs, lang): |
| | if isinstance(text, str): |
| | |
| | self.check_input_length(text, text_lang) |
| | processed_text = self.preprocess_text(text, text_lang) |
| |
|
| | |
| | lang_code = "zh-cn" if text_lang == "zh" else text_lang |
| | processed_text = f"[{lang_code}]{processed_text}" |
| | processed_text = processed_text.replace(" ", "[SPACE]") |
| |
|
| | processed_texts.append(processed_text) |
| | else: |
| | processed_texts.append(text) |
| |
|
| | |
| | return super()._batch_encode_plus( |
| | processed_texts, |
| | add_special_tokens=add_special_tokens, |
| | padding_strategy=padding_strategy, |
| | truncation_strategy=truncation_strategy, |
| | max_length=max_length, |
| | stride=stride, |
| | is_split_into_words=is_split_into_words, |
| | pad_to_multiple_of=pad_to_multiple_of, |
| | return_tensors=return_tensors, |
| | return_token_type_ids=return_token_type_ids, |
| | return_attention_mask=return_attention_mask, |
| | return_overflowing_tokens=return_overflowing_tokens, |
| | return_special_tokens_mask=return_special_tokens_mask, |
| | return_offsets_mapping=return_offsets_mapping, |
| | return_length=return_length, |
| | verbose=verbose, |
| | **kwargs |
| | ) |
| |
|
| | def __call__( |
| | self, |
| | text: Union[str, List[str]], |
| | lang: Union[str, List[str]] = "en", |
| | add_special_tokens: bool = True, |
| | padding: Union[bool, str, PaddingStrategy] = True, |
| | truncation: Union[bool, str, TruncationStrategy] = True, |
| | max_length: Optional[int] = 402, |
| | stride: int = 0, |
| | return_tensors: Optional[str] = None, |
| | return_token_type_ids: Optional[bool] = None, |
| | return_attention_mask: Optional[bool] = True, |
| | **kwargs |
| | ): |
| | """ |
| | Main tokenization method |
| | Args: |
| | text: Text or list of texts to tokenize |
| | lang: Language code or list of language codes corresponding to each text |
| | add_special_tokens: Whether to add special tokens |
| | padding: Padding strategy (default True) |
| | truncation: Truncation strategy (default True) |
| | max_length: Maximum length |
| | stride: Stride for truncation |
| | return_tensors: Format of output tensors ("pt" for PyTorch) |
| | return_token_type_ids: Whether to return token type IDs |
| | return_attention_mask: Whether to return attention mask (default True) |
| | """ |
| | |
| | if isinstance(text, str): |
| | text = [text] |
| | if isinstance(lang, str): |
| | lang = [lang] |
| |
|
| | |
| | if len(text) != len(lang): |
| | raise ValueError(f"Number of texts ({len(text)}) must match number of language codes ({len(lang)})") |
| |
|
| | |
| | if isinstance(padding, bool): |
| | padding_strategy = PaddingStrategy.MAX_LENGTH if padding else PaddingStrategy.DO_NOT_PAD |
| | else: |
| | padding_strategy = PaddingStrategy(padding) |
| |
|
| | |
| | if isinstance(truncation, bool): |
| | truncation_strategy = TruncationStrategy.LONGEST_FIRST if truncation else TruncationStrategy.DO_NOT_TRUNCATE |
| | else: |
| | truncation_strategy = TruncationStrategy(truncation) |
| |
|
| | |
| | encoded = self._batch_encode_plus( |
| | text, |
| | add_special_tokens=add_special_tokens, |
| | padding_strategy=padding_strategy, |
| | truncation_strategy=truncation_strategy, |
| | max_length=max_length, |
| | stride=stride, |
| | return_tensors=return_tensors, |
| | return_token_type_ids=return_token_type_ids, |
| | return_attention_mask=return_attention_mask, |
| | lang=lang, |
| | **kwargs |
| | ) |
| |
|
| | return encoded |