khanusa commited on
Commit
ee09fe7
·
verified ·
1 Parent(s): 75ea504

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ LLM/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ __pycache__/modeling_spark_tts.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
BiCodec/config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audio_tokenizer:
2
+ mel_params:
3
+ sample_rate: 16000
4
+ n_fft: 1024
5
+ win_length: 640
6
+ hop_length: 320
7
+ mel_fmin: 10
8
+ mel_fmax: null
9
+ num_mels: 128
10
+
11
+ encoder:
12
+ input_channels: 1024
13
+ vocos_dim: 384
14
+ vocos_intermediate_dim: 2048
15
+ vocos_num_layers: 12
16
+ out_channels: 1024
17
+ sample_ratios: [1,1]
18
+
19
+ decoder:
20
+ input_channel: 1024
21
+ channels: 1536
22
+ rates: [8, 5, 4, 2]
23
+ kernel_sizes: [16,11,8,4]
24
+
25
+ quantizer:
26
+ input_dim: 1024
27
+ codebook_size: 8192
28
+ codebook_dim: 8
29
+ commitment: 0.25
30
+ codebook_loss_weight: 2.0
31
+ use_l2_normlize: True
32
+ threshold_ema_dead_code: 0.2
33
+
34
+ speaker_encoder:
35
+ input_dim: 128
36
+ out_dim: 1024
37
+ latent_dim: 128
38
+ token_num: 32
39
+ fsq_levels: [4, 4, 4, 4, 4, 4]
40
+ fsq_num_quantizers: 1
41
+
42
+ prenet:
43
+ input_channels: 1024
44
+ vocos_dim: 384
45
+ vocos_intermediate_dim: 2048
46
+ vocos_num_layers: 12
47
+ out_channels: 1024
48
+ condition_dim: 1024
49
+ sample_ratios: [1,1]
50
+ use_tanh_at_final: False
51
+
52
+ postnet:
53
+ input_channels: 1024
54
+ vocos_dim: 384
55
+ vocos_intermediate_dim: 2048
56
+ vocos_num_layers: 6
57
+ out_channels: 1024
58
+ use_tanh_at_final: False
59
+
60
+
BiCodec/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9940cd48d4446e4340ced82d234bf5618350dd9f5db900ebe47a4fdb03867ec
3
+ size 625518756
LLM/added_tokens.json ADDED
The diff for this file is too large to render. See raw diff
 
LLM/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen2ForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 151643,
7
+ "eos_token_id": 151645,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 896,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 4864,
12
+ "max_position_embeddings": 32768,
13
+ "max_window_layers": 21,
14
+ "model_type": "qwen2",
15
+ "num_attention_heads": 14,
16
+ "num_hidden_layers": 24,
17
+ "num_key_value_heads": 2,
18
+ "rms_norm_eps": 1e-06,
19
+ "rope_theta": 1000000.0,
20
+ "sliding_window": 32768,
21
+ "tie_word_embeddings": true,
22
+ "torch_dtype": "bfloat16",
23
+ "transformers_version": "4.43.1",
24
+ "use_cache": true,
25
+ "use_sliding_window": false,
26
+ "vocab_size": 166000
27
+ }
LLM/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
LLM/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4719dd1d707ab915a12b58602eeb2521c8f259d0cc0dc2e0dca502a4fae22c8f
3
+ size 1013300536
LLM/special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
LLM/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c8b057d6ca205a429cc3428b9fc815f0d6ee1d53106dd5e5b129ef9db2ff057
3
+ size 14129172
LLM/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
LLM/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
__init__.py ADDED
File without changes
__pycache__/configuration_spark_tts.cpython-312.pyc ADDED
Binary file (11.1 kB). View file
 
__pycache__/modeling_spark_tts.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3dc50b2b543272688fbfa8c52a85837182e7672981fbf7d232fb5c74ef77b657
3
+ size 134941
__pycache__/processing_spark_tts.cpython-312.pyc ADDED
Binary file (36 kB). View file
 
config.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "spark-tts",
3
+ "architectures": [
4
+ "SparkTTSModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_spark_tts.SparkTTSConfig",
8
+ "AutoModel": "modeling_spark_tts.SparkTTSModel",
9
+ "AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
10
+ },
11
+ "processor_class": "processing_spark_tts.SparkTTSProcessor",
12
+ "llm_model_name_or_path": "./LLM",
13
+ "bicodec_model_name_or_path": "./BiCodec",
14
+ "wav2vec2_model_name_or_path": "./wav2vec2-large-xlsr-53",
15
+ "sample_rate": 16000,
16
+ "highpass_cutoff_freq": 40,
17
+ "latent_hop_length": 320,
18
+ "ref_segment_duration": 6.0,
19
+ "volume_normalize": true,
20
+ "torch_dtype": "bfloat16",
21
+ "transformers_version": "4.50.3",
22
+ "_commit_hash": null,
23
+ "bicodec_config": {
24
+ "mel_params": {
25
+ "sample_rate": 16000,
26
+ "n_fft": 1024,
27
+ "win_length": 640,
28
+ "hop_length": 320,
29
+ "mel_fmin": 10,
30
+ "mel_fmax": null,
31
+ "num_mels": 128
32
+ },
33
+ "encoder_config": {
34
+ "input_channels": 1024,
35
+ "vocos_dim": 384,
36
+ "vocos_intermediate_dim": 2048,
37
+ "vocos_num_layers": 12,
38
+ "out_channels": 1024,
39
+ "sample_ratios": [1, 1]
40
+ },
41
+ "decoder_config": {
42
+ "input_channel": 1024,
43
+ "channels": 1536,
44
+ "rates": [8, 5, 4, 2],
45
+ "kernel_sizes": [16, 11, 8, 4]
46
+ },
47
+ "quantizer_config": {
48
+ "input_dim": 1024,
49
+ "codebook_size": 8192,
50
+ "codebook_dim": 8,
51
+ "commitment": 0.25,
52
+ "codebook_loss_weight": 2.0,
53
+ "decay": 0.99,
54
+ "threshold_ema_dead_code": 0.2
55
+ },
56
+ "speaker_encoder_config": {
57
+ "input_dim": 128,
58
+ "out_dim": 1024,
59
+ "latent_dim": 128,
60
+ "token_num": 32,
61
+ "fsq_levels": [4, 4, 4, 4, 4, 4],
62
+ "fsq_num_quantizers": 1
63
+ },
64
+ "prenet_config": {
65
+ "input_channels": 1024,
66
+ "vocos_dim": 384,
67
+ "vocos_intermediate_dim": 2048,
68
+ "vocos_num_layers": 12,
69
+ "out_channels": 1024,
70
+ "condition_dim": 1024,
71
+ "sample_ratios": [1, 1],
72
+ "use_tanh_at_final": false
73
+ },
74
+ "postnet_config": {
75
+ "input_channels": 1024,
76
+ "vocos_dim": 384,
77
+ "vocos_intermediate_dim": 2048,
78
+ "vocos_num_layers": 6,
79
+ "out_channels": 1024,
80
+ "use_tanh_at_final": false
81
+ }
82
+ }
83
+ }
configuration_spark_tts.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
3
+ # ... (License headers remain the same) ...
4
+ """ SparkTTS model configuration"""
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.utils import logging
8
+ from typing import List, Optional # Added typing
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ # --- Define Individual Sub-Component Config Classes ---
14
+
15
+ class SparkTTSMelParamsConfig(PretrainedConfig):
16
+ """Configuration for Mel Spectrogram parameters."""
17
+ model_type = "spark-tts-mel-params"
18
+ def __init__(self, sample_rate=16000, n_fft=1024, win_length=640, hop_length=320,
19
+ mel_fmin=10, mel_fmax=None, num_mels=128, **kwargs):
20
+ super().__init__(**kwargs)
21
+ self.sample_rate = sample_rate
22
+ self.n_fft = n_fft
23
+ self.win_length = win_length
24
+ self.hop_length = hop_length
25
+ self.mel_fmin = mel_fmin
26
+ self.mel_fmax = mel_fmax
27
+ self.num_mels = num_mels
28
+
29
+ class SparkTTSEncoderConfig(PretrainedConfig):
30
+ """Configuration for the BiCodec Feature Encoder."""
31
+ model_type = "spark-tts-encoder"
32
+ def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
33
+ vocos_num_layers=12, out_channels=1024, sample_ratios=[1, 1], **kwargs):
34
+ super().__init__(**kwargs)
35
+ self.input_channels = input_channels
36
+ self.vocos_dim = vocos_dim
37
+ self.vocos_intermediate_dim = vocos_intermediate_dim
38
+ self.vocos_num_layers = vocos_num_layers
39
+ self.out_channels = out_channels
40
+ self.sample_ratios = sample_ratios
41
+
42
+ class SparkTTSDecoderConfig(PretrainedConfig):
43
+ """Configuration for the BiCodec Wave Generator (Decoder)."""
44
+ model_type = "spark-tts-decoder"
45
+ def __init__(self, input_channel=1024, channels=1536, rates=[8, 5, 4, 2],
46
+ kernel_sizes=[16, 11, 8, 4], **kwargs):
47
+ super().__init__(**kwargs)
48
+ self.input_channel = input_channel
49
+ self.channels = channels
50
+ self.rates = rates
51
+ self.kernel_sizes = kernel_sizes
52
+
53
+ class SparkTTSQuantizerConfig(PretrainedConfig):
54
+ """Configuration for the BiCodec Factorized Vector Quantizer."""
55
+ model_type = "spark-tts-quantizer"
56
+ def __init__(self, input_dim=1024, codebook_size=8192, codebook_dim=8,
57
+ commitment=0.25, codebook_loss_weight=2.0, decay=0.99,
58
+ threshold_ema_dead_code=0.2, **kwargs):
59
+ # Note: Removed use_l2_normlize as it wasn't in the original class __init__ args
60
+ # Add it back if it's actually used by the FactorizedVectorQuantize class init
61
+ super().__init__(**kwargs)
62
+ self.input_dim = input_dim
63
+ self.codebook_size = codebook_size
64
+ self.codebook_dim = codebook_dim
65
+ self.commitment = commitment
66
+ self.codebook_loss_weight = codebook_loss_weight
67
+ self.decay = decay
68
+ self.threshold_ema_dead_code = threshold_ema_dead_code
69
+
70
+ class SparkTTSSpeakerEncoderConfig(PretrainedConfig):
71
+ """Configuration for the BiCodec Speaker Encoder."""
72
+ model_type = "spark-tts-speaker-encoder"
73
+ def __init__(self, input_dim=128, out_dim=1024, latent_dim=128, token_num=32,
74
+ fsq_levels=[4, 4, 4, 4, 4, 4], fsq_num_quantizers=1, **kwargs):
75
+ super().__init__(**kwargs)
76
+ self.input_dim = input_dim
77
+ self.out_dim = out_dim
78
+ self.latent_dim = latent_dim
79
+ self.token_num = token_num
80
+ self.fsq_levels = fsq_levels
81
+ self.fsq_num_quantizers = fsq_num_quantizers
82
+
83
+ class SparkTTSPrenetConfig(PretrainedConfig):
84
+ """Configuration for the BiCodec Prenet."""
85
+ model_type = "spark-tts-prenet"
86
+ def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
87
+ vocos_num_layers=12, out_channels=1024, condition_dim=1024,
88
+ sample_ratios=[1, 1], use_tanh_at_final=False, **kwargs):
89
+ super().__init__(**kwargs)
90
+ self.input_channels = input_channels
91
+ self.vocos_dim = vocos_dim
92
+ self.vocos_intermediate_dim = vocos_intermediate_dim
93
+ self.vocos_num_layers = vocos_num_layers
94
+ self.out_channels = out_channels
95
+ self.condition_dim = condition_dim
96
+ self.sample_ratios = sample_ratios
97
+ self.use_tanh_at_final = use_tanh_at_final
98
+
99
+ class SparkTTSPostnetConfig(PretrainedConfig):
100
+ """Configuration for the BiCodec Postnet."""
101
+ model_type = "spark-tts-postnet"
102
+ def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
103
+ vocos_num_layers=6, out_channels=1024, use_tanh_at_final=False, **kwargs):
104
+ # Note: Removed condition_dim as it wasn't in the original config example for postnet
105
+ super().__init__(**kwargs)
106
+ self.input_channels = input_channels
107
+ self.vocos_dim = vocos_dim
108
+ self.vocos_intermediate_dim = vocos_intermediate_dim
109
+ self.vocos_num_layers = vocos_num_layers
110
+ self.out_channels = out_channels
111
+ self.use_tanh_at_final = use_tanh_at_final
112
+
113
+
114
+ # --- Define the Intermediate BiCodec Config Class ---
115
+
116
+ class SparkTTSBiCodecConfig(PretrainedConfig):
117
+ """
118
+ Intermediate configuration class for the BiCodec component within SparkTTS.
119
+ It holds instances of the individual sub-component configurations.
120
+ """
121
+ model_type = "spark-tts-bicodec"
122
+ # Map keys in the 'bicodec_config' dict to their respective classes
123
+ sub_configs = {
124
+ "mel_params": SparkTTSMelParamsConfig,
125
+ "encoder_config": SparkTTSEncoderConfig,
126
+ "decoder_config": SparkTTSDecoderConfig,
127
+ "quantizer_config": SparkTTSQuantizerConfig,
128
+ "speaker_encoder_config": SparkTTSSpeakerEncoderConfig,
129
+ "prenet_config": SparkTTSPrenetConfig,
130
+ "postnet_config": SparkTTSPostnetConfig,
131
+ }
132
+
133
+ def __init__(
134
+ self,
135
+ mel_params=None,
136
+ encoder_config=None,
137
+ decoder_config=None,
138
+ quantizer_config=None,
139
+ speaker_encoder_config=None,
140
+ prenet_config=None,
141
+ postnet_config=None,
142
+ **kwargs,
143
+ ):
144
+ super().__init__(**kwargs)
145
+
146
+ # Instantiate sub-configs from dictionaries or use defaults/provided instances
147
+ self.mel_params = self._init_sub_config(mel_params, "mel_params")
148
+ self.encoder_config = self._init_sub_config(encoder_config, "encoder_config")
149
+ self.decoder_config = self._init_sub_config(decoder_config, "decoder_config")
150
+ self.quantizer_config = self._init_sub_config(quantizer_config, "quantizer_config")
151
+ self.speaker_encoder_config = self._init_sub_config(speaker_encoder_config, "speaker_encoder_config")
152
+ self.prenet_config = self._init_sub_config(prenet_config, "prenet_config")
153
+ self.postnet_config = self._init_sub_config(postnet_config, "postnet_config")
154
+
155
+ def _init_sub_config(self, config_input, config_key):
156
+ """Helper to initialize sub-configs."""
157
+ config_cls = self.sub_configs[config_key]
158
+ if isinstance(config_input, dict):
159
+ return config_cls(**config_input)
160
+ elif config_input is None:
161
+ return config_cls() # Initialize with defaults
162
+ elif isinstance(config_input, config_cls):
163
+ return config_input # Already an instance
164
+ else:
165
+ raise TypeError(f"Invalid type for {config_key}: {type(config_input)}. Expected dict, None, or {config_cls.__name__}.")
166
+
167
+
168
+ # --- Define the Main SparkTTS Config Class ---
169
+
170
+ class SparkTTSConfig(PretrainedConfig):
171
+ r"""
172
+ Main configuration class for SparkTTSModel, including nested BiCodec configuration.
173
+ Args:
174
+ llm_model_name_or_path (`str`, *optional*, defaults to `"./LLM"`): Path/ID for LLM.
175
+ bicodec_model_name_or_path (`str`, *optional*, defaults to `"./BiCodec"`): Path/ID for BiCodec checkpoint.
176
+ wav2vec2_model_name_or_path (`str`, *optional*, defaults to `"./wav2vec2-large-xlsr-53"`): Path/ID for Wav2Vec2.
177
+ sample_rate (`int`, *optional*, defaults to 16000): Audio sample rate.
178
+ # ... (other top-level args: highpass_cutoff_freq, latent_hop_length, ref_segment_duration, volume_normalize) ...
179
+ bicodec_config (`dict`, *optional*): Dictionary to initialize `SparkTTSBiCodecConfig`.
180
+ torch_dtype (`str`, *optional*, defaults to `"auto"`): Torch dtype.
181
+ kwargs (*optional*): Dictionary of keyword arguments.
182
+ """
183
+ model_type = "spark-tts"
184
+ # Map the key in config.json to the intermediate BiCodec config class
185
+ sub_configs = {"bicodec_config": SparkTTSBiCodecConfig}
186
+ attribute_map = {"hidden_size": "d_model"} # Example
187
+
188
+ def __init__(
189
+ self,
190
+ llm_model_name_or_path="./LLM",
191
+ bicodec_model_name_or_path="./BiCodec",
192
+ wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53",
193
+ sample_rate=16000,
194
+ highpass_cutoff_freq=40,
195
+ latent_hop_length=320,
196
+ ref_segment_duration=6.0,
197
+ volume_normalize=True,
198
+ bicodec_config=None, # Expects a dictionary or None
199
+ torch_dtype="auto",
200
+ **kwargs,
201
+ ):
202
+ # --- Top-level parameters ---
203
+ self.llm_model_name_or_path = llm_model_name_or_path
204
+ self.bicodec_model_name_or_path = bicodec_model_name_or_path
205
+ self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
206
+ self.sample_rate = sample_rate
207
+ self.highpass_cutoff_freq = highpass_cutoff_freq
208
+ self.latent_hop_length = latent_hop_length
209
+ self.ref_segment_duration = ref_segment_duration
210
+ self.volume_normalize = volume_normalize
211
+ self.torch_dtype = torch_dtype
212
+
213
+ # --- Nested BiCodec Configuration ---
214
+ # Instantiate the intermediate BiCodec config class, which will handle its own sub-configs
215
+ if isinstance(bicodec_config, dict):
216
+ self.bicodec_config = self.sub_configs["bicodec_config"](**bicodec_config)
217
+ elif bicodec_config is None:
218
+ logger.info("`bicodec_config` not provided. Initializing `SparkTTSBiCodecConfig` with its defaults.")
219
+ self.bicodec_config = self.sub_configs["bicodec_config"]()
220
+ elif isinstance(bicodec_config, self.sub_configs["bicodec_config"]):
221
+ self.bicodec_config = bicodec_config # Use existing instance
222
+ else:
223
+ raise TypeError(f"Invalid type for bicodec_config: {type(bicodec_config)}. Expected dict, None, or SparkTTSBiCodecConfig.")
224
+
225
+
226
+ # Set processor class and auto_map
227
+ kwargs["processor_class"] = kwargs.get("processor_class", "SparkTTSProcessor")
228
+ kwargs["auto_map"] = kwargs.get("auto_map", {
229
+ "AutoConfig": "configuration_spark_tts.SparkTTSConfig",
230
+ "AutoModel": "modeling_spark_tts.SparkTTSModel",
231
+ "AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
232
+ })
233
+ super().__init__(**kwargs)
modeling_spark_tts.py ADDED
The diff for this file is too large to render. See raw diff
 
processing_spark_tts.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 SparkAudio & The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for SparkTTS. Combines text tokenization and audio feature extraction/processing.
17
+ """
18
+
19
+ import os # Needed for save_pretrained
20
+ import re # For decoding
21
+ import torch
22
+ import numpy as np
23
+ import soundfile as sf # For audio loading
24
+ import soxr # For resampling
25
+
26
+ from pathlib import Path
27
+ from typing import Optional, Union, List, Dict, Tuple, Any
28
+
29
+ from transformers.processing_utils import ProcessorMixin
30
+ from transformers.tokenization_utils_base import BatchEncoding # Return type hint
31
+ from transformers.feature_extraction_utils import BatchFeature # Return type hint
32
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
33
+ from transformers.models.wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
34
+ from transformers.utils import logging, PushToHubMixin # Added PushToHubMixin
35
+ from numpy.lib.stride_tricks import sliding_window_view
36
+ import soxr
37
+ import soundfile
38
+ import random
39
+
40
+ # Import custom config if needed for defaults
41
+ from .configuration_spark_tts import SparkTTSConfig
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ # =============================================================================
47
+ # >> START: PASTE CODE FROM sparktts/utils/* HERE <<
48
+ # =============================================================================
49
+ # IMPORTANT: Utility functions needed for processing (audio loading, token parsing)
50
+ # must be defined or imported here.
51
+
52
+ # --- Paste sparktts/utils/audio.py content here ---
53
+
54
+ def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
55
+ """
56
+ Normalize the volume of an audio signal.
57
+
58
+ Parameters:
59
+ audio (numpy array): Input audio signal array.
60
+ coeff (float): Target coefficient for normalization, default is 0.2.
61
+
62
+ Returns:
63
+ numpy array: The volume-normalized audio signal.
64
+ """
65
+ # Sort the absolute values of the audio signal
66
+ temp = np.sort(np.abs(audio))
67
+
68
+ # If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
69
+ if temp[-1] < 0.1:
70
+ scaling_factor = max(
71
+ temp[-1], 1e-3
72
+ ) # Prevent division by zero with a small constant
73
+ audio = audio / scaling_factor * 0.1
74
+
75
+ # Filter out values less than 0.01 from temp
76
+ temp = temp[temp > 0.01]
77
+ L = temp.shape[0] # Length of the filtered array
78
+
79
+ # If there are fewer than or equal to 10 significant values, return the audio without further processing
80
+ if L <= 10:
81
+ return audio
82
+
83
+ # Compute the average of the top 10% to 1% of values in temp
84
+ volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
85
+
86
+ # Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
87
+ audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
88
+
89
+ # Ensure the maximum absolute value in the audio does not exceed 1
90
+ max_value = np.max(np.abs(audio))
91
+ if max_value > 1:
92
+ audio = audio / max_value
93
+
94
+ return audio
95
+
96
+
97
+ def load_audio(
98
+ adfile: Path,
99
+ sampling_rate: int = None,
100
+ length: int = None,
101
+ volume_normalize: bool = False,
102
+ segment_duration: int = None,
103
+ ) -> np.ndarray:
104
+ r"""Load audio file with target sampling rate and lsength
105
+
106
+ Args:
107
+ adfile (Path): path to audio file.
108
+ sampling_rate (int, optional): target sampling rate. Defaults to None.
109
+ length (int, optional): target audio length. Defaults to None.
110
+ volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
111
+ segment_duration (int): random select a segment with duration of {segment_duration}s.
112
+ Defualt to None which means the whole audio will be used.
113
+
114
+ Returns:
115
+ audio (np.ndarray): audio
116
+ """
117
+
118
+ audio, sr = soundfile.read(adfile)
119
+ if len(audio.shape) > 1:
120
+ audio = audio[:, 0]
121
+
122
+ if sampling_rate is not None and sr != sampling_rate:
123
+ audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
124
+ sr = sampling_rate
125
+
126
+ if segment_duration is not None:
127
+ seg_length = int(sr * segment_duration)
128
+ audio = random_select_audio_segment(audio, seg_length)
129
+
130
+ # Audio volume normalize
131
+ if volume_normalize:
132
+ audio = audio_volume_normalize(audio)
133
+ # check the audio length
134
+ if length is not None:
135
+ assert abs(audio.shape[0] - length) < 1000
136
+ if audio.shape[0] > length:
137
+ audio = audio[:length]
138
+ else:
139
+ audio = np.pad(audio, (0, int(length - audio.shape[0])))
140
+ return audio
141
+
142
+
143
+ def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
144
+ """get an audio segment given the length
145
+
146
+ Args:
147
+ audio (np.ndarray):
148
+ length (int): audio length = sampling_rate * duration
149
+ """
150
+ if audio.shape[0] < length:
151
+ audio = np.pad(audio, (0, int(length - audio.shape[0])))
152
+ start_index = random.randint(0, audio.shape[0] - length)
153
+ end_index = int(start_index + length)
154
+
155
+ return audio[start_index:end_index]
156
+
157
+ def get_ref_clip(wav: np.ndarray, config) -> np.ndarray: # Needs access to config attributes
158
+ """Get reference audio clip for speaker embedding."""
159
+ # Make sure config has sample_rate, ref_segment_duration, latent_hop_length
160
+ if not all(hasattr(config, attr) for attr in ['sample_rate', 'ref_segment_duration', 'latent_hop_length']):
161
+ raise AttributeError("Config object missing required attributes for get_ref_clip")
162
+ ref_segment_length = (
163
+ int(config.sample_rate * config.ref_segment_duration)
164
+ // config.latent_hop_length
165
+ * config.latent_hop_length
166
+ )
167
+ wav_length = len(wav)
168
+ if ref_segment_length > wav_length:
169
+ wav = np.tile(wav, ref_segment_length // wav_length + 1)
170
+ return wav[:ref_segment_length]
171
+
172
+
173
+ # --- Paste sparktts/utils/token_parser.py content here ---
174
+
175
+ TASK_TOKEN_MAP = {
176
+ "vc": "<|task_vc|>",
177
+ "tts": "<|task_tts|>",
178
+ "asr": "<|task_asr|>",
179
+ "s2s": "<|task_s2s|>",
180
+ "t2s": "<|task_t2s|>",
181
+ "understand": "<|task_understand|>",
182
+ "caption": "<|task_cap|>",
183
+ "controllable_tts": "<|task_controllable_tts|>",
184
+ "prompt_tts": "<|task_prompt_tts|>",
185
+ "speech_edit": "<|task_edit|>",
186
+ }
187
+
188
+ LEVELS_MAP = {
189
+ "very_low": 0,
190
+ "low": 1,
191
+ "moderate": 2,
192
+ "high": 3,
193
+ "very_high": 4,
194
+ }
195
+
196
+ LEVELS_MAP_UI = {
197
+ 1: 'very_low',
198
+ 2: 'low',
199
+ 3: 'moderate',
200
+ 4: 'high',
201
+ 5: 'very_high'
202
+ }
203
+
204
+ GENDER_MAP = {
205
+ "female": 0,
206
+ "male": 1,
207
+ }
208
+
209
+ AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
210
+
211
+ EMO_MAP = {
212
+ "UNKNOWN": 0,
213
+ "NEUTRAL": 1,
214
+ "ANGRY": 2,
215
+ "HAPPY": 3,
216
+ "SAD": 4,
217
+ "FEARFUL": 5,
218
+ "DISGUSTED": 6,
219
+ "SURPRISED": 7,
220
+ "SARCASTIC": 8,
221
+ "EXCITED": 9,
222
+ "SLEEPY": 10,
223
+ "CONFUSED": 11,
224
+ "EMPHASIS": 12,
225
+ "LAUGHING": 13,
226
+ "SINGING": 14,
227
+ "WORRIED": 15,
228
+ "WHISPER": 16,
229
+ "ANXIOUS": 17,
230
+ "NO-AGREEMENT": 18,
231
+ "APOLOGETIC": 19,
232
+ "CONCERNED": 20,
233
+ "ENUNCIATED": 21,
234
+ "ASSERTIVE": 22,
235
+ "ENCOURAGING": 23,
236
+ "CONTEMPT": 24,
237
+ }
238
+
239
+
240
+ class TokenParser:
241
+ """Turn label to special token"""
242
+
243
+ def __init__(self):
244
+ pass
245
+
246
+ """Parse the attributes of a person."""
247
+
248
+ def __init__(self):
249
+ pass
250
+
251
+ @staticmethod
252
+ def age(age: str) -> str:
253
+ """Turn age token."""
254
+ age_id = AGE_MAP[age]
255
+ return f"<|age_{age_id}|>"
256
+
257
+ @staticmethod
258
+ def gender(gender: str) -> str:
259
+ """Turn gender token."""
260
+ gender_id = GENDER_MAP[gender]
261
+ return f"<|gender_{gender_id}|>"
262
+
263
+ @staticmethod
264
+ def mel_value(mel: int):
265
+ """Turn special token of mel scale pitch."""
266
+ mel = max(0, int(mel))
267
+ mel = min(1000, int(mel))
268
+ return f"<|pitch_value_{mel}|>"
269
+
270
+ @staticmethod
271
+ def mel_level(level: str):
272
+ """Turn special token of mel level."""
273
+ level_tag = LEVELS_MAP[level]
274
+ return f"<|pitch_label_{level_tag}|>"
275
+
276
+ @staticmethod
277
+ def pitch_var_value(pitch_std: int):
278
+ """Turn special token of pitch_std value."""
279
+ assert isinstance(pitch_std, int)
280
+ pitch_std = max(0, int(pitch_std))
281
+ pitch_std = min(10, int(pitch_std))
282
+ return f"<|pitch_var_value_{pitch_std}|>"
283
+
284
+ @staticmethod
285
+ def pitch_var_level(level: str):
286
+ """Turn special token of pitch std level."""
287
+ level_tag = LEVELS_MAP[level]
288
+ return f"<|pitch_var_label_{level_tag}|>"
289
+
290
+ @staticmethod
291
+ def loudness_value(loudness: int):
292
+ """Turn special toak of loudness value [0, 30]"""
293
+ assert loudness >= 0
294
+ loudness = max(0, int(loudness))
295
+ loudness = min(30, int(loudness))
296
+ return f"<|loudness_value_{loudness}|>"
297
+
298
+ @staticmethod
299
+ def loudness_level(level: str):
300
+ """Turn special token of loudness level."""
301
+ level_tag = LEVELS_MAP[level]
302
+ return f"<|loudness_label_{level_tag}|>"
303
+
304
+ @staticmethod
305
+ def speed_value(speed: int):
306
+ """Turn special token of speed value."""
307
+ speed = max(0, int(speed))
308
+ speed = min(10, int(speed))
309
+ return f"<|speed_value_{speed}|>"
310
+
311
+ @staticmethod
312
+ def speed_level(level: str):
313
+ """Turn special token of speed level."""
314
+ level_tag = LEVELS_MAP[level]
315
+ return f"<|speed_label_{level_tag}|>"
316
+
317
+ @staticmethod
318
+ def task(task: str) -> str:
319
+ """Turn special token of task."""
320
+ assert task in TASK_TOKEN_MAP.keys()
321
+
322
+ return TASK_TOKEN_MAP[task]
323
+
324
+ @staticmethod
325
+ def emotion(emotion: str):
326
+ emo_id = EMO_MAP[emotion]
327
+
328
+ return f"<|emotion_{emo_id}|>"
329
+
330
+ # =============================================================================
331
+ # >> END: PASTE CODE FROM sparktts/utils/* HERE <<
332
+ # =============================================================================
333
+
334
+
335
+ class SparkTTSProcessor(ProcessorMixin, PushToHubMixin): # Added PushToHubMixin
336
+ r"""
337
+ Constructs a SparkTTS processor which wraps a text tokenizer and relevant audio processing logic.
338
+
339
+ Args:
340
+ tokenizer ([`PreTrainedTokenizer`]):
341
+ An instance of [`PreTrainedTokenizer`]. This handles the text tokenization for the LLM.
342
+ feature_extractor ([`Wav2Vec2FeatureExtractor`]):
343
+ An instance of [`Wav2Vec2FeatureExtractor`]. Although Wav2Vec2 features are extracted
344
+ within the model's `tokenize_audio`, the extractor's configuration (like sampling rate)
345
+ is useful, and it aligns with the ProcessorMixin pattern.
346
+ config ([`SparkTTSConfig`], *optional*):
347
+ An instance of [`SparkTTSConfig`] to access configuration parameters like sample rate.
348
+ """
349
+ attributes = ["tokenizer", "feature_extractor"]
350
+ tokenizer_class = "AutoTokenizer"
351
+ feature_extractor_class = "Wav2Vec2FeatureExtractor" # Keep for consistency
352
+
353
+ def __init__(self, tokenizer, feature_extractor, config: Optional[SparkTTSConfig] = None, **kwargs):
354
+ super().__init__(tokenizer=tokenizer, feature_extractor=feature_extractor, **kwargs)
355
+ self.model = None
356
+ self.config = config
357
+ # Set sampling rate
358
+ if config and hasattr(config, 'sample_rate'):
359
+ self.sampling_rate = config.sample_rate
360
+ elif feature_extractor and hasattr(feature_extractor, 'sampling_rate'):
361
+ self.sampling_rate = feature_extractor.sampling_rate
362
+ else:
363
+ self.sampling_rate = 16000
364
+ logger.warning(f"Could not determine sampling rate. Defaulting to {self.sampling_rate} Hz.")
365
+
366
+ # # Ensure tokenizer pad token
367
+ # if self.tokenizer.pad_token is None:
368
+ # if self.tokenizer.eos_token is not None:
369
+ # logger.warning("Tokenizer does not have a pad token. Setting pad_token to eos_token.")
370
+ # self.tokenizer.pad_token = self.tokenizer.eos_token
371
+ # else:
372
+ # logger.warning("Tokenizer lacks pad and eos token. Adding default pad token '<|pad|>'.")
373
+ # self.tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
374
+
375
+ def link_model(self, model):
376
+ """Links the processor to a SparkTTSModel instance for audio processing calls."""
377
+ if not hasattr(model, 'tokenize_audio') or not hasattr(model, 'detokenize_audio'):
378
+ raise TypeError("The provided model instance does not have the required 'tokenize_audio' and 'detokenize_audio' methods.")
379
+ if not hasattr(model, 'config'):
380
+ logger.warning("Linked model does not have a 'config' attribute. Some processor functionalities might rely on it.")
381
+
382
+ self.model = model
383
+ logger.info("SparkTTSModel successfully linked to the processor.")
384
+ # Update sampling rate based on linked model's config if available
385
+ if hasattr(model, 'config') and hasattr(model.config, 'sample_rate'):
386
+ if self.sampling_rate != model.config.sample_rate:
387
+ logger.info(f"Updating processor sampling rate from {self.sampling_rate} to {model.config.sample_rate} based on linked model config.")
388
+ self.sampling_rate = model.config.sample_rate
389
+ # Also update feature extractor sampling rate if it differs
390
+ if hasattr(self, 'feature_extractor') and self.feature_extractor.sampling_rate != model.config.sample_rate:
391
+ logger.info(f"Updating feature_extractor sampling rate from {self.feature_extractor.sampling_rate} to {model.config.sample_rate}.")
392
+ self.feature_extractor.sampling_rate = model.config.sample_rate
393
+
394
+
395
+ def __call__(
396
+ self,
397
+ text: str,
398
+ prompt_speech_path: Optional[Union[str, Path]] = None,
399
+ prompt_text: Optional[str] = None,
400
+ gender: Optional[str] = None,
401
+ pitch: Optional[str] = None,
402
+ speed: Optional[str] = None,
403
+ return_tensors: Optional[str] = "pt",
404
+ **kwargs, # Allow passing other args like padding, truncation to tokenizer
405
+ ) -> BatchEncoding:
406
+ """
407
+ Processes the input text and optional prompt audio/control parameters into a format suitable for [`SparkTTSModel`].
408
+
409
+ Args:
410
+ text (`str`):
411
+ The main text to be synthesized.
412
+ prompt_speech_path (`str` or `Path`, *optional*):
413
+ Path to the prompt audio file for voice cloning. Required if `gender` is not set.
414
+ prompt_text (`str`, *optional*):
415
+ Transcript of the prompt audio. Used only in voice cloning mode.
416
+ gender (`str`, *optional*):
417
+ Target gender ("male" or "female") for controllable synthesis. If set, enables control mode.
418
+ pitch (`str`, *optional*):
419
+ Target pitch level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set.
420
+ speed (`str`, *optional*):
421
+ Target speed level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set.
422
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
423
+ If set, will return tensors instead of list of python integers. Only "pt" (PyTorch) is supported currently.
424
+ **kwargs:
425
+ Additional arguments passed to the underlying tokenizer's `__call__` method.
426
+
427
+ Returns:
428
+ [`BatchEncoding`]: A dictionary containing the `input_ids` and `attention_mask` for the LLM.
429
+ In voice cloning mode, it also includes `global_token_ids_prompt` (torch.Tensor) representing the
430
+ global tokens extracted from the prompt audio.
431
+ """
432
+
433
+ global_token_ids_prompt = None # Initialize
434
+
435
+ # Determine mode: Control TTS or Voice Cloning (Prompt TTS)
436
+ is_control_mode = gender is not None
437
+ is_cloning_mode = prompt_speech_path is not None and not is_control_mode
438
+
439
+ if is_control_mode:
440
+ # --- Controllable TTS Prompt Construction ---
441
+ if not all([pitch, speed]):
442
+ raise ValueError("For controllable TTS, 'gender', 'pitch', and 'speed' must all be provided.")
443
+ if prompt_speech_path is not None:
444
+ logger.warning("`prompt_speech_path` provided but ignored because `gender` is set (controllable TTS mode).")
445
+
446
+ if not all(k in GENDER_MAP for k in [gender]): # Basic check
447
+ raise ValueError(f"Invalid gender provided: {gender}. Must be one of {list(GENDER_MAP.keys())}")
448
+ if not all(k in LEVELS_MAP for k in [pitch, speed]): # Basic check
449
+ raise ValueError(f"Invalid pitch or speed level provided. Must be one of {list(LEVELS_MAP.keys())}")
450
+
451
+ gender_id = GENDER_MAP[gender]
452
+ pitch_level_id = LEVELS_MAP[pitch]
453
+ speed_level_id = LEVELS_MAP[speed]
454
+
455
+ pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
456
+ speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
457
+ gender_tokens = f"<|gender_{gender_id}|>"
458
+
459
+ attribute_tokens = "".join([gender_tokens, pitch_label_tokens, speed_label_tokens])
460
+
461
+ prompt_list = [
462
+ TASK_TOKEN_MAP["controllable_tts"],
463
+ "<|start_content|>",
464
+ text,
465
+ "<|end_content|>",
466
+ "<|start_style_label|>",
467
+ attribute_tokens,
468
+ "<|end_style_label|>",
469
+ ]
470
+ prompt_string = "".join(prompt_list)
471
+
472
+ elif is_cloning_mode:
473
+ # --- Voice Cloning Prompt Construction ---
474
+ if self.model is None:
475
+ raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before performing voice cloning.")
476
+ prompt_speech_path = Path(prompt_speech_path) # Ensure it's a Path object
477
+ if not prompt_speech_path.exists():
478
+ raise FileNotFoundError(f"Prompt audio file not found: {prompt_speech_path}")
479
+
480
+ # Load and process prompt audio
481
+ try:
482
+ model_config = self.model.config if self.model and hasattr(self.model, 'config') else self.config
483
+ if model_config is None:
484
+ raise ValueError("Configuration not available in processor or linked model.")
485
+
486
+ # Load main wav
487
+ wav = load_audio(
488
+ prompt_speech_path,
489
+ sampling_rate=self.sampling_rate,
490
+ volume_normalize=getattr(model_config, 'volume_normalize', True), # Use getattr for safety
491
+ )
492
+ # Get reference clip
493
+ wav_ref_np = get_ref_clip(wav, model_config) # Pass config object
494
+ wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float()
495
+ wav_tensor = torch.from_numpy(wav).unsqueeze(0).float()
496
+
497
+ # Tokenize using the linked model's method
498
+ # Assuming tokenize_audio returns tensors with batch dim 1: [1, N_global], [1, N_semantic]
499
+ global_tokens_tensor, semantic_tokens_tensor = self.model.tokenize_audio(wav_tensor, wav_ref)
500
+
501
+ # Store the global tokens tensor (with batch dim) for the output dict
502
+ global_token_ids_prompt = global_tokens_tensor # Keep batch dim [1, N_global]
503
+
504
+ # Convert tensors to lists of ints for string formatting
505
+ global_token_list = global_tokens_tensor.squeeze().tolist() # Remove batch dim -> list
506
+ semantic_token_list = semantic_tokens_tensor.squeeze().tolist() # Remove batch dim -> list
507
+
508
+ except Exception as e:
509
+ logger.error(f"Error processing prompt audio {prompt_speech_path}: {e}")
510
+ import traceback
511
+ traceback.print_exc()
512
+ raise
513
+
514
+ # ==============================================================
515
+ # CORRECTED TOKEN STRING FORMATTING
516
+ # ==============================================================
517
+ # Create individual token strings for each ID
518
+ global_tokens_str = "".join([f"<|bicodec_global_{gid}|>" for gid in global_token_list])
519
+ semantic_tokens_str = "".join([f"<|bicodec_semantic_{sid}|>" for sid in semantic_token_list])
520
+ # ==============================================================
521
+
522
+ # Construct prompt list based on presence of prompt_text
523
+ if prompt_text is not None and prompt_text.strip(): # Check if prompt_text is meaningful
524
+ logger.info("Using prompt text in voice cloning prompt.")
525
+ prompt_list = [
526
+ TASK_TOKEN_MAP["tts"], # Or maybe TASK_TOKEN_MAP["prompt_tts"]? Check original logic. Assuming "tts".
527
+ "<|start_content|>",
528
+ prompt_text, # Transcript first
529
+ text, # Then target text
530
+ "<|end_content|>",
531
+ "<|start_global_token|>",
532
+ global_tokens_str,
533
+ "<|end_global_token|>",
534
+ "<|start_semantic_token|>",
535
+ semantic_tokens_str,
536
+ # "<|end_semantic_token|>", # Original code didn't have this marker here
537
+ ]
538
+ else:
539
+ # Simpler prompt without semantic tokens if no transcript provided
540
+ logger.info("No prompt text provided, using text-only voice cloning prompt.")
541
+ prompt_list = [
542
+ TASK_TOKEN_MAP["tts"], # Or maybe TASK_TOKEN_MAP["prompt_tts"]?
543
+ "<|start_content|>",
544
+ text, # Only target text
545
+ "<|end_content|>",
546
+ "<|start_global_token|>",
547
+ global_tokens_str,
548
+ "<|end_global_token|>",
549
+ ]
550
+ prompt_string = "".join(prompt_list)
551
+ logger.debug(f"Generated prompt string (cloning): {prompt_string[:200]}...") # Log start of prompt
552
+
553
+ else:
554
+ raise ValueError("Invalid input combination. Either provide `prompt_speech_path` for cloning or (`gender`, `pitch`, `speed`) for control.")
555
+
556
+ # --- Tokenize the final prompt string ---
557
+ # print(f"Tokenizing prompt: {prompt_string}")
558
+ inputs = self.tokenizer(
559
+ prompt_string,
560
+ return_tensors=return_tensors,
561
+ padding=kwargs.get("padding", False), # Often False for generation prompts unless batching > 1
562
+ truncation=kwargs.get("truncation", True),
563
+ max_length=kwargs.get("max_length", self.tokenizer.model_max_length),
564
+ add_special_tokens=kwargs.get("add_special_tokens", True), # Usually True unless handled manually
565
+ return_attention_mask=kwargs.get("return_attention_mask", True), # Need attention mask
566
+ **{k: v for k, v in kwargs.items() if k not in ["padding", "truncation", "max_length", "add_special_tokens", "return_attention_mask"]}
567
+ )
568
+ logger.debug(f"Tokenized input_ids shape: {inputs['input_ids'].shape}")
569
+
570
+
571
+ # Add the prompt's global tokens (as tensor with batch dim) to the output if in cloning mode
572
+ if is_cloning_mode and global_token_ids_prompt is not None:
573
+ if return_tensors == "pt":
574
+ inputs["global_token_ids_prompt"] = global_token_ids_prompt # Already has batch dim [1, N_global]
575
+ else:
576
+ # Handle non-tensor return if necessary
577
+ inputs["global_token_ids_prompt"] = global_token_ids_prompt.tolist()
578
+
579
+ return inputs
580
+
581
+
582
+ def decode(
583
+ self,
584
+ generated_ids: torch.Tensor,
585
+ global_token_ids_prompt: Optional[torch.Tensor] = None,
586
+ input_ids_len: Optional[int] = None,
587
+ skip_special_tokens: bool = True,
588
+ ) -> Dict[str, Any]:
589
+ """
590
+ Decodes the generated token IDs from [`SparkTTSModel`] into an audio waveform.
591
+
592
+ Args:
593
+ generated_ids (`torch.Tensor`):
594
+ Tensor of token IDs generated by `model.generate()`, including the input prompt part. Shape [B, seq_len].
595
+ global_token_ids_prompt (`torch.Tensor`, *optional*):
596
+ The global tokens extracted from the prompt audio during the `__call__` step (for voice cloning).
597
+ Shape [B, N_global]. Required if the generation was for voice cloning.
598
+ input_ids_len (`int`, *optional*):
599
+ The length of the original input prompt `input_ids` fed to `model.generate()`. Required to
600
+ correctly isolate the newly generated tokens.
601
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
602
+ Whether to skip special tokens during the text decoding step (used to extract audio tokens).
603
+
604
+ Returns:
605
+ Dict[str, Any]: A dictionary containing:
606
+ - "audio": The decoded audio waveform as a NumPy array. Shape [T_audio] (if B=1) or [B, T_audio].
607
+ - "sampling_rate": The sampling rate of the audio.
608
+ """
609
+ if self.model is None:
610
+ raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before decoding.")
611
+ if input_ids_len is None:
612
+ raise ValueError("`input_ids_len` (length of the prompt input_ids) must be provided for decoding.")
613
+
614
+ # --- Isolate generated part and decode text ---
615
+ # Assumes generated_ids has shape [B, full_seq_len]
616
+ # Handle case where generated sequence is shorter than prompt (shouldn't happen with max_new_tokens > 0)
617
+ if generated_ids.shape[1] < input_ids_len:
618
+ logger.warning(f"Generated sequence length ({generated_ids.shape[1]}) is shorter than input prompt length ({input_ids_len}). Decoding might be incorrect.")
619
+ output_only_ids = generated_ids[:, input_ids_len:] # Will be empty if equal
620
+ else:
621
+ output_only_ids = generated_ids[:, input_ids_len:]
622
+
623
+
624
+ # Decode the generated part to find audio tokens
625
+ # Need to handle batch decoding if B > 1
626
+ # print("decode token", self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=False))
627
+ decoded_texts = self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=skip_special_tokens)
628
+
629
+ # --- Extract Audio Tokens ---
630
+ # Handle batch processing correctly
631
+ batch_size = generated_ids.shape[0]
632
+ all_semantic_ids = []
633
+ all_global_tokens = []
634
+ successful_indices = [] # Keep track of which batch items were successful
635
+
636
+ for i in range(batch_size):
637
+ decoded_text = decoded_texts[i]
638
+ current_semantic_ids = None
639
+ current_global_tokens = None
640
+
641
+ # Extract semantic tokens
642
+ try:
643
+ pred_semantic_indices = [int(token) for token in re.findall(r"bicodec_semantic_(\d+)", decoded_text)]
644
+ if not pred_semantic_indices:
645
+ logger.warning(f"Batch item {i}: No semantic tokens found in decoded text: '{decoded_text[:200]}...'")
646
+ continue # Skip this item
647
+
648
+ current_semantic_ids = torch.tensor(pred_semantic_indices).long() # Shape [N_semantic]
649
+ except Exception as e:
650
+ logger.error(f"Batch item {i}: Error parsing semantic tokens from: '{decoded_text[:200]}...'. Error: {e}")
651
+ continue # Skip this item
652
+
653
+ # Determine global tokens
654
+ if global_token_ids_prompt is not None:
655
+ # Cloning mode: Use the provided prompt global tokens for this batch item
656
+ if global_token_ids_prompt.shape[0] != batch_size:
657
+ raise ValueError(f"Batch size mismatch: generated_ids has {batch_size}, but global_token_ids_prompt has {global_token_ids_prompt.shape[0]}.")
658
+ current_global_tokens = global_token_ids_prompt[i] # Shape [N_global]
659
+ else:
660
+ # Control mode: Extract global tokens from the generated text
661
+ try:
662
+ pred_global_indices = [int(token) for token in re.findall(r"bicodec_global_(\d+)", decoded_text)]
663
+ if not pred_global_indices:
664
+ logger.warning(f"Batch item {i}: No global tokens found in decoded text for control mode: '{decoded_text[:200]}...'")
665
+ continue # Skip this item
666
+
667
+ current_global_tokens = torch.tensor(pred_global_indices).long() # Shape [N_global]
668
+
669
+ except Exception as e:
670
+ logger.error(f"Batch item {i}: Error parsing global tokens from: '{decoded_text[:200]}...'. Error: {e}")
671
+ continue # Skip this item
672
+
673
+ # If both tokens extracted successfully
674
+ all_semantic_ids.append(current_semantic_ids)
675
+ all_global_tokens.append(current_global_tokens)
676
+ successful_indices.append(i)
677
+
678
+ if not successful_indices:
679
+ logger.error("Failed to extract audio tokens for any item in the batch.")
680
+ return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
681
+
682
+ # Pad sequences to the max length within the successful batch items for batch detokenization
683
+ # Note: BiCodec might not support batching if sequences have different lengths. Check its implementation.
684
+ # Assuming BiCodec *can* handle batches if padded (or if lengths are naturally equal).
685
+ # This padding might be unnecessary if BiCodec handles variable lengths or if B=1 anyway.
686
+ # For now, let's assume B=1 was handled correctly and skip complex padding.
687
+ if batch_size > 1 and len(successful_indices) < batch_size:
688
+ logger.warning(f"Only successfully decoded {len(successful_indices)} out of {batch_size} batch items.")
689
+ # Further processing might need to handle only the successful items.
690
+
691
+ # Let's proceed assuming B=1 or BiCodec handles batches appropriately.
692
+ # Stack the successful tokens.
693
+ try:
694
+ # Need to ensure tensors have the same length before stacking if BiCodec requires it.
695
+ # If BiCodec handles variable length, stacking might not be needed, just loop and call detokenize.
696
+ # Let's assume B=1 for simplicity of the example, matching original code's likely behavior.
697
+ if len(successful_indices) != 1:
698
+ raise NotImplementedError("Batch decoding (B > 1) requires verification of BiCodec's batch handling and potentially padding.")
699
+
700
+ final_semantic_ids = all_semantic_ids[0].unsqueeze(0) # Add batch dim [1, N_semantic]
701
+ final_global_tokens = all_global_tokens[0].unsqueeze(0) # Add batch dim [1, N_global]
702
+
703
+ except IndexError: # Should not happen if successful_indices is not empty
704
+ logger.error("Internal error during token batch preparation.")
705
+ return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
706
+
707
+
708
+ # --- Detokenize Audio ---
709
+ try:
710
+ # Call the linked model's detokenize method
711
+ # print(f"DEBUG: Detokenizing audio with global tokens {final_global_tokens.shape}, semantic tokens {final_semantic_ids.shape}")
712
+ output_wav = self.model.detokenize_audio(final_global_tokens, final_semantic_ids)
713
+ # detokenize_audio now returns numpy array float32 in [-1, 1]
714
+
715
+ # Optional: Double-check dtype here if needed, but should be handled by detokenize_audio now
716
+ # if output_wav.dtype != np.float32:
717
+ # logger.warning(f"Audio dtype after detokenize is {output_wav.dtype}. Converting to float32.")
718
+ # output_wav = output_wav.astype(np.float32)
719
+ # output_wav = np.clip(output_wav, -1.0, 1.0) # Clipping done in detokenize_audio
720
+
721
+ except Exception as e:
722
+ logger.error(f"Error during audio detokenization: {e}")
723
+ import traceback
724
+ traceback.print_exc()
725
+ raise RuntimeError("Audio detokenization failed.") from e
726
+
727
+ return {"audio": output_wav, "sampling_rate": self.sampling_rate}
728
+
729
+
730
+ @classmethod
731
+ def from_pretrained(
732
+ cls,
733
+ pretrained_model_name_or_path: Union[str, os.PathLike],
734
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
735
+ force_download: bool = False,
736
+ local_files_only: bool = False,
737
+ token: Optional[Union[str, bool]] = None,
738
+ revision: str = "main",
739
+ trust_remote_code: bool = False, # Allow passing this, needed for config potentially
740
+ **kwargs,
741
+ ):
742
+ r"""
743
+ Instantiate a SparkTTSProcessor from pretrained components.
744
+ """
745
+ # Pop specific kwargs for this method
746
+ config = kwargs.pop("config", None) # Allow passing config explicitly
747
+
748
+ # --- 1. Load Config (to find component paths) ---
749
+ # We need the config even if the processor doesn't store it permanently,
750
+ # just to find where the tokenizer/feature_extractor live.
751
+ loaded_config = None
752
+ if not isinstance(config, SparkTTSConfig):
753
+ try:
754
+ # Load the specific config class
755
+ loaded_config = SparkTTSConfig.from_pretrained(
756
+ pretrained_model_name_or_path,
757
+ cache_dir=cache_dir,
758
+ force_download=force_download,
759
+ local_files_only=local_files_only,
760
+ token=token,
761
+ revision=revision,
762
+ trust_remote_code=trust_remote_code, # Config might be custom
763
+ **kwargs, # Pass relevant kwargs
764
+ )
765
+ except Exception as e:
766
+ logger.warning(
767
+ f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. "
768
+ f"Attempting to load components from default relative paths ('LLM', 'wav2vec2-large-xlsr-53'). Error: {e}"
769
+ )
770
+ loaded_config = None # Fallback
771
+ else:
772
+ # Config object was passed directly
773
+ loaded_config = config
774
+
775
+
776
+ # --- 2. Determine Component Paths ---
777
+ llm_tokenizer_path_or_id = "./LLM" # Default relative path
778
+ w2v_processor_path_or_id = "./wav2vec2-large-xlsr-53" # Default relative path
779
+
780
+ if loaded_config:
781
+ llm_tokenizer_path_or_id = getattr(loaded_config, 'llm_model_name_or_path', llm_tokenizer_path_or_id)
782
+ w2v_processor_path_or_id = getattr(loaded_config, 'wav2vec2_model_name_or_path', w2v_processor_path_or_id)
783
+
784
+ # The component `from_pretrained` methods handle resolving these paths/IDs
785
+ # whether they are relative subfolders of `pretrained_model_name_or_path`
786
+ # or separate Hub IDs.
787
+
788
+ # --- 3. Load Components ---
789
+ # Pass down relevant kwargs for loading components
790
+ component_loading_kwargs = {
791
+ "cache_dir": cache_dir,
792
+ "force_download": force_download,
793
+ "local_files_only": local_files_only,
794
+ "token": token,
795
+ "revision": revision,
796
+ **kwargs # Pass other user kwargs
797
+ }
798
+ try:
799
+ # Tokenizer might require trust_remote_code if its class is custom
800
+ tokenizer = AutoTokenizer.from_pretrained(
801
+ pretrained_model_name_or_path, # Main path
802
+ subfolder=llm_tokenizer_path_or_id.lstrip('./'), # Specify subfolder relative to main path
803
+ trust_remote_code=trust_remote_code,
804
+ **component_loading_kwargs
805
+ )
806
+ except Exception as e:
807
+ # Fallback: try loading directly using the path/id from config if different
808
+ if llm_tokenizer_path_or_id != "./LLM":
809
+ try:
810
+ logger.info(f"Retrying tokenizer load directly from: {llm_tokenizer_path_or_id}")
811
+ tokenizer = AutoTokenizer.from_pretrained(
812
+ llm_tokenizer_path_or_id,
813
+ trust_remote_code=trust_remote_code,
814
+ **component_loading_kwargs
815
+ )
816
+ except Exception as e2:
817
+ raise OSError(f"Could not load tokenizer using main path + subfolder or directly from '{llm_tokenizer_path_or_id}'. Error: {e2}") from e
818
+ else:
819
+ raise OSError(f"Could not load tokenizer from subfolder '{llm_tokenizer_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}")
820
+
821
+
822
+ try:
823
+ # Feature extractor usually doesn't need trust_remote_code
824
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
825
+ pretrained_model_name_or_path, # Main path
826
+ subfolder=w2v_processor_path_or_id.lstrip('./'), # Specify subfolder relative to main path
827
+ **component_loading_kwargs
828
+ )
829
+ except Exception as e:
830
+ # Fallback: try loading directly using the path/id from config if different
831
+ if w2v_processor_path_or_id != "./wav2vec2-large-xlsr-53":
832
+ try:
833
+ logger.info(f"Retrying feature extractor load directly from: {w2v_processor_path_or_id}")
834
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
835
+ w2v_processor_path_or_id,
836
+ **component_loading_kwargs
837
+ )
838
+ except Exception as e2:
839
+ raise OSError(f"Could not load feature extractor using main path + subfolder or directly from '{w2v_processor_path_or_id}'. Error: {e2}") from e
840
+ else:
841
+ raise OSError(f"Could not load feature extractor from subfolder '{w2v_processor_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}")
842
+
843
+
844
+ # --- 4. Instantiate processor ---
845
+ # Pass the potentially loaded config object (or None)
846
+ return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=loaded_config)
847
+
848
+
849
+ def save_pretrained(
850
+ self,
851
+ save_directory: Union[str, os.PathLike],
852
+ push_to_hub: bool = False,
853
+ **kwargs,
854
+ ):
855
+ """
856
+ Save the processor's state (tokenizer and feature extractor files) to a directory.
857
+
858
+ Args:
859
+ save_directory (`str` or `os.PathLike`):
860
+ Directory where the processor files will be saved.
861
+ push_to_hub (`bool`, *optional*, defaults to `False`):
862
+ Whether or not to push your model to the Hugging Face Hub after saving it.
863
+ **kwargs:
864
+ Additional key word arguments passed along to the `push_to_hub` method.
865
+ """
866
+ save_directory = Path(save_directory)
867
+ save_directory.mkdir(parents=True, exist_ok=True)
868
+
869
+ # Save tokenizer
870
+ self.tokenizer.save_pretrained(str(save_directory), **kwargs)
871
+
872
+ # Save feature extractor
873
+ self.feature_extractor.save_pretrained(str(save_directory), **kwargs)
874
+
875
+ # Save the main processor config (if it exists and has relevant info)
876
+ # Note: The SparkTTSConfig is usually saved with the *model*, not the processor.
877
+ # However, if the processor holds specific config needed for reloading *itself*,
878
+ # it could be saved here. Usually, relying on the model's config is sufficient.
879
+ # if self.config:
880
+ # self.config.save_pretrained(str(save_directory)) # Example if needed
881
+
882
+ logger.info(f"Processor components saved in {save_directory}")
883
+
884
+ if push_to_hub:
885
+ # Commit message and other hub kwargs can be passed via **kwargs
886
+ commit_message = kwargs.pop("commit_message", "Save processor")
887
+ return self.push_to_hub(save_directory, commit_message=commit_message, **kwargs)
888
+
889
+ return str(save_directory) # Return path consistent with Mixin
wav2vec2-large-xlsr-53/README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: multilingual
3
+ datasets:
4
+ - common_voice
5
+ tags:
6
+ - speech
7
+ license: apache-2.0
8
+ ---
9
+
10
+ # Wav2Vec2-XLSR-53
11
+
12
+ [Facebook's XLSR-Wav2Vec2](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/)
13
+
14
+ The base model pretrained on 16kHz sampled speech audio. When using the model make sure that your speech input is also sampled at 16Khz. Note that this model should be fine-tuned on a downstream task, like Automatic Speech Recognition. Check out [this blog](https://huggingface.co/blog/fine-tune-wav2vec2-english) for more information.
15
+
16
+ [Paper](https://arxiv.org/abs/2006.13979)
17
+
18
+ Authors: Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli
19
+
20
+ **Abstract**
21
+ This paper presents XLSR which learns cross-lingual speech representations by pretraining a single model from the raw waveform of speech in multiple languages. We build on wav2vec 2.0 which is trained by solving a contrastive task over masked latent speech representations and jointly learns a quantization of the latents shared across languages. The resulting model is fine-tuned on labeled data and experiments show that cross-lingual pretraining significantly outperforms monolingual pretraining. On the CommonVoice benchmark, XLSR shows a relative phoneme error rate reduction of 72% compared to the best known results. On BABEL, our approach improves word error rate by 16% relative compared to a comparable system. Our approach enables a single multilingual speech recognition model which is competitive to strong individual models. Analysis shows that the latent discrete speech representations are shared across languages with increased sharing for related languages. We hope to catalyze research in low-resource speech understanding by releasing XLSR-53, a large model pretrained in 53 languages.
22
+
23
+ The original model can be found under https://github.com/pytorch/fairseq/tree/master/examples/wav2vec#wav2vec-20.
24
+
25
+ # Usage
26
+
27
+ See [this notebook](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_Tune_XLSR_Wav2Vec2_on_Turkish_ASR_with_%F0%9F%A4%97_Transformers.ipynb) for more information on how to fine-tune the model.
28
+
29
+ ![model image](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/xlsr_wav2vec2.png)
wav2vec2-large-xlsr-53/config.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "apply_spec_augment": true,
4
+ "architectures": [
5
+ "Wav2Vec2ForPreTraining"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "bos_token_id": 1,
9
+ "codevector_dim": 768,
10
+ "contrastive_logits_temperature": 0.1,
11
+ "conv_bias": true,
12
+ "conv_dim": [
13
+ 512,
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512,
18
+ 512,
19
+ 512
20
+ ],
21
+ "conv_kernel": [
22
+ 10,
23
+ 3,
24
+ 3,
25
+ 3,
26
+ 3,
27
+ 2,
28
+ 2
29
+ ],
30
+ "conv_stride": [
31
+ 5,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2,
36
+ 2,
37
+ 2
38
+ ],
39
+ "ctc_loss_reduction": "sum",
40
+ "ctc_zero_infinity": false,
41
+ "diversity_loss_weight": 0.1,
42
+ "do_stable_layer_norm": true,
43
+ "eos_token_id": 2,
44
+ "feat_extract_activation": "gelu",
45
+ "feat_extract_dropout": 0.0,
46
+ "feat_extract_norm": "layer",
47
+ "feat_proj_dropout": 0.1,
48
+ "feat_quantizer_dropout": 0.0,
49
+ "final_dropout": 0.0,
50
+ "gradient_checkpointing": false,
51
+ "hidden_act": "gelu",
52
+ "hidden_dropout": 0.1,
53
+ "hidden_size": 1024,
54
+ "initializer_range": 0.02,
55
+ "intermediate_size": 4096,
56
+ "layer_norm_eps": 1e-05,
57
+ "layerdrop": 0.1,
58
+ "mask_channel_length": 10,
59
+ "mask_channel_min_space": 1,
60
+ "mask_channel_other": 0.0,
61
+ "mask_channel_prob": 0.0,
62
+ "mask_channel_selection": "static",
63
+ "mask_feature_length": 10,
64
+ "mask_feature_prob": 0.0,
65
+ "mask_time_length": 10,
66
+ "mask_time_min_space": 1,
67
+ "mask_time_other": 0.0,
68
+ "mask_time_prob": 0.075,
69
+ "mask_time_selection": "static",
70
+ "model_type": "wav2vec2",
71
+ "num_attention_heads": 16,
72
+ "num_codevector_groups": 2,
73
+ "num_codevectors_per_group": 320,
74
+ "num_conv_pos_embedding_groups": 16,
75
+ "num_conv_pos_embeddings": 128,
76
+ "num_feat_extract_layers": 7,
77
+ "num_hidden_layers": 24,
78
+ "num_negatives": 100,
79
+ "pad_token_id": 0,
80
+ "proj_codevector_dim": 768,
81
+ "transformers_version": "4.7.0.dev0",
82
+ "vocab_size": 32
83
+ }
wav2vec2-large-xlsr-53/preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
wav2vec2-large-xlsr-53/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:314340227371a608f71adcd5f0de5933824fe77e55822aa4b24dba9c1c364dcb
3
+ size 1269737156