| |
| |
| |
| """ SparkTTS model configuration""" |
|
|
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils import logging |
| from typing import List, Optional |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| |
|
|
| class SparkTTSMelParamsConfig(PretrainedConfig): |
| """Configuration for Mel Spectrogram parameters.""" |
| model_type = "spark-tts-mel-params" |
| def __init__(self, sample_rate=16000, n_fft=1024, win_length=640, hop_length=320, |
| mel_fmin=10, mel_fmax=None, num_mels=128, **kwargs): |
| super().__init__(**kwargs) |
| self.sample_rate = sample_rate |
| self.n_fft = n_fft |
| self.win_length = win_length |
| self.hop_length = hop_length |
| self.mel_fmin = mel_fmin |
| self.mel_fmax = mel_fmax |
| self.num_mels = num_mels |
|
|
| class SparkTTSEncoderConfig(PretrainedConfig): |
| """Configuration for the BiCodec Feature Encoder.""" |
| model_type = "spark-tts-encoder" |
| def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048, |
| vocos_num_layers=12, out_channels=1024, sample_ratios=[1, 1], **kwargs): |
| super().__init__(**kwargs) |
| self.input_channels = input_channels |
| self.vocos_dim = vocos_dim |
| self.vocos_intermediate_dim = vocos_intermediate_dim |
| self.vocos_num_layers = vocos_num_layers |
| self.out_channels = out_channels |
| self.sample_ratios = sample_ratios |
|
|
| class SparkTTSDecoderConfig(PretrainedConfig): |
| """Configuration for the BiCodec Wave Generator (Decoder).""" |
| model_type = "spark-tts-decoder" |
| def __init__(self, input_channel=1024, channels=1536, rates=[8, 5, 4, 2], |
| kernel_sizes=[16, 11, 8, 4], **kwargs): |
| super().__init__(**kwargs) |
| self.input_channel = input_channel |
| self.channels = channels |
| self.rates = rates |
| self.kernel_sizes = kernel_sizes |
|
|
| class SparkTTSQuantizerConfig(PretrainedConfig): |
| """Configuration for the BiCodec Factorized Vector Quantizer.""" |
| model_type = "spark-tts-quantizer" |
| def __init__(self, input_dim=1024, codebook_size=8192, codebook_dim=8, |
| commitment=0.25, codebook_loss_weight=2.0, decay=0.99, |
| threshold_ema_dead_code=0.2, **kwargs): |
| |
| |
| super().__init__(**kwargs) |
| self.input_dim = input_dim |
| self.codebook_size = codebook_size |
| self.codebook_dim = codebook_dim |
| self.commitment = commitment |
| self.codebook_loss_weight = codebook_loss_weight |
| self.decay = decay |
| self.threshold_ema_dead_code = threshold_ema_dead_code |
|
|
| class SparkTTSSpeakerEncoderConfig(PretrainedConfig): |
| """Configuration for the BiCodec Speaker Encoder.""" |
| model_type = "spark-tts-speaker-encoder" |
| def __init__(self, input_dim=128, out_dim=1024, latent_dim=128, token_num=32, |
| fsq_levels=[4, 4, 4, 4, 4, 4], fsq_num_quantizers=1, **kwargs): |
| super().__init__(**kwargs) |
| self.input_dim = input_dim |
| self.out_dim = out_dim |
| self.latent_dim = latent_dim |
| self.token_num = token_num |
| self.fsq_levels = fsq_levels |
| self.fsq_num_quantizers = fsq_num_quantizers |
|
|
| class SparkTTSPrenetConfig(PretrainedConfig): |
| """Configuration for the BiCodec Prenet.""" |
| model_type = "spark-tts-prenet" |
| def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048, |
| vocos_num_layers=12, out_channels=1024, condition_dim=1024, |
| sample_ratios=[1, 1], use_tanh_at_final=False, **kwargs): |
| super().__init__(**kwargs) |
| self.input_channels = input_channels |
| self.vocos_dim = vocos_dim |
| self.vocos_intermediate_dim = vocos_intermediate_dim |
| self.vocos_num_layers = vocos_num_layers |
| self.out_channels = out_channels |
| self.condition_dim = condition_dim |
| self.sample_ratios = sample_ratios |
| self.use_tanh_at_final = use_tanh_at_final |
|
|
| class SparkTTSPostnetConfig(PretrainedConfig): |
| """Configuration for the BiCodec Postnet.""" |
| model_type = "spark-tts-postnet" |
| def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048, |
| vocos_num_layers=6, out_channels=1024, use_tanh_at_final=False, **kwargs): |
| |
| super().__init__(**kwargs) |
| self.input_channels = input_channels |
| self.vocos_dim = vocos_dim |
| self.vocos_intermediate_dim = vocos_intermediate_dim |
| self.vocos_num_layers = vocos_num_layers |
| self.out_channels = out_channels |
| self.use_tanh_at_final = use_tanh_at_final |
|
|
|
|
| |
|
|
| class SparkTTSBiCodecConfig(PretrainedConfig): |
| """ |
| Intermediate configuration class for the BiCodec component within SparkTTS. |
| It holds instances of the individual sub-component configurations. |
| """ |
| model_type = "spark-tts-bicodec" |
| |
| sub_configs = { |
| "mel_params": SparkTTSMelParamsConfig, |
| "encoder_config": SparkTTSEncoderConfig, |
| "decoder_config": SparkTTSDecoderConfig, |
| "quantizer_config": SparkTTSQuantizerConfig, |
| "speaker_encoder_config": SparkTTSSpeakerEncoderConfig, |
| "prenet_config": SparkTTSPrenetConfig, |
| "postnet_config": SparkTTSPostnetConfig, |
| } |
|
|
| def __init__( |
| self, |
| mel_params=None, |
| encoder_config=None, |
| decoder_config=None, |
| quantizer_config=None, |
| speaker_encoder_config=None, |
| prenet_config=None, |
| postnet_config=None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| |
| self.mel_params = self._init_sub_config(mel_params, "mel_params") |
| self.encoder_config = self._init_sub_config(encoder_config, "encoder_config") |
| self.decoder_config = self._init_sub_config(decoder_config, "decoder_config") |
| self.quantizer_config = self._init_sub_config(quantizer_config, "quantizer_config") |
| self.speaker_encoder_config = self._init_sub_config(speaker_encoder_config, "speaker_encoder_config") |
| self.prenet_config = self._init_sub_config(prenet_config, "prenet_config") |
| self.postnet_config = self._init_sub_config(postnet_config, "postnet_config") |
|
|
| def _init_sub_config(self, config_input, config_key): |
| """Helper to initialize sub-configs.""" |
| config_cls = self.sub_configs[config_key] |
| if isinstance(config_input, dict): |
| return config_cls(**config_input) |
| elif config_input is None: |
| return config_cls() |
| elif isinstance(config_input, config_cls): |
| return config_input |
| else: |
| raise TypeError(f"Invalid type for {config_key}: {type(config_input)}. Expected dict, None, or {config_cls.__name__}.") |
|
|
|
|
| |
|
|
| class SparkTTSConfig(PretrainedConfig): |
| r""" |
| Main configuration class for SparkTTSModel, including nested BiCodec configuration. |
| Args: |
| llm_model_name_or_path (`str`, *optional*, defaults to `"./LLM"`): Path/ID for LLM. |
| bicodec_model_name_or_path (`str`, *optional*, defaults to `"./BiCodec"`): Path/ID for BiCodec checkpoint. |
| wav2vec2_model_name_or_path (`str`, *optional*, defaults to `"./wav2vec2-large-xlsr-53"`): Path/ID for Wav2Vec2. |
| sample_rate (`int`, *optional*, defaults to 16000): Audio sample rate. |
| # ... (other top-level args: highpass_cutoff_freq, latent_hop_length, ref_segment_duration, volume_normalize) ... |
| bicodec_config (`dict`, *optional*): Dictionary to initialize `SparkTTSBiCodecConfig`. |
| torch_dtype (`str`, *optional*, defaults to `"auto"`): Torch dtype. |
| kwargs (*optional*): Dictionary of keyword arguments. |
| """ |
| model_type = "spark-tts" |
| |
| sub_configs = {"bicodec_config": SparkTTSBiCodecConfig} |
| attribute_map = {"hidden_size": "d_model"} |
|
|
| def __init__( |
| self, |
| llm_model_name_or_path="./LLM", |
| bicodec_model_name_or_path="./BiCodec", |
| wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53", |
| sample_rate=16000, |
| highpass_cutoff_freq=40, |
| latent_hop_length=320, |
| ref_segment_duration=6.0, |
| volume_normalize=True, |
| bicodec_config=None, |
| torch_dtype="auto", |
| **kwargs, |
| ): |
| |
| self.llm_model_name_or_path = llm_model_name_or_path |
| self.bicodec_model_name_or_path = bicodec_model_name_or_path |
| self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path |
| self.sample_rate = sample_rate |
| self.highpass_cutoff_freq = highpass_cutoff_freq |
| self.latent_hop_length = latent_hop_length |
| self.ref_segment_duration = ref_segment_duration |
| self.volume_normalize = volume_normalize |
| self.torch_dtype = torch_dtype |
|
|
| |
| |
| if isinstance(bicodec_config, dict): |
| self.bicodec_config = self.sub_configs["bicodec_config"](**bicodec_config) |
| elif bicodec_config is None: |
| logger.info("`bicodec_config` not provided. Initializing `SparkTTSBiCodecConfig` with its defaults.") |
| self.bicodec_config = self.sub_configs["bicodec_config"]() |
| elif isinstance(bicodec_config, self.sub_configs["bicodec_config"]): |
| self.bicodec_config = bicodec_config |
| else: |
| raise TypeError(f"Invalid type for bicodec_config: {type(bicodec_config)}. Expected dict, None, or SparkTTSBiCodecConfig.") |
|
|
|
|
| |
| kwargs["processor_class"] = kwargs.get("processor_class", "SparkTTSProcessor") |
| kwargs["auto_map"] = kwargs.get("auto_map", { |
| "AutoConfig": "configuration_spark_tts.SparkTTSConfig", |
| "AutoModel": "modeling_spark_tts.SparkTTSModel", |
| "AutoProcessor": "processing_spark_tts.SparkTTSProcessor" |
| }) |
| super().__init__(**kwargs) |