| from dataclasses import asdict, dataclass |
| from typing import Dict, List, Optional |
| from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
| @dataclass |
| class SpeakerEncoderConfig: |
| """Configuration for the speaker encoder component""" |
| model_name: str = "speaker_encoder" |
| preprocess_config: Optional[Dict] = None |
| model_config: Optional[Dict] = None |
| speaker_embedding_dim: int = 512 |
| use_torch_spec: bool = True |
|
|
|
|
| @dataclass |
| class XTTSAudioConfig: |
| """Configuration for audio processing parameters""" |
| sample_rate: int = 22050 |
| output_sample_rate: int = 24000 |
| mel_channels: int = 80 |
| hop_length: int = 256 |
| win_length: int = 1024 |
| n_fft: int = 1024 |
| fmin: int = 0 |
| fmax: int = 8000 |
| power: float = 1.0 |
| mel_norms_file: Optional[str] = None |
|
|
|
|
| class XTTSConfig(PretrainedConfig): |
| """Combined configuration class for XTTS including both HifiGAN and GPT components""" |
| model_type = "xtts" |
|
|
| def __init__( |
| self, |
| |
| input_sample_rate: int = 22050, |
| output_sample_rate: int = 24000, |
| output_hop_length: int = 256, |
| |
| |
| decoder_input_dim: int = 1024, |
| d_vector_dim: int = 512, |
| cond_d_vector_in_each_upsampling_layer: bool = True, |
| |
| |
| upsample_rates: List[int] = None, |
| upsample_kernel_sizes: List[int] = None, |
| upsample_initial_channel: int = 512, |
| |
| |
| resblock_kernel_sizes: List[int] = None, |
| resblock_dilation_sizes: List[List[int]] = None, |
| |
| |
| speaker_encoder_config: Optional[Dict] = None, |
| |
| |
| vocab_size: int = 256, |
| num_chars: int = 255, |
| |
| |
| gpt_batch_size: int = 1, |
| gpt_max_audio_tokens: int = 605, |
| gpt_max_text_tokens: int = 402, |
| gpt_max_prompt_tokens: int = 70, |
| gpt_layers: int = 30, |
| gpt_n_model_channels: int = 1024, |
| gpt_n_heads: int = 16, |
| gpt_number_text_tokens: int = 6681, |
| gpt_start_text_token: Optional[int] = None, |
| gpt_stop_text_token: Optional[int] = None, |
| gpt_num_audio_tokens: int = 1026, |
| gpt_start_audio_token: int = 1024, |
| gpt_stop_audio_token: int = 1025, |
| gpt_code_stride_len: int = 1024, |
| gpt_use_masking_gt_prompt_approach: bool = True, |
| gpt_use_perceiver_resampler: bool = True, |
| gpt_checkpointing: bool = False, |
| gpt_train_solo_embeddings: bool = False, |
| |
| |
| enable_redaction: bool = False, |
| kv_cache: bool = True, |
| perceiver_cond_length_compression: int = 256, |
| label_smoothing: float = 0.0, |
| |
| |
| temperature: float = 0.75, |
| length_penalty: float = 1.0, |
| repetition_penalty: float = 5.0, |
| top_k: int = 50, |
| top_p: float = 0.85, |
| gpt_cond_len: int = 30, |
| gpt_cond_chunk_len: int = 4, |
| max_ref_len: int = 30, |
| sound_norm_refs: bool = False, |
| |
| |
| audio_config: Optional[XTTSAudioConfig] = None, |
| |
| |
| duration_const: int = 102400, |
| char_limits: Optional[Dict[str, int]] = None, |
| languages: Optional[List[str]] = None, |
| |
| |
| pad_token_id: Optional[int] = None, |
| bos_token_id: Optional[int] = None, |
| eos_token_id: Optional[int] = None, |
| **kwargs, |
| ): |
| super().__init__( |
| pad_token_id=pad_token_id, |
| bos_token_id=bos_token_id, |
| eos_token_id=eos_token_id, |
| **kwargs |
| ) |
|
|
| |
| if upsample_rates is None: |
| upsample_rates = [8, 8, 2, 2] |
| if upsample_kernel_sizes is None: |
| upsample_kernel_sizes = [16, 16, 4, 4] |
| if resblock_kernel_sizes is None: |
| resblock_kernel_sizes = [3, 7, 11] |
| if resblock_dilation_sizes is None: |
| resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] |
|
|
| |
| if char_limits is None: |
| char_limits = { |
| "en": 250, "de": 253, "fr": 273, "es": 239, |
| "it": 213, "pt": 203, "pl": 224, "zh": 82, |
| "ar": 166, "cs": 186, "ru": 182, "nl": 251, |
| "tr": 226, "ja": 71, "hu": 224, "ko": 95, |
| } |
|
|
| if languages is None: |
| languages = [ |
| "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", |
| "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi" |
| ] |
|
|
| |
| |
| self.input_sample_rate = input_sample_rate |
| self.output_sample_rate = output_sample_rate |
| self.output_hop_length = output_hop_length |
|
|
| |
| self.decoder_input_dim = decoder_input_dim |
| self.d_vector_dim = d_vector_dim |
| self.cond_d_vector_in_each_upsampling_layer = cond_d_vector_in_each_upsampling_layer |
|
|
| |
| self.upsample_rates = upsample_rates |
| self.upsample_kernel_sizes = upsample_kernel_sizes |
| self.upsample_initial_channel = upsample_initial_channel |
|
|
| |
| self.resblock_kernel_sizes = resblock_kernel_sizes |
| self.resblock_dilation_sizes = resblock_dilation_sizes |
|
|
| |
| if speaker_encoder_config is None: |
| self.speaker_encoder_config = asdict(SpeakerEncoderConfig()) |
| elif isinstance(speaker_encoder_config, dict): |
| default_config = asdict(SpeakerEncoderConfig()) |
| default_config.update(speaker_encoder_config) |
| self.speaker_encoder_config = default_config |
| elif isinstance(speaker_encoder_config, SpeakerEncoderConfig): |
| self.speaker_encoder_config = asdict(speaker_encoder_config) |
| else: |
| raise ValueError("speaker_encoder_config must be either a dictionary or SpeakerEncoderConfig instance") |
|
|
| |
| self.vocab_size = vocab_size |
| self.num_chars = num_chars |
|
|
| |
| self.gpt_batch_size = gpt_batch_size |
| self.gpt_max_audio_tokens = gpt_max_audio_tokens |
| self.gpt_max_text_tokens = gpt_max_text_tokens |
| self.gpt_max_prompt_tokens = gpt_max_prompt_tokens |
| self.gpt_layers = gpt_layers |
| self.gpt_n_model_channels = gpt_n_model_channels |
| self.gpt_n_heads = gpt_n_heads |
| self.gpt_number_text_tokens = gpt_number_text_tokens |
| self.gpt_start_text_token = gpt_start_text_token |
| self.gpt_stop_text_token = gpt_stop_text_token |
| self.gpt_num_audio_tokens = gpt_num_audio_tokens |
| self.gpt_start_audio_token = gpt_start_audio_token |
| self.gpt_stop_audio_token = gpt_stop_audio_token |
| self.gpt_code_stride_len = gpt_code_stride_len |
| self.gpt_use_masking_gt_prompt_approach = gpt_use_masking_gt_prompt_approach |
| self.gpt_use_perceiver_resampler = gpt_use_perceiver_resampler |
| self.gpt_checkpointing = gpt_checkpointing |
| self.gpt_train_solo_embeddings = gpt_train_solo_embeddings |
|
|
| |
| self.enable_redaction = enable_redaction |
| self.kv_cache = kv_cache |
| self.perceiver_cond_length_compression = perceiver_cond_length_compression |
| self.label_smoothing = label_smoothing |
|
|
| |
| self.temperature = temperature |
| self.length_penalty = length_penalty |
| self.repetition_penalty = repetition_penalty |
| self.top_k = top_k |
| self.top_p = top_p |
| self.gpt_cond_len = gpt_cond_len |
| self.gpt_cond_chunk_len = gpt_cond_chunk_len |
| self.max_ref_len = max_ref_len |
| self.sound_norm_refs = sound_norm_refs |
|
|
| |
| if audio_config is None: |
| audio_config = XTTSAudioConfig() |
| elif isinstance(audio_config, dict): |
| audio_config = XTTSAudioConfig(**audio_config) |
| self.audio_config = audio_config |
|
|
| |
| self.duration_const = duration_const |
| self.char_limits = char_limits |
| self.languages = languages |
|
|
| def to_dict(self) -> Dict: |
| """Convert the config to a dictionary format.""" |
| |
| output = super().to_dict() |
|
|
| |
| output.update({ |
| |
| "input_sample_rate": self.input_sample_rate, |
| "output_sample_rate": self.output_sample_rate, |
| "output_hop_length": self.output_hop_length, |
| "decoder_input_dim": self.decoder_input_dim, |
| "d_vector_dim": self.d_vector_dim, |
| "cond_d_vector_in_each_upsampling_layer": self.cond_d_vector_in_each_upsampling_layer, |
| "upsample_rates": self.upsample_rates, |
| "upsample_kernel_sizes": self.upsample_kernel_sizes, |
| "upsample_initial_channel": self.upsample_initial_channel, |
| "resblock_kernel_sizes": self.resblock_kernel_sizes, |
| "resblock_dilation_sizes": self.resblock_dilation_sizes, |
| "speaker_encoder_config": self.speaker_encoder_config, |
|
|
| |
| "vocab_size": self.vocab_size, |
| "num_chars": self.num_chars, |
| "gpt_batch_size": self.gpt_batch_size, |
| "gpt_max_audio_tokens": self.gpt_max_audio_tokens, |
| "gpt_max_text_tokens": self.gpt_max_text_tokens, |
| "gpt_max_prompt_tokens": self.gpt_max_prompt_tokens, |
| "gpt_layers": self.gpt_layers, |
| "gpt_n_model_channels": self.gpt_n_model_channels, |
| "gpt_n_heads": self.gpt_n_heads, |
| "gpt_number_text_tokens": self.gpt_number_text_tokens, |
| "gpt_start_text_token": self.gpt_start_text_token, |
| "gpt_stop_text_token": self.gpt_stop_text_token, |
| "gpt_num_audio_tokens": self.gpt_num_audio_tokens, |
| "gpt_start_audio_token": self.gpt_start_audio_token, |
| "gpt_stop_audio_token": self.gpt_stop_audio_token, |
| "gpt_code_stride_len": self.gpt_code_stride_len, |
| "gpt_use_masking_gt_prompt_approach": self.gpt_use_masking_gt_prompt_approach, |
| "gpt_use_perceiver_resampler": self.gpt_use_perceiver_resampler, |
| "gpt_checkpointing": self.gpt_checkpointing, |
| "gpt_train_solo_embeddings": self.gpt_train_solo_embeddings, |
| "enable_redaction": self.enable_redaction, |
| "kv_cache": self.kv_cache, |
| "perceiver_cond_length_compression": self.perceiver_cond_length_compression, |
| "label_smoothing": self.label_smoothing, |
| "temperature": self.temperature, |
| "length_penalty": self.length_penalty, |
| "repetition_penalty": self.repetition_penalty, |
| "top_k": self.top_k, |
| "top_p": self.top_p, |
| "gpt_cond_len": self.gpt_cond_len, |
| "gpt_cond_chunk_len": self.gpt_cond_chunk_len, |
| "max_ref_len": self.max_ref_len, |
| "sound_norm_refs": self.sound_norm_refs, |
| "audio_config": asdict(self.audio_config), |
| "duration_const": self.duration_const, |
| "char_limits": self.char_limits, |
| "languages": self.languages, |
| }) |
|
|
| return output |
|
|
| @classmethod |
| def from_dict(cls, config_dict: Dict) -> "XTTSConfig": |
| """Create a config instance from a dictionary.""" |
| config_copy = config_dict.copy() |
|
|
| |
| if "audio_config" in config_copy: |
| config_copy["audio_config"] = XTTSAudioConfig(**config_copy["audio_config"]) |
|
|
| return cls(**config_copy) |
|
|
| def get_speaker_encoder_config(self) -> SpeakerEncoderConfig: |
| """Get speaker encoder config as a SpeakerEncoderConfig instance""" |
| return SpeakerEncoderConfig(**self.speaker_encoder_config) |
|
|
| def update_with_tokenizer(self, tokenizer=None): |
| """Update configuration values based on tokenizer""" |
| if tokenizer is not None: |
| self.gpt_number_text_tokens = tokenizer.get_vocab_size() |
| self.gpt_start_text_token = tokenizer.bos_token_id |
| self.gpt_stop_text_token = tokenizer.eos_token_id |
| self.vocab_size = tokenizer.get_vocab_size() |
| self.pad_token_id = tokenizer.pad_token_id |
| self.bos_token_id = tokenizer.bos_token_id |
| self.eos_token_id = tokenizer.eos_token_id |
|
|
| def get_hifigan_config(self) -> Dict: |
| """Extract HiFiGAN-specific configuration""" |
| return { |
| "input_sample_rate": self.input_sample_rate, |
| "output_sample_rate": self.output_sample_rate, |
| "output_hop_length": self.output_hop_length, |
| "decoder_input_dim": self.decoder_input_dim, |
| "d_vector_dim": self.d_vector_dim, |
| "cond_d_vector_in_each_upsampling_layer": self.cond_d_vector_in_each_upsampling_layer, |
| "upsample_rates": self.upsample_rates, |
| "upsample_kernel_sizes": self.upsample_kernel_sizes, |
| "upsample_initial_channel": self.upsample_initial_channel, |
| "resblock_kernel_sizes": self.resblock_kernel_sizes, |
| "resblock_dilation_sizes": self.resblock_dilation_sizes, |
| "speaker_encoder_config": self.speaker_encoder_config |
| } |
|
|
| def get_gpt_config(self) -> Dict: |
| """Extract GPT-specific configuration""" |
| return { |
| "vocab_size": self.vocab_size, |
| "num_chars": self.num_chars, |
| "gpt_batch_size": self.gpt_batch_size, |
| "gpt_max_audio_tokens": self.gpt_max_audio_tokens, |
| "gpt_max_text_tokens": self.gpt_max_text_tokens, |
| "gpt_max_prompt_tokens": self.gpt_max_prompt_tokens, |
| "gpt_layers": self.gpt_layers, |
| "gpt_n_model_channels": self.gpt_n_model_channels, |
| "gpt_n_heads": self.gpt_n_heads, |
| "gpt_number_text_tokens": self.gpt_number_text_tokens, |
| "gpt_start_text_token": self.gpt_start_text_token, |
| "gpt_stop_text_token": self.gpt_stop_text_token, |
| "gpt_num_audio_tokens": self.gpt_num_audio_tokens, |
| "gpt_start_audio_token": self.gpt_start_audio_token, |
| "gpt_stop_audio_token": self.gpt_stop_audio_token, |
| "gpt_code_stride_len": self.gpt_code_stride_len, |
| "gpt_use_masking_gt_prompt_approach": self.gpt_use_masking_gt_prompt_approach, |
| "gpt_use_perceiver_resampler": self.gpt_use_perceiver_resampler, |
| "gpt_checkpointing": self.gpt_checkpointing, |
| "gpt_train_solo_embeddings": self.gpt_train_solo_embeddings, |
| "enable_redaction": self.enable_redaction, |
| "kv_cache": self.kv_cache, |
| "perceiver_cond_length_compression": self.perceiver_cond_length_compression, |
| "label_smoothing": self.label_smoothing, |
| "audio_config": self.audio_config, |
| "pad_token_id": self.pad_token_id, |
| "bos_token_id": self.bos_token_id, |
| "eos_token_id": self.eos_token_id |
| } |
|
|
| def get_generation_config(self) -> Dict: |
| """Extract generation-specific configuration""" |
| return { |
| "temperature": self.temperature, |
| "length_penalty": self.length_penalty, |
| "repetition_penalty": self.repetition_penalty, |
| "top_k": self.top_k, |
| "top_p": self.top_p, |
| "gpt_cond_len": self.gpt_cond_len, |
| "gpt_cond_chunk_len": self.gpt_cond_chunk_len, |
| "max_ref_len": self.max_ref_len, |
| "sound_norm_refs": self.sound_norm_refs |
| } |
|
|
| def validate(self): |
| """Validate configuration values""" |
| if self.gpt_max_text_tokens <= 0: |
| raise ValueError("gpt_max_text_tokens must be positive") |
| if self.gpt_max_audio_tokens <= 0: |
| raise ValueError("gpt_max_audio_tokens must be positive") |
| if self.gpt_layers <= 0: |
| raise ValueError("gpt_layers must be positive") |
| if self.gpt_n_heads <= 0: |
| raise ValueError("gpt_n_heads must be positive") |
| if self.gpt_n_model_channels <= 0: |
| raise ValueError("gpt_n_model_channels must be positive") |
| if len(self.upsample_rates) != len(self.upsample_kernel_sizes): |
| raise ValueError("upsample_rates and upsample_kernel_sizes must have same length") |
| if not all(isinstance(x, int) and x > 0 for x in self.upsample_rates): |
| raise ValueError("all upsample_rates must be positive integers") |
|
|
| def get_audio_config(self) -> XTTSAudioConfig: |
| """Get the audio configuration""" |
| return self.audio_config |
|
|
| @property |
| def num_hidden_layers(self) -> int: |
| """Get number of hidden layers (alias for gpt_layers)""" |
| return self.gpt_layers |
|
|
| @property |
| def hidden_size(self) -> int: |
| """Get hidden size (alias for gpt_n_model_channels)""" |
| return self.gpt_n_model_channels |
|
|
| @property |
| def num_attention_heads(self) -> int: |
| """Get number of attention heads (alias for gpt_n_heads)""" |
| return self.gpt_n_heads |