xttsv2 / xtts2_config.py
mlinmg's picture
Upload 6 files
8b6d69d verified
raw
history blame
17.7 kB
from dataclasses import asdict, dataclass
from typing import Dict, List, Optional
from transformers.configuration_utils import PretrainedConfig
@dataclass
class SpeakerEncoderConfig:
"""Configuration for the speaker encoder component"""
model_name: str = "speaker_encoder"
preprocess_config: Optional[Dict] = None
model_config: Optional[Dict] = None
speaker_embedding_dim: int = 512
use_torch_spec: bool = True
@dataclass
class XTTSAudioConfig:
"""Configuration for audio processing parameters"""
sample_rate: int = 22050
output_sample_rate: int = 24000
mel_channels: int = 80
hop_length: int = 256
win_length: int = 1024
n_fft: int = 1024
fmin: int = 0
fmax: int = 8000
power: float = 1.0
mel_norms_file: Optional[str] = None
class XTTSConfig(PretrainedConfig):
"""Combined configuration class for XTTS including both HifiGAN and GPT components"""
model_type = "xtts"
def __init__(
self,
# HifiGAN Audio parameters
input_sample_rate: int = 22050,
output_sample_rate: int = 24000,
output_hop_length: int = 256,
# HifiGAN Model architecture
decoder_input_dim: int = 1024,
d_vector_dim: int = 512,
cond_d_vector_in_each_upsampling_layer: bool = True,
# HifiGAN Upsampling parameters
upsample_rates: List[int] = None,
upsample_kernel_sizes: List[int] = None,
upsample_initial_channel: int = 512,
# HifiGAN Resblock parameters
resblock_kernel_sizes: List[int] = None,
resblock_dilation_sizes: List[List[int]] = None,
# HifiGAN Speaker encoder
speaker_encoder_config: Optional[Dict] = None,
# GPT Model architecture
vocab_size: int = 256,
num_chars: int = 255,
# GPT parameters
gpt_batch_size: int = 1,
gpt_max_audio_tokens: int = 605,
gpt_max_text_tokens: int = 402,
gpt_max_prompt_tokens: int = 70,
gpt_layers: int = 30,
gpt_n_model_channels: int = 1024,
gpt_n_heads: int = 16,
gpt_number_text_tokens: int = 6681,
gpt_start_text_token: Optional[int] = None,
gpt_stop_text_token: Optional[int] = None,
gpt_num_audio_tokens: int = 1026,
gpt_start_audio_token: int = 1024,
gpt_stop_audio_token: int = 1025,
gpt_code_stride_len: int = 1024,
gpt_use_masking_gt_prompt_approach: bool = True,
gpt_use_perceiver_resampler: bool = True,
gpt_checkpointing: bool = False,
gpt_train_solo_embeddings: bool = False,
# GPT Training parameters
enable_redaction: bool = False,
kv_cache: bool = True,
perceiver_cond_length_compression: int = 256,
label_smoothing: float = 0.0,
# GPT Generation parameters
temperature: float = 0.75,
length_penalty: float = 1.0,
repetition_penalty: float = 5.0,
top_k: int = 50,
top_p: float = 0.85,
gpt_cond_len: int = 30,
gpt_cond_chunk_len: int = 4,
max_ref_len: int = 30,
sound_norm_refs: bool = False,
# GPT Audio processing
audio_config: Optional[XTTSAudioConfig] = None,
# GPT Constants and limits
duration_const: int = 102400,
char_limits: Optional[Dict[str, int]] = None,
languages: Optional[List[str]] = None,
# Base config parameters
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs
)
# Set default lists for HifiGAN
if upsample_rates is None:
upsample_rates = [8, 8, 2, 2]
if upsample_kernel_sizes is None:
upsample_kernel_sizes = [16, 16, 4, 4]
if resblock_kernel_sizes is None:
resblock_kernel_sizes = [3, 7, 11]
if resblock_dilation_sizes is None:
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
# Set default dicts for GPT
if char_limits is None:
char_limits = {
"en": 250, "de": 253, "fr": 273, "es": 239,
"it": 213, "pt": 203, "pl": 224, "zh": 82,
"ar": 166, "cs": 186, "ru": 182, "nl": 251,
"tr": 226, "ja": 71, "hu": 224, "ko": 95,
}
if languages is None:
languages = [
"en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl",
"cs", "ar", "zh-cn", "hu", "ko", "ja", "hi"
]
# Initialize HifiGAN parameters
# Audio parameters
self.input_sample_rate = input_sample_rate
self.output_sample_rate = output_sample_rate
self.output_hop_length = output_hop_length
# Model architecture
self.decoder_input_dim = decoder_input_dim
self.d_vector_dim = d_vector_dim
self.cond_d_vector_in_each_upsampling_layer = cond_d_vector_in_each_upsampling_layer
# Upsampling parameters
self.upsample_rates = upsample_rates
self.upsample_kernel_sizes = upsample_kernel_sizes
self.upsample_initial_channel = upsample_initial_channel
# Resblock parameters
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
# Speaker encoder - store as dictionary
if speaker_encoder_config is None:
self.speaker_encoder_config = asdict(SpeakerEncoderConfig())
elif isinstance(speaker_encoder_config, dict):
default_config = asdict(SpeakerEncoderConfig())
default_config.update(speaker_encoder_config)
self.speaker_encoder_config = default_config
elif isinstance(speaker_encoder_config, SpeakerEncoderConfig):
self.speaker_encoder_config = asdict(speaker_encoder_config)
else:
raise ValueError("speaker_encoder_config must be either a dictionary or SpeakerEncoderConfig instance")
# Initialize GPT parameters
self.vocab_size = vocab_size
self.num_chars = num_chars
# GPT model parameters
self.gpt_batch_size = gpt_batch_size
self.gpt_max_audio_tokens = gpt_max_audio_tokens
self.gpt_max_text_tokens = gpt_max_text_tokens
self.gpt_max_prompt_tokens = gpt_max_prompt_tokens
self.gpt_layers = gpt_layers
self.gpt_n_model_channels = gpt_n_model_channels
self.gpt_n_heads = gpt_n_heads
self.gpt_number_text_tokens = gpt_number_text_tokens
self.gpt_start_text_token = gpt_start_text_token
self.gpt_stop_text_token = gpt_stop_text_token
self.gpt_num_audio_tokens = gpt_num_audio_tokens
self.gpt_start_audio_token = gpt_start_audio_token
self.gpt_stop_audio_token = gpt_stop_audio_token
self.gpt_code_stride_len = gpt_code_stride_len
self.gpt_use_masking_gt_prompt_approach = gpt_use_masking_gt_prompt_approach
self.gpt_use_perceiver_resampler = gpt_use_perceiver_resampler
self.gpt_checkpointing = gpt_checkpointing
self.gpt_train_solo_embeddings = gpt_train_solo_embeddings
# Training parameters
self.enable_redaction = enable_redaction
self.kv_cache = kv_cache
self.perceiver_cond_length_compression = perceiver_cond_length_compression
self.label_smoothing = label_smoothing
# Generation parameters
self.temperature = temperature
self.length_penalty = length_penalty
self.repetition_penalty = repetition_penalty
self.top_k = top_k
self.top_p = top_p
self.gpt_cond_len = gpt_cond_len
self.gpt_cond_chunk_len = gpt_cond_chunk_len
self.max_ref_len = max_ref_len
self.sound_norm_refs = sound_norm_refs
# Audio processing
if audio_config is None:
audio_config = XTTSAudioConfig()
elif isinstance(audio_config, dict):
audio_config = XTTSAudioConfig(**audio_config)
self.audio_config = audio_config
# Constants and limits
self.duration_const = duration_const
self.char_limits = char_limits
self.languages = languages
def to_dict(self) -> Dict:
"""Convert the config to a dictionary format."""
# Get parent class dict
output = super().to_dict()
# Add all attributes
output.update({
# HifiGAN parameters
"input_sample_rate": self.input_sample_rate,
"output_sample_rate": self.output_sample_rate,
"output_hop_length": self.output_hop_length,
"decoder_input_dim": self.decoder_input_dim,
"d_vector_dim": self.d_vector_dim,
"cond_d_vector_in_each_upsampling_layer": self.cond_d_vector_in_each_upsampling_layer,
"upsample_rates": self.upsample_rates,
"upsample_kernel_sizes": self.upsample_kernel_sizes,
"upsample_initial_channel": self.upsample_initial_channel,
"resblock_kernel_sizes": self.resblock_kernel_sizes,
"resblock_dilation_sizes": self.resblock_dilation_sizes,
"speaker_encoder_config": self.speaker_encoder_config,
# GPT parameters
"vocab_size": self.vocab_size,
"num_chars": self.num_chars,
"gpt_batch_size": self.gpt_batch_size,
"gpt_max_audio_tokens": self.gpt_max_audio_tokens,
"gpt_max_text_tokens": self.gpt_max_text_tokens,
"gpt_max_prompt_tokens": self.gpt_max_prompt_tokens,
"gpt_layers": self.gpt_layers,
"gpt_n_model_channels": self.gpt_n_model_channels,
"gpt_n_heads": self.gpt_n_heads,
"gpt_number_text_tokens": self.gpt_number_text_tokens,
"gpt_start_text_token": self.gpt_start_text_token,
"gpt_stop_text_token": self.gpt_stop_text_token,
"gpt_num_audio_tokens": self.gpt_num_audio_tokens,
"gpt_start_audio_token": self.gpt_start_audio_token,
"gpt_stop_audio_token": self.gpt_stop_audio_token,
"gpt_code_stride_len": self.gpt_code_stride_len,
"gpt_use_masking_gt_prompt_approach": self.gpt_use_masking_gt_prompt_approach,
"gpt_use_perceiver_resampler": self.gpt_use_perceiver_resampler,
"gpt_checkpointing": self.gpt_checkpointing,
"gpt_train_solo_embeddings": self.gpt_train_solo_embeddings,
"enable_redaction": self.enable_redaction,
"kv_cache": self.kv_cache,
"perceiver_cond_length_compression": self.perceiver_cond_length_compression,
"label_smoothing": self.label_smoothing,
"temperature": self.temperature,
"length_penalty": self.length_penalty,
"repetition_penalty": self.repetition_penalty,
"top_k": self.top_k,
"top_p": self.top_p,
"gpt_cond_len": self.gpt_cond_len,
"gpt_cond_chunk_len": self.gpt_cond_chunk_len,
"max_ref_len": self.max_ref_len,
"sound_norm_refs": self.sound_norm_refs,
"audio_config": asdict(self.audio_config),
"duration_const": self.duration_const,
"char_limits": self.char_limits,
"languages": self.languages,
})
return output
@classmethod
def from_dict(cls, config_dict: Dict) -> "XTTSConfig":
"""Create a config instance from a dictionary."""
config_copy = config_dict.copy()
# Handle special nested configs
if "audio_config" in config_copy:
config_copy["audio_config"] = XTTSAudioConfig(**config_copy["audio_config"])
return cls(**config_copy)
def get_speaker_encoder_config(self) -> SpeakerEncoderConfig:
"""Get speaker encoder config as a SpeakerEncoderConfig instance"""
return SpeakerEncoderConfig(**self.speaker_encoder_config)
def update_with_tokenizer(self, tokenizer=None):
"""Update configuration values based on tokenizer"""
if tokenizer is not None:
self.gpt_number_text_tokens = tokenizer.get_vocab_size()
self.gpt_start_text_token = tokenizer.bos_token_id
self.gpt_stop_text_token = tokenizer.eos_token_id
self.vocab_size = tokenizer.get_vocab_size()
self.pad_token_id = tokenizer.pad_token_id
self.bos_token_id = tokenizer.bos_token_id
self.eos_token_id = tokenizer.eos_token_id
def get_hifigan_config(self) -> Dict:
"""Extract HiFiGAN-specific configuration"""
return {
"input_sample_rate": self.input_sample_rate,
"output_sample_rate": self.output_sample_rate,
"output_hop_length": self.output_hop_length,
"decoder_input_dim": self.decoder_input_dim,
"d_vector_dim": self.d_vector_dim,
"cond_d_vector_in_each_upsampling_layer": self.cond_d_vector_in_each_upsampling_layer,
"upsample_rates": self.upsample_rates,
"upsample_kernel_sizes": self.upsample_kernel_sizes,
"upsample_initial_channel": self.upsample_initial_channel,
"resblock_kernel_sizes": self.resblock_kernel_sizes,
"resblock_dilation_sizes": self.resblock_dilation_sizes,
"speaker_encoder_config": self.speaker_encoder_config
}
def get_gpt_config(self) -> Dict:
"""Extract GPT-specific configuration"""
return {
"vocab_size": self.vocab_size,
"num_chars": self.num_chars,
"gpt_batch_size": self.gpt_batch_size,
"gpt_max_audio_tokens": self.gpt_max_audio_tokens,
"gpt_max_text_tokens": self.gpt_max_text_tokens,
"gpt_max_prompt_tokens": self.gpt_max_prompt_tokens,
"gpt_layers": self.gpt_layers,
"gpt_n_model_channels": self.gpt_n_model_channels,
"gpt_n_heads": self.gpt_n_heads,
"gpt_number_text_tokens": self.gpt_number_text_tokens,
"gpt_start_text_token": self.gpt_start_text_token,
"gpt_stop_text_token": self.gpt_stop_text_token,
"gpt_num_audio_tokens": self.gpt_num_audio_tokens,
"gpt_start_audio_token": self.gpt_start_audio_token,
"gpt_stop_audio_token": self.gpt_stop_audio_token,
"gpt_code_stride_len": self.gpt_code_stride_len,
"gpt_use_masking_gt_prompt_approach": self.gpt_use_masking_gt_prompt_approach,
"gpt_use_perceiver_resampler": self.gpt_use_perceiver_resampler,
"gpt_checkpointing": self.gpt_checkpointing,
"gpt_train_solo_embeddings": self.gpt_train_solo_embeddings,
"enable_redaction": self.enable_redaction,
"kv_cache": self.kv_cache,
"perceiver_cond_length_compression": self.perceiver_cond_length_compression,
"label_smoothing": self.label_smoothing,
"audio_config": self.audio_config,
"pad_token_id": self.pad_token_id,
"bos_token_id": self.bos_token_id,
"eos_token_id": self.eos_token_id
}
def get_generation_config(self) -> Dict:
"""Extract generation-specific configuration"""
return {
"temperature": self.temperature,
"length_penalty": self.length_penalty,
"repetition_penalty": self.repetition_penalty,
"top_k": self.top_k,
"top_p": self.top_p,
"gpt_cond_len": self.gpt_cond_len,
"gpt_cond_chunk_len": self.gpt_cond_chunk_len,
"max_ref_len": self.max_ref_len,
"sound_norm_refs": self.sound_norm_refs
}
def validate(self):
"""Validate configuration values"""
if self.gpt_max_text_tokens <= 0:
raise ValueError("gpt_max_text_tokens must be positive")
if self.gpt_max_audio_tokens <= 0:
raise ValueError("gpt_max_audio_tokens must be positive")
if self.gpt_layers <= 0:
raise ValueError("gpt_layers must be positive")
if self.gpt_n_heads <= 0:
raise ValueError("gpt_n_heads must be positive")
if self.gpt_n_model_channels <= 0:
raise ValueError("gpt_n_model_channels must be positive")
if len(self.upsample_rates) != len(self.upsample_kernel_sizes):
raise ValueError("upsample_rates and upsample_kernel_sizes must have same length")
if not all(isinstance(x, int) and x > 0 for x in self.upsample_rates):
raise ValueError("all upsample_rates must be positive integers")
def get_audio_config(self) -> XTTSAudioConfig:
"""Get the audio configuration"""
return self.audio_config
@property
def num_hidden_layers(self) -> int:
"""Get number of hidden layers (alias for gpt_layers)"""
return self.gpt_layers
@property
def hidden_size(self) -> int:
"""Get hidden size (alias for gpt_n_model_channels)"""
return self.gpt_n_model_channels
@property
def num_attention_heads(self) -> int:
"""Get number of attention heads (alias for gpt_n_heads)"""
return self.gpt_n_heads