Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- BiCodec/config.yaml +60 -0
- BiCodec/model.safetensors +3 -0
- LLM/added_tokens.json +0 -0
- LLM/config.json +27 -0
- LLM/merges.txt +0 -0
- LLM/model.safetensors +3 -0
- LLM/special_tokens_map.json +31 -0
- LLM/tokenizer.json +3 -0
- LLM/tokenizer_config.json +0 -0
- LLM/vocab.json +0 -0
- __init__.py +0 -0
- __pycache__/configuration_spark_tts.cpython-312.pyc +0 -0
- __pycache__/modeling_spark_tts.cpython-312.pyc +3 -0
- __pycache__/processing_spark_tts.cpython-312.pyc +0 -0
- config.json +83 -0
- configuration_spark_tts.py +233 -0
- modeling_spark_tts.py +0 -0
- processing_spark_tts.py +889 -0
- wav2vec2-large-xlsr-53/README.md +29 -0
- wav2vec2-large-xlsr-53/config.json +83 -0
- wav2vec2-large-xlsr-53/preprocessor_config.json +9 -0
- wav2vec2-large-xlsr-53/pytorch_model.bin +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
LLM/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
__pycache__/modeling_spark_tts.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
|
BiCodec/config.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
audio_tokenizer:
|
| 2 |
+
mel_params:
|
| 3 |
+
sample_rate: 16000
|
| 4 |
+
n_fft: 1024
|
| 5 |
+
win_length: 640
|
| 6 |
+
hop_length: 320
|
| 7 |
+
mel_fmin: 10
|
| 8 |
+
mel_fmax: null
|
| 9 |
+
num_mels: 128
|
| 10 |
+
|
| 11 |
+
encoder:
|
| 12 |
+
input_channels: 1024
|
| 13 |
+
vocos_dim: 384
|
| 14 |
+
vocos_intermediate_dim: 2048
|
| 15 |
+
vocos_num_layers: 12
|
| 16 |
+
out_channels: 1024
|
| 17 |
+
sample_ratios: [1,1]
|
| 18 |
+
|
| 19 |
+
decoder:
|
| 20 |
+
input_channel: 1024
|
| 21 |
+
channels: 1536
|
| 22 |
+
rates: [8, 5, 4, 2]
|
| 23 |
+
kernel_sizes: [16,11,8,4]
|
| 24 |
+
|
| 25 |
+
quantizer:
|
| 26 |
+
input_dim: 1024
|
| 27 |
+
codebook_size: 8192
|
| 28 |
+
codebook_dim: 8
|
| 29 |
+
commitment: 0.25
|
| 30 |
+
codebook_loss_weight: 2.0
|
| 31 |
+
use_l2_normlize: True
|
| 32 |
+
threshold_ema_dead_code: 0.2
|
| 33 |
+
|
| 34 |
+
speaker_encoder:
|
| 35 |
+
input_dim: 128
|
| 36 |
+
out_dim: 1024
|
| 37 |
+
latent_dim: 128
|
| 38 |
+
token_num: 32
|
| 39 |
+
fsq_levels: [4, 4, 4, 4, 4, 4]
|
| 40 |
+
fsq_num_quantizers: 1
|
| 41 |
+
|
| 42 |
+
prenet:
|
| 43 |
+
input_channels: 1024
|
| 44 |
+
vocos_dim: 384
|
| 45 |
+
vocos_intermediate_dim: 2048
|
| 46 |
+
vocos_num_layers: 12
|
| 47 |
+
out_channels: 1024
|
| 48 |
+
condition_dim: 1024
|
| 49 |
+
sample_ratios: [1,1]
|
| 50 |
+
use_tanh_at_final: False
|
| 51 |
+
|
| 52 |
+
postnet:
|
| 53 |
+
input_channels: 1024
|
| 54 |
+
vocos_dim: 384
|
| 55 |
+
vocos_intermediate_dim: 2048
|
| 56 |
+
vocos_num_layers: 6
|
| 57 |
+
out_channels: 1024
|
| 58 |
+
use_tanh_at_final: False
|
| 59 |
+
|
| 60 |
+
|
BiCodec/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9940cd48d4446e4340ced82d234bf5618350dd9f5db900ebe47a4fdb03867ec
|
| 3 |
+
size 625518756
|
LLM/added_tokens.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LLM/config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen2ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_dropout": 0.0,
|
| 6 |
+
"bos_token_id": 151643,
|
| 7 |
+
"eos_token_id": 151645,
|
| 8 |
+
"hidden_act": "silu",
|
| 9 |
+
"hidden_size": 896,
|
| 10 |
+
"initializer_range": 0.02,
|
| 11 |
+
"intermediate_size": 4864,
|
| 12 |
+
"max_position_embeddings": 32768,
|
| 13 |
+
"max_window_layers": 21,
|
| 14 |
+
"model_type": "qwen2",
|
| 15 |
+
"num_attention_heads": 14,
|
| 16 |
+
"num_hidden_layers": 24,
|
| 17 |
+
"num_key_value_heads": 2,
|
| 18 |
+
"rms_norm_eps": 1e-06,
|
| 19 |
+
"rope_theta": 1000000.0,
|
| 20 |
+
"sliding_window": 32768,
|
| 21 |
+
"tie_word_embeddings": true,
|
| 22 |
+
"torch_dtype": "bfloat16",
|
| 23 |
+
"transformers_version": "4.43.1",
|
| 24 |
+
"use_cache": true,
|
| 25 |
+
"use_sliding_window": false,
|
| 26 |
+
"vocab_size": 166000
|
| 27 |
+
}
|
LLM/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LLM/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4719dd1d707ab915a12b58602eeb2521c8f259d0cc0dc2e0dca502a4fae22c8f
|
| 3 |
+
size 1013300536
|
LLM/special_tokens_map.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|im_start|>",
|
| 4 |
+
"<|im_end|>",
|
| 5 |
+
"<|object_ref_start|>",
|
| 6 |
+
"<|object_ref_end|>",
|
| 7 |
+
"<|box_start|>",
|
| 8 |
+
"<|box_end|>",
|
| 9 |
+
"<|quad_start|>",
|
| 10 |
+
"<|quad_end|>",
|
| 11 |
+
"<|vision_start|>",
|
| 12 |
+
"<|vision_end|>",
|
| 13 |
+
"<|vision_pad|>",
|
| 14 |
+
"<|image_pad|>",
|
| 15 |
+
"<|video_pad|>"
|
| 16 |
+
],
|
| 17 |
+
"eos_token": {
|
| 18 |
+
"content": "<|im_end|>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
},
|
| 24 |
+
"pad_token": {
|
| 25 |
+
"content": "<|endoftext|>",
|
| 26 |
+
"lstrip": false,
|
| 27 |
+
"normalized": false,
|
| 28 |
+
"rstrip": false,
|
| 29 |
+
"single_word": false
|
| 30 |
+
}
|
| 31 |
+
}
|
LLM/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c8b057d6ca205a429cc3428b9fc815f0d6ee1d53106dd5e5b129ef9db2ff057
|
| 3 |
+
size 14129172
|
LLM/tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LLM/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
__init__.py
ADDED
|
File without changes
|
__pycache__/configuration_spark_tts.cpython-312.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
__pycache__/modeling_spark_tts.cpython-312.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3dc50b2b543272688fbfa8c52a85837182e7672981fbf7d232fb5c74ef77b657
|
| 3 |
+
size 134941
|
__pycache__/processing_spark_tts.cpython-312.pyc
ADDED
|
Binary file (36 kB). View file
|
|
|
config.json
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "spark-tts",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"SparkTTSModel"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_spark_tts.SparkTTSConfig",
|
| 8 |
+
"AutoModel": "modeling_spark_tts.SparkTTSModel",
|
| 9 |
+
"AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
|
| 10 |
+
},
|
| 11 |
+
"processor_class": "processing_spark_tts.SparkTTSProcessor",
|
| 12 |
+
"llm_model_name_or_path": "./LLM",
|
| 13 |
+
"bicodec_model_name_or_path": "./BiCodec",
|
| 14 |
+
"wav2vec2_model_name_or_path": "./wav2vec2-large-xlsr-53",
|
| 15 |
+
"sample_rate": 16000,
|
| 16 |
+
"highpass_cutoff_freq": 40,
|
| 17 |
+
"latent_hop_length": 320,
|
| 18 |
+
"ref_segment_duration": 6.0,
|
| 19 |
+
"volume_normalize": true,
|
| 20 |
+
"torch_dtype": "bfloat16",
|
| 21 |
+
"transformers_version": "4.50.3",
|
| 22 |
+
"_commit_hash": null,
|
| 23 |
+
"bicodec_config": {
|
| 24 |
+
"mel_params": {
|
| 25 |
+
"sample_rate": 16000,
|
| 26 |
+
"n_fft": 1024,
|
| 27 |
+
"win_length": 640,
|
| 28 |
+
"hop_length": 320,
|
| 29 |
+
"mel_fmin": 10,
|
| 30 |
+
"mel_fmax": null,
|
| 31 |
+
"num_mels": 128
|
| 32 |
+
},
|
| 33 |
+
"encoder_config": {
|
| 34 |
+
"input_channels": 1024,
|
| 35 |
+
"vocos_dim": 384,
|
| 36 |
+
"vocos_intermediate_dim": 2048,
|
| 37 |
+
"vocos_num_layers": 12,
|
| 38 |
+
"out_channels": 1024,
|
| 39 |
+
"sample_ratios": [1, 1]
|
| 40 |
+
},
|
| 41 |
+
"decoder_config": {
|
| 42 |
+
"input_channel": 1024,
|
| 43 |
+
"channels": 1536,
|
| 44 |
+
"rates": [8, 5, 4, 2],
|
| 45 |
+
"kernel_sizes": [16, 11, 8, 4]
|
| 46 |
+
},
|
| 47 |
+
"quantizer_config": {
|
| 48 |
+
"input_dim": 1024,
|
| 49 |
+
"codebook_size": 8192,
|
| 50 |
+
"codebook_dim": 8,
|
| 51 |
+
"commitment": 0.25,
|
| 52 |
+
"codebook_loss_weight": 2.0,
|
| 53 |
+
"decay": 0.99,
|
| 54 |
+
"threshold_ema_dead_code": 0.2
|
| 55 |
+
},
|
| 56 |
+
"speaker_encoder_config": {
|
| 57 |
+
"input_dim": 128,
|
| 58 |
+
"out_dim": 1024,
|
| 59 |
+
"latent_dim": 128,
|
| 60 |
+
"token_num": 32,
|
| 61 |
+
"fsq_levels": [4, 4, 4, 4, 4, 4],
|
| 62 |
+
"fsq_num_quantizers": 1
|
| 63 |
+
},
|
| 64 |
+
"prenet_config": {
|
| 65 |
+
"input_channels": 1024,
|
| 66 |
+
"vocos_dim": 384,
|
| 67 |
+
"vocos_intermediate_dim": 2048,
|
| 68 |
+
"vocos_num_layers": 12,
|
| 69 |
+
"out_channels": 1024,
|
| 70 |
+
"condition_dim": 1024,
|
| 71 |
+
"sample_ratios": [1, 1],
|
| 72 |
+
"use_tanh_at_final": false
|
| 73 |
+
},
|
| 74 |
+
"postnet_config": {
|
| 75 |
+
"input_channels": 1024,
|
| 76 |
+
"vocos_dim": 384,
|
| 77 |
+
"vocos_intermediate_dim": 2048,
|
| 78 |
+
"vocos_num_layers": 6,
|
| 79 |
+
"out_channels": 1024,
|
| 80 |
+
"use_tanh_at_final": false
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
}
|
configuration_spark_tts.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
# ... (License headers remain the same) ...
|
| 4 |
+
""" SparkTTS model configuration"""
|
| 5 |
+
|
| 6 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 7 |
+
from transformers.utils import logging
|
| 8 |
+
from typing import List, Optional # Added typing
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
# --- Define Individual Sub-Component Config Classes ---
|
| 14 |
+
|
| 15 |
+
class SparkTTSMelParamsConfig(PretrainedConfig):
|
| 16 |
+
"""Configuration for Mel Spectrogram parameters."""
|
| 17 |
+
model_type = "spark-tts-mel-params"
|
| 18 |
+
def __init__(self, sample_rate=16000, n_fft=1024, win_length=640, hop_length=320,
|
| 19 |
+
mel_fmin=10, mel_fmax=None, num_mels=128, **kwargs):
|
| 20 |
+
super().__init__(**kwargs)
|
| 21 |
+
self.sample_rate = sample_rate
|
| 22 |
+
self.n_fft = n_fft
|
| 23 |
+
self.win_length = win_length
|
| 24 |
+
self.hop_length = hop_length
|
| 25 |
+
self.mel_fmin = mel_fmin
|
| 26 |
+
self.mel_fmax = mel_fmax
|
| 27 |
+
self.num_mels = num_mels
|
| 28 |
+
|
| 29 |
+
class SparkTTSEncoderConfig(PretrainedConfig):
|
| 30 |
+
"""Configuration for the BiCodec Feature Encoder."""
|
| 31 |
+
model_type = "spark-tts-encoder"
|
| 32 |
+
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
|
| 33 |
+
vocos_num_layers=12, out_channels=1024, sample_ratios=[1, 1], **kwargs):
|
| 34 |
+
super().__init__(**kwargs)
|
| 35 |
+
self.input_channels = input_channels
|
| 36 |
+
self.vocos_dim = vocos_dim
|
| 37 |
+
self.vocos_intermediate_dim = vocos_intermediate_dim
|
| 38 |
+
self.vocos_num_layers = vocos_num_layers
|
| 39 |
+
self.out_channels = out_channels
|
| 40 |
+
self.sample_ratios = sample_ratios
|
| 41 |
+
|
| 42 |
+
class SparkTTSDecoderConfig(PretrainedConfig):
|
| 43 |
+
"""Configuration for the BiCodec Wave Generator (Decoder)."""
|
| 44 |
+
model_type = "spark-tts-decoder"
|
| 45 |
+
def __init__(self, input_channel=1024, channels=1536, rates=[8, 5, 4, 2],
|
| 46 |
+
kernel_sizes=[16, 11, 8, 4], **kwargs):
|
| 47 |
+
super().__init__(**kwargs)
|
| 48 |
+
self.input_channel = input_channel
|
| 49 |
+
self.channels = channels
|
| 50 |
+
self.rates = rates
|
| 51 |
+
self.kernel_sizes = kernel_sizes
|
| 52 |
+
|
| 53 |
+
class SparkTTSQuantizerConfig(PretrainedConfig):
|
| 54 |
+
"""Configuration for the BiCodec Factorized Vector Quantizer."""
|
| 55 |
+
model_type = "spark-tts-quantizer"
|
| 56 |
+
def __init__(self, input_dim=1024, codebook_size=8192, codebook_dim=8,
|
| 57 |
+
commitment=0.25, codebook_loss_weight=2.0, decay=0.99,
|
| 58 |
+
threshold_ema_dead_code=0.2, **kwargs):
|
| 59 |
+
# Note: Removed use_l2_normlize as it wasn't in the original class __init__ args
|
| 60 |
+
# Add it back if it's actually used by the FactorizedVectorQuantize class init
|
| 61 |
+
super().__init__(**kwargs)
|
| 62 |
+
self.input_dim = input_dim
|
| 63 |
+
self.codebook_size = codebook_size
|
| 64 |
+
self.codebook_dim = codebook_dim
|
| 65 |
+
self.commitment = commitment
|
| 66 |
+
self.codebook_loss_weight = codebook_loss_weight
|
| 67 |
+
self.decay = decay
|
| 68 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 69 |
+
|
| 70 |
+
class SparkTTSSpeakerEncoderConfig(PretrainedConfig):
|
| 71 |
+
"""Configuration for the BiCodec Speaker Encoder."""
|
| 72 |
+
model_type = "spark-tts-speaker-encoder"
|
| 73 |
+
def __init__(self, input_dim=128, out_dim=1024, latent_dim=128, token_num=32,
|
| 74 |
+
fsq_levels=[4, 4, 4, 4, 4, 4], fsq_num_quantizers=1, **kwargs):
|
| 75 |
+
super().__init__(**kwargs)
|
| 76 |
+
self.input_dim = input_dim
|
| 77 |
+
self.out_dim = out_dim
|
| 78 |
+
self.latent_dim = latent_dim
|
| 79 |
+
self.token_num = token_num
|
| 80 |
+
self.fsq_levels = fsq_levels
|
| 81 |
+
self.fsq_num_quantizers = fsq_num_quantizers
|
| 82 |
+
|
| 83 |
+
class SparkTTSPrenetConfig(PretrainedConfig):
|
| 84 |
+
"""Configuration for the BiCodec Prenet."""
|
| 85 |
+
model_type = "spark-tts-prenet"
|
| 86 |
+
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
|
| 87 |
+
vocos_num_layers=12, out_channels=1024, condition_dim=1024,
|
| 88 |
+
sample_ratios=[1, 1], use_tanh_at_final=False, **kwargs):
|
| 89 |
+
super().__init__(**kwargs)
|
| 90 |
+
self.input_channels = input_channels
|
| 91 |
+
self.vocos_dim = vocos_dim
|
| 92 |
+
self.vocos_intermediate_dim = vocos_intermediate_dim
|
| 93 |
+
self.vocos_num_layers = vocos_num_layers
|
| 94 |
+
self.out_channels = out_channels
|
| 95 |
+
self.condition_dim = condition_dim
|
| 96 |
+
self.sample_ratios = sample_ratios
|
| 97 |
+
self.use_tanh_at_final = use_tanh_at_final
|
| 98 |
+
|
| 99 |
+
class SparkTTSPostnetConfig(PretrainedConfig):
|
| 100 |
+
"""Configuration for the BiCodec Postnet."""
|
| 101 |
+
model_type = "spark-tts-postnet"
|
| 102 |
+
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
|
| 103 |
+
vocos_num_layers=6, out_channels=1024, use_tanh_at_final=False, **kwargs):
|
| 104 |
+
# Note: Removed condition_dim as it wasn't in the original config example for postnet
|
| 105 |
+
super().__init__(**kwargs)
|
| 106 |
+
self.input_channels = input_channels
|
| 107 |
+
self.vocos_dim = vocos_dim
|
| 108 |
+
self.vocos_intermediate_dim = vocos_intermediate_dim
|
| 109 |
+
self.vocos_num_layers = vocos_num_layers
|
| 110 |
+
self.out_channels = out_channels
|
| 111 |
+
self.use_tanh_at_final = use_tanh_at_final
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# --- Define the Intermediate BiCodec Config Class ---
|
| 115 |
+
|
| 116 |
+
class SparkTTSBiCodecConfig(PretrainedConfig):
|
| 117 |
+
"""
|
| 118 |
+
Intermediate configuration class for the BiCodec component within SparkTTS.
|
| 119 |
+
It holds instances of the individual sub-component configurations.
|
| 120 |
+
"""
|
| 121 |
+
model_type = "spark-tts-bicodec"
|
| 122 |
+
# Map keys in the 'bicodec_config' dict to their respective classes
|
| 123 |
+
sub_configs = {
|
| 124 |
+
"mel_params": SparkTTSMelParamsConfig,
|
| 125 |
+
"encoder_config": SparkTTSEncoderConfig,
|
| 126 |
+
"decoder_config": SparkTTSDecoderConfig,
|
| 127 |
+
"quantizer_config": SparkTTSQuantizerConfig,
|
| 128 |
+
"speaker_encoder_config": SparkTTSSpeakerEncoderConfig,
|
| 129 |
+
"prenet_config": SparkTTSPrenetConfig,
|
| 130 |
+
"postnet_config": SparkTTSPostnetConfig,
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
mel_params=None,
|
| 136 |
+
encoder_config=None,
|
| 137 |
+
decoder_config=None,
|
| 138 |
+
quantizer_config=None,
|
| 139 |
+
speaker_encoder_config=None,
|
| 140 |
+
prenet_config=None,
|
| 141 |
+
postnet_config=None,
|
| 142 |
+
**kwargs,
|
| 143 |
+
):
|
| 144 |
+
super().__init__(**kwargs)
|
| 145 |
+
|
| 146 |
+
# Instantiate sub-configs from dictionaries or use defaults/provided instances
|
| 147 |
+
self.mel_params = self._init_sub_config(mel_params, "mel_params")
|
| 148 |
+
self.encoder_config = self._init_sub_config(encoder_config, "encoder_config")
|
| 149 |
+
self.decoder_config = self._init_sub_config(decoder_config, "decoder_config")
|
| 150 |
+
self.quantizer_config = self._init_sub_config(quantizer_config, "quantizer_config")
|
| 151 |
+
self.speaker_encoder_config = self._init_sub_config(speaker_encoder_config, "speaker_encoder_config")
|
| 152 |
+
self.prenet_config = self._init_sub_config(prenet_config, "prenet_config")
|
| 153 |
+
self.postnet_config = self._init_sub_config(postnet_config, "postnet_config")
|
| 154 |
+
|
| 155 |
+
def _init_sub_config(self, config_input, config_key):
|
| 156 |
+
"""Helper to initialize sub-configs."""
|
| 157 |
+
config_cls = self.sub_configs[config_key]
|
| 158 |
+
if isinstance(config_input, dict):
|
| 159 |
+
return config_cls(**config_input)
|
| 160 |
+
elif config_input is None:
|
| 161 |
+
return config_cls() # Initialize with defaults
|
| 162 |
+
elif isinstance(config_input, config_cls):
|
| 163 |
+
return config_input # Already an instance
|
| 164 |
+
else:
|
| 165 |
+
raise TypeError(f"Invalid type for {config_key}: {type(config_input)}. Expected dict, None, or {config_cls.__name__}.")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# --- Define the Main SparkTTS Config Class ---
|
| 169 |
+
|
| 170 |
+
class SparkTTSConfig(PretrainedConfig):
|
| 171 |
+
r"""
|
| 172 |
+
Main configuration class for SparkTTSModel, including nested BiCodec configuration.
|
| 173 |
+
Args:
|
| 174 |
+
llm_model_name_or_path (`str`, *optional*, defaults to `"./LLM"`): Path/ID for LLM.
|
| 175 |
+
bicodec_model_name_or_path (`str`, *optional*, defaults to `"./BiCodec"`): Path/ID for BiCodec checkpoint.
|
| 176 |
+
wav2vec2_model_name_or_path (`str`, *optional*, defaults to `"./wav2vec2-large-xlsr-53"`): Path/ID for Wav2Vec2.
|
| 177 |
+
sample_rate (`int`, *optional*, defaults to 16000): Audio sample rate.
|
| 178 |
+
# ... (other top-level args: highpass_cutoff_freq, latent_hop_length, ref_segment_duration, volume_normalize) ...
|
| 179 |
+
bicodec_config (`dict`, *optional*): Dictionary to initialize `SparkTTSBiCodecConfig`.
|
| 180 |
+
torch_dtype (`str`, *optional*, defaults to `"auto"`): Torch dtype.
|
| 181 |
+
kwargs (*optional*): Dictionary of keyword arguments.
|
| 182 |
+
"""
|
| 183 |
+
model_type = "spark-tts"
|
| 184 |
+
# Map the key in config.json to the intermediate BiCodec config class
|
| 185 |
+
sub_configs = {"bicodec_config": SparkTTSBiCodecConfig}
|
| 186 |
+
attribute_map = {"hidden_size": "d_model"} # Example
|
| 187 |
+
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
llm_model_name_or_path="./LLM",
|
| 191 |
+
bicodec_model_name_or_path="./BiCodec",
|
| 192 |
+
wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53",
|
| 193 |
+
sample_rate=16000,
|
| 194 |
+
highpass_cutoff_freq=40,
|
| 195 |
+
latent_hop_length=320,
|
| 196 |
+
ref_segment_duration=6.0,
|
| 197 |
+
volume_normalize=True,
|
| 198 |
+
bicodec_config=None, # Expects a dictionary or None
|
| 199 |
+
torch_dtype="auto",
|
| 200 |
+
**kwargs,
|
| 201 |
+
):
|
| 202 |
+
# --- Top-level parameters ---
|
| 203 |
+
self.llm_model_name_or_path = llm_model_name_or_path
|
| 204 |
+
self.bicodec_model_name_or_path = bicodec_model_name_or_path
|
| 205 |
+
self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
|
| 206 |
+
self.sample_rate = sample_rate
|
| 207 |
+
self.highpass_cutoff_freq = highpass_cutoff_freq
|
| 208 |
+
self.latent_hop_length = latent_hop_length
|
| 209 |
+
self.ref_segment_duration = ref_segment_duration
|
| 210 |
+
self.volume_normalize = volume_normalize
|
| 211 |
+
self.torch_dtype = torch_dtype
|
| 212 |
+
|
| 213 |
+
# --- Nested BiCodec Configuration ---
|
| 214 |
+
# Instantiate the intermediate BiCodec config class, which will handle its own sub-configs
|
| 215 |
+
if isinstance(bicodec_config, dict):
|
| 216 |
+
self.bicodec_config = self.sub_configs["bicodec_config"](**bicodec_config)
|
| 217 |
+
elif bicodec_config is None:
|
| 218 |
+
logger.info("`bicodec_config` not provided. Initializing `SparkTTSBiCodecConfig` with its defaults.")
|
| 219 |
+
self.bicodec_config = self.sub_configs["bicodec_config"]()
|
| 220 |
+
elif isinstance(bicodec_config, self.sub_configs["bicodec_config"]):
|
| 221 |
+
self.bicodec_config = bicodec_config # Use existing instance
|
| 222 |
+
else:
|
| 223 |
+
raise TypeError(f"Invalid type for bicodec_config: {type(bicodec_config)}. Expected dict, None, or SparkTTSBiCodecConfig.")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# Set processor class and auto_map
|
| 227 |
+
kwargs["processor_class"] = kwargs.get("processor_class", "SparkTTSProcessor")
|
| 228 |
+
kwargs["auto_map"] = kwargs.get("auto_map", {
|
| 229 |
+
"AutoConfig": "configuration_spark_tts.SparkTTSConfig",
|
| 230 |
+
"AutoModel": "modeling_spark_tts.SparkTTSModel",
|
| 231 |
+
"AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
|
| 232 |
+
})
|
| 233 |
+
super().__init__(**kwargs)
|
modeling_spark_tts.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
processing_spark_tts.py
ADDED
|
@@ -0,0 +1,889 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 SparkAudio & The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Processor class for SparkTTS. Combines text tokenization and audio feature extraction/processing.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os # Needed for save_pretrained
|
| 20 |
+
import re # For decoding
|
| 21 |
+
import torch
|
| 22 |
+
import numpy as np
|
| 23 |
+
import soundfile as sf # For audio loading
|
| 24 |
+
import soxr # For resampling
|
| 25 |
+
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Optional, Union, List, Dict, Tuple, Any
|
| 28 |
+
|
| 29 |
+
from transformers.processing_utils import ProcessorMixin
|
| 30 |
+
from transformers.tokenization_utils_base import BatchEncoding # Return type hint
|
| 31 |
+
from transformers.feature_extraction_utils import BatchFeature # Return type hint
|
| 32 |
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
| 33 |
+
from transformers.models.wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
|
| 34 |
+
from transformers.utils import logging, PushToHubMixin # Added PushToHubMixin
|
| 35 |
+
from numpy.lib.stride_tricks import sliding_window_view
|
| 36 |
+
import soxr
|
| 37 |
+
import soundfile
|
| 38 |
+
import random
|
| 39 |
+
|
| 40 |
+
# Import custom config if needed for defaults
|
| 41 |
+
from .configuration_spark_tts import SparkTTSConfig
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# =============================================================================
|
| 47 |
+
# >> START: PASTE CODE FROM sparktts/utils/* HERE <<
|
| 48 |
+
# =============================================================================
|
| 49 |
+
# IMPORTANT: Utility functions needed for processing (audio loading, token parsing)
|
| 50 |
+
# must be defined or imported here.
|
| 51 |
+
|
| 52 |
+
# --- Paste sparktts/utils/audio.py content here ---
|
| 53 |
+
|
| 54 |
+
def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
|
| 55 |
+
"""
|
| 56 |
+
Normalize the volume of an audio signal.
|
| 57 |
+
|
| 58 |
+
Parameters:
|
| 59 |
+
audio (numpy array): Input audio signal array.
|
| 60 |
+
coeff (float): Target coefficient for normalization, default is 0.2.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
numpy array: The volume-normalized audio signal.
|
| 64 |
+
"""
|
| 65 |
+
# Sort the absolute values of the audio signal
|
| 66 |
+
temp = np.sort(np.abs(audio))
|
| 67 |
+
|
| 68 |
+
# If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
|
| 69 |
+
if temp[-1] < 0.1:
|
| 70 |
+
scaling_factor = max(
|
| 71 |
+
temp[-1], 1e-3
|
| 72 |
+
) # Prevent division by zero with a small constant
|
| 73 |
+
audio = audio / scaling_factor * 0.1
|
| 74 |
+
|
| 75 |
+
# Filter out values less than 0.01 from temp
|
| 76 |
+
temp = temp[temp > 0.01]
|
| 77 |
+
L = temp.shape[0] # Length of the filtered array
|
| 78 |
+
|
| 79 |
+
# If there are fewer than or equal to 10 significant values, return the audio without further processing
|
| 80 |
+
if L <= 10:
|
| 81 |
+
return audio
|
| 82 |
+
|
| 83 |
+
# Compute the average of the top 10% to 1% of values in temp
|
| 84 |
+
volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
|
| 85 |
+
|
| 86 |
+
# Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
|
| 87 |
+
audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
|
| 88 |
+
|
| 89 |
+
# Ensure the maximum absolute value in the audio does not exceed 1
|
| 90 |
+
max_value = np.max(np.abs(audio))
|
| 91 |
+
if max_value > 1:
|
| 92 |
+
audio = audio / max_value
|
| 93 |
+
|
| 94 |
+
return audio
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def load_audio(
|
| 98 |
+
adfile: Path,
|
| 99 |
+
sampling_rate: int = None,
|
| 100 |
+
length: int = None,
|
| 101 |
+
volume_normalize: bool = False,
|
| 102 |
+
segment_duration: int = None,
|
| 103 |
+
) -> np.ndarray:
|
| 104 |
+
r"""Load audio file with target sampling rate and lsength
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
adfile (Path): path to audio file.
|
| 108 |
+
sampling_rate (int, optional): target sampling rate. Defaults to None.
|
| 109 |
+
length (int, optional): target audio length. Defaults to None.
|
| 110 |
+
volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
|
| 111 |
+
segment_duration (int): random select a segment with duration of {segment_duration}s.
|
| 112 |
+
Defualt to None which means the whole audio will be used.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
audio (np.ndarray): audio
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
audio, sr = soundfile.read(adfile)
|
| 119 |
+
if len(audio.shape) > 1:
|
| 120 |
+
audio = audio[:, 0]
|
| 121 |
+
|
| 122 |
+
if sampling_rate is not None and sr != sampling_rate:
|
| 123 |
+
audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
|
| 124 |
+
sr = sampling_rate
|
| 125 |
+
|
| 126 |
+
if segment_duration is not None:
|
| 127 |
+
seg_length = int(sr * segment_duration)
|
| 128 |
+
audio = random_select_audio_segment(audio, seg_length)
|
| 129 |
+
|
| 130 |
+
# Audio volume normalize
|
| 131 |
+
if volume_normalize:
|
| 132 |
+
audio = audio_volume_normalize(audio)
|
| 133 |
+
# check the audio length
|
| 134 |
+
if length is not None:
|
| 135 |
+
assert abs(audio.shape[0] - length) < 1000
|
| 136 |
+
if audio.shape[0] > length:
|
| 137 |
+
audio = audio[:length]
|
| 138 |
+
else:
|
| 139 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
| 140 |
+
return audio
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
|
| 144 |
+
"""get an audio segment given the length
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
audio (np.ndarray):
|
| 148 |
+
length (int): audio length = sampling_rate * duration
|
| 149 |
+
"""
|
| 150 |
+
if audio.shape[0] < length:
|
| 151 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
| 152 |
+
start_index = random.randint(0, audio.shape[0] - length)
|
| 153 |
+
end_index = int(start_index + length)
|
| 154 |
+
|
| 155 |
+
return audio[start_index:end_index]
|
| 156 |
+
|
| 157 |
+
def get_ref_clip(wav: np.ndarray, config) -> np.ndarray: # Needs access to config attributes
|
| 158 |
+
"""Get reference audio clip for speaker embedding."""
|
| 159 |
+
# Make sure config has sample_rate, ref_segment_duration, latent_hop_length
|
| 160 |
+
if not all(hasattr(config, attr) for attr in ['sample_rate', 'ref_segment_duration', 'latent_hop_length']):
|
| 161 |
+
raise AttributeError("Config object missing required attributes for get_ref_clip")
|
| 162 |
+
ref_segment_length = (
|
| 163 |
+
int(config.sample_rate * config.ref_segment_duration)
|
| 164 |
+
// config.latent_hop_length
|
| 165 |
+
* config.latent_hop_length
|
| 166 |
+
)
|
| 167 |
+
wav_length = len(wav)
|
| 168 |
+
if ref_segment_length > wav_length:
|
| 169 |
+
wav = np.tile(wav, ref_segment_length // wav_length + 1)
|
| 170 |
+
return wav[:ref_segment_length]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# --- Paste sparktts/utils/token_parser.py content here ---
|
| 174 |
+
|
| 175 |
+
TASK_TOKEN_MAP = {
|
| 176 |
+
"vc": "<|task_vc|>",
|
| 177 |
+
"tts": "<|task_tts|>",
|
| 178 |
+
"asr": "<|task_asr|>",
|
| 179 |
+
"s2s": "<|task_s2s|>",
|
| 180 |
+
"t2s": "<|task_t2s|>",
|
| 181 |
+
"understand": "<|task_understand|>",
|
| 182 |
+
"caption": "<|task_cap|>",
|
| 183 |
+
"controllable_tts": "<|task_controllable_tts|>",
|
| 184 |
+
"prompt_tts": "<|task_prompt_tts|>",
|
| 185 |
+
"speech_edit": "<|task_edit|>",
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
LEVELS_MAP = {
|
| 189 |
+
"very_low": 0,
|
| 190 |
+
"low": 1,
|
| 191 |
+
"moderate": 2,
|
| 192 |
+
"high": 3,
|
| 193 |
+
"very_high": 4,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
LEVELS_MAP_UI = {
|
| 197 |
+
1: 'very_low',
|
| 198 |
+
2: 'low',
|
| 199 |
+
3: 'moderate',
|
| 200 |
+
4: 'high',
|
| 201 |
+
5: 'very_high'
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
GENDER_MAP = {
|
| 205 |
+
"female": 0,
|
| 206 |
+
"male": 1,
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
|
| 210 |
+
|
| 211 |
+
EMO_MAP = {
|
| 212 |
+
"UNKNOWN": 0,
|
| 213 |
+
"NEUTRAL": 1,
|
| 214 |
+
"ANGRY": 2,
|
| 215 |
+
"HAPPY": 3,
|
| 216 |
+
"SAD": 4,
|
| 217 |
+
"FEARFUL": 5,
|
| 218 |
+
"DISGUSTED": 6,
|
| 219 |
+
"SURPRISED": 7,
|
| 220 |
+
"SARCASTIC": 8,
|
| 221 |
+
"EXCITED": 9,
|
| 222 |
+
"SLEEPY": 10,
|
| 223 |
+
"CONFUSED": 11,
|
| 224 |
+
"EMPHASIS": 12,
|
| 225 |
+
"LAUGHING": 13,
|
| 226 |
+
"SINGING": 14,
|
| 227 |
+
"WORRIED": 15,
|
| 228 |
+
"WHISPER": 16,
|
| 229 |
+
"ANXIOUS": 17,
|
| 230 |
+
"NO-AGREEMENT": 18,
|
| 231 |
+
"APOLOGETIC": 19,
|
| 232 |
+
"CONCERNED": 20,
|
| 233 |
+
"ENUNCIATED": 21,
|
| 234 |
+
"ASSERTIVE": 22,
|
| 235 |
+
"ENCOURAGING": 23,
|
| 236 |
+
"CONTEMPT": 24,
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class TokenParser:
|
| 241 |
+
"""Turn label to special token"""
|
| 242 |
+
|
| 243 |
+
def __init__(self):
|
| 244 |
+
pass
|
| 245 |
+
|
| 246 |
+
"""Parse the attributes of a person."""
|
| 247 |
+
|
| 248 |
+
def __init__(self):
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
@staticmethod
|
| 252 |
+
def age(age: str) -> str:
|
| 253 |
+
"""Turn age token."""
|
| 254 |
+
age_id = AGE_MAP[age]
|
| 255 |
+
return f"<|age_{age_id}|>"
|
| 256 |
+
|
| 257 |
+
@staticmethod
|
| 258 |
+
def gender(gender: str) -> str:
|
| 259 |
+
"""Turn gender token."""
|
| 260 |
+
gender_id = GENDER_MAP[gender]
|
| 261 |
+
return f"<|gender_{gender_id}|>"
|
| 262 |
+
|
| 263 |
+
@staticmethod
|
| 264 |
+
def mel_value(mel: int):
|
| 265 |
+
"""Turn special token of mel scale pitch."""
|
| 266 |
+
mel = max(0, int(mel))
|
| 267 |
+
mel = min(1000, int(mel))
|
| 268 |
+
return f"<|pitch_value_{mel}|>"
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def mel_level(level: str):
|
| 272 |
+
"""Turn special token of mel level."""
|
| 273 |
+
level_tag = LEVELS_MAP[level]
|
| 274 |
+
return f"<|pitch_label_{level_tag}|>"
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
def pitch_var_value(pitch_std: int):
|
| 278 |
+
"""Turn special token of pitch_std value."""
|
| 279 |
+
assert isinstance(pitch_std, int)
|
| 280 |
+
pitch_std = max(0, int(pitch_std))
|
| 281 |
+
pitch_std = min(10, int(pitch_std))
|
| 282 |
+
return f"<|pitch_var_value_{pitch_std}|>"
|
| 283 |
+
|
| 284 |
+
@staticmethod
|
| 285 |
+
def pitch_var_level(level: str):
|
| 286 |
+
"""Turn special token of pitch std level."""
|
| 287 |
+
level_tag = LEVELS_MAP[level]
|
| 288 |
+
return f"<|pitch_var_label_{level_tag}|>"
|
| 289 |
+
|
| 290 |
+
@staticmethod
|
| 291 |
+
def loudness_value(loudness: int):
|
| 292 |
+
"""Turn special toak of loudness value [0, 30]"""
|
| 293 |
+
assert loudness >= 0
|
| 294 |
+
loudness = max(0, int(loudness))
|
| 295 |
+
loudness = min(30, int(loudness))
|
| 296 |
+
return f"<|loudness_value_{loudness}|>"
|
| 297 |
+
|
| 298 |
+
@staticmethod
|
| 299 |
+
def loudness_level(level: str):
|
| 300 |
+
"""Turn special token of loudness level."""
|
| 301 |
+
level_tag = LEVELS_MAP[level]
|
| 302 |
+
return f"<|loudness_label_{level_tag}|>"
|
| 303 |
+
|
| 304 |
+
@staticmethod
|
| 305 |
+
def speed_value(speed: int):
|
| 306 |
+
"""Turn special token of speed value."""
|
| 307 |
+
speed = max(0, int(speed))
|
| 308 |
+
speed = min(10, int(speed))
|
| 309 |
+
return f"<|speed_value_{speed}|>"
|
| 310 |
+
|
| 311 |
+
@staticmethod
|
| 312 |
+
def speed_level(level: str):
|
| 313 |
+
"""Turn special token of speed level."""
|
| 314 |
+
level_tag = LEVELS_MAP[level]
|
| 315 |
+
return f"<|speed_label_{level_tag}|>"
|
| 316 |
+
|
| 317 |
+
@staticmethod
|
| 318 |
+
def task(task: str) -> str:
|
| 319 |
+
"""Turn special token of task."""
|
| 320 |
+
assert task in TASK_TOKEN_MAP.keys()
|
| 321 |
+
|
| 322 |
+
return TASK_TOKEN_MAP[task]
|
| 323 |
+
|
| 324 |
+
@staticmethod
|
| 325 |
+
def emotion(emotion: str):
|
| 326 |
+
emo_id = EMO_MAP[emotion]
|
| 327 |
+
|
| 328 |
+
return f"<|emotion_{emo_id}|>"
|
| 329 |
+
|
| 330 |
+
# =============================================================================
|
| 331 |
+
# >> END: PASTE CODE FROM sparktts/utils/* HERE <<
|
| 332 |
+
# =============================================================================
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class SparkTTSProcessor(ProcessorMixin, PushToHubMixin): # Added PushToHubMixin
|
| 336 |
+
r"""
|
| 337 |
+
Constructs a SparkTTS processor which wraps a text tokenizer and relevant audio processing logic.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
tokenizer ([`PreTrainedTokenizer`]):
|
| 341 |
+
An instance of [`PreTrainedTokenizer`]. This handles the text tokenization for the LLM.
|
| 342 |
+
feature_extractor ([`Wav2Vec2FeatureExtractor`]):
|
| 343 |
+
An instance of [`Wav2Vec2FeatureExtractor`]. Although Wav2Vec2 features are extracted
|
| 344 |
+
within the model's `tokenize_audio`, the extractor's configuration (like sampling rate)
|
| 345 |
+
is useful, and it aligns with the ProcessorMixin pattern.
|
| 346 |
+
config ([`SparkTTSConfig`], *optional*):
|
| 347 |
+
An instance of [`SparkTTSConfig`] to access configuration parameters like sample rate.
|
| 348 |
+
"""
|
| 349 |
+
attributes = ["tokenizer", "feature_extractor"]
|
| 350 |
+
tokenizer_class = "AutoTokenizer"
|
| 351 |
+
feature_extractor_class = "Wav2Vec2FeatureExtractor" # Keep for consistency
|
| 352 |
+
|
| 353 |
+
def __init__(self, tokenizer, feature_extractor, config: Optional[SparkTTSConfig] = None, **kwargs):
|
| 354 |
+
super().__init__(tokenizer=tokenizer, feature_extractor=feature_extractor, **kwargs)
|
| 355 |
+
self.model = None
|
| 356 |
+
self.config = config
|
| 357 |
+
# Set sampling rate
|
| 358 |
+
if config and hasattr(config, 'sample_rate'):
|
| 359 |
+
self.sampling_rate = config.sample_rate
|
| 360 |
+
elif feature_extractor and hasattr(feature_extractor, 'sampling_rate'):
|
| 361 |
+
self.sampling_rate = feature_extractor.sampling_rate
|
| 362 |
+
else:
|
| 363 |
+
self.sampling_rate = 16000
|
| 364 |
+
logger.warning(f"Could not determine sampling rate. Defaulting to {self.sampling_rate} Hz.")
|
| 365 |
+
|
| 366 |
+
# # Ensure tokenizer pad token
|
| 367 |
+
# if self.tokenizer.pad_token is None:
|
| 368 |
+
# if self.tokenizer.eos_token is not None:
|
| 369 |
+
# logger.warning("Tokenizer does not have a pad token. Setting pad_token to eos_token.")
|
| 370 |
+
# self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 371 |
+
# else:
|
| 372 |
+
# logger.warning("Tokenizer lacks pad and eos token. Adding default pad token '<|pad|>'.")
|
| 373 |
+
# self.tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
|
| 374 |
+
|
| 375 |
+
def link_model(self, model):
|
| 376 |
+
"""Links the processor to a SparkTTSModel instance for audio processing calls."""
|
| 377 |
+
if not hasattr(model, 'tokenize_audio') or not hasattr(model, 'detokenize_audio'):
|
| 378 |
+
raise TypeError("The provided model instance does not have the required 'tokenize_audio' and 'detokenize_audio' methods.")
|
| 379 |
+
if not hasattr(model, 'config'):
|
| 380 |
+
logger.warning("Linked model does not have a 'config' attribute. Some processor functionalities might rely on it.")
|
| 381 |
+
|
| 382 |
+
self.model = model
|
| 383 |
+
logger.info("SparkTTSModel successfully linked to the processor.")
|
| 384 |
+
# Update sampling rate based on linked model's config if available
|
| 385 |
+
if hasattr(model, 'config') and hasattr(model.config, 'sample_rate'):
|
| 386 |
+
if self.sampling_rate != model.config.sample_rate:
|
| 387 |
+
logger.info(f"Updating processor sampling rate from {self.sampling_rate} to {model.config.sample_rate} based on linked model config.")
|
| 388 |
+
self.sampling_rate = model.config.sample_rate
|
| 389 |
+
# Also update feature extractor sampling rate if it differs
|
| 390 |
+
if hasattr(self, 'feature_extractor') and self.feature_extractor.sampling_rate != model.config.sample_rate:
|
| 391 |
+
logger.info(f"Updating feature_extractor sampling rate from {self.feature_extractor.sampling_rate} to {model.config.sample_rate}.")
|
| 392 |
+
self.feature_extractor.sampling_rate = model.config.sample_rate
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def __call__(
|
| 396 |
+
self,
|
| 397 |
+
text: str,
|
| 398 |
+
prompt_speech_path: Optional[Union[str, Path]] = None,
|
| 399 |
+
prompt_text: Optional[str] = None,
|
| 400 |
+
gender: Optional[str] = None,
|
| 401 |
+
pitch: Optional[str] = None,
|
| 402 |
+
speed: Optional[str] = None,
|
| 403 |
+
return_tensors: Optional[str] = "pt",
|
| 404 |
+
**kwargs, # Allow passing other args like padding, truncation to tokenizer
|
| 405 |
+
) -> BatchEncoding:
|
| 406 |
+
"""
|
| 407 |
+
Processes the input text and optional prompt audio/control parameters into a format suitable for [`SparkTTSModel`].
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
text (`str`):
|
| 411 |
+
The main text to be synthesized.
|
| 412 |
+
prompt_speech_path (`str` or `Path`, *optional*):
|
| 413 |
+
Path to the prompt audio file for voice cloning. Required if `gender` is not set.
|
| 414 |
+
prompt_text (`str`, *optional*):
|
| 415 |
+
Transcript of the prompt audio. Used only in voice cloning mode.
|
| 416 |
+
gender (`str`, *optional*):
|
| 417 |
+
Target gender ("male" or "female") for controllable synthesis. If set, enables control mode.
|
| 418 |
+
pitch (`str`, *optional*):
|
| 419 |
+
Target pitch level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set.
|
| 420 |
+
speed (`str`, *optional*):
|
| 421 |
+
Target speed level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set.
|
| 422 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
| 423 |
+
If set, will return tensors instead of list of python integers. Only "pt" (PyTorch) is supported currently.
|
| 424 |
+
**kwargs:
|
| 425 |
+
Additional arguments passed to the underlying tokenizer's `__call__` method.
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
[`BatchEncoding`]: A dictionary containing the `input_ids` and `attention_mask` for the LLM.
|
| 429 |
+
In voice cloning mode, it also includes `global_token_ids_prompt` (torch.Tensor) representing the
|
| 430 |
+
global tokens extracted from the prompt audio.
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
global_token_ids_prompt = None # Initialize
|
| 434 |
+
|
| 435 |
+
# Determine mode: Control TTS or Voice Cloning (Prompt TTS)
|
| 436 |
+
is_control_mode = gender is not None
|
| 437 |
+
is_cloning_mode = prompt_speech_path is not None and not is_control_mode
|
| 438 |
+
|
| 439 |
+
if is_control_mode:
|
| 440 |
+
# --- Controllable TTS Prompt Construction ---
|
| 441 |
+
if not all([pitch, speed]):
|
| 442 |
+
raise ValueError("For controllable TTS, 'gender', 'pitch', and 'speed' must all be provided.")
|
| 443 |
+
if prompt_speech_path is not None:
|
| 444 |
+
logger.warning("`prompt_speech_path` provided but ignored because `gender` is set (controllable TTS mode).")
|
| 445 |
+
|
| 446 |
+
if not all(k in GENDER_MAP for k in [gender]): # Basic check
|
| 447 |
+
raise ValueError(f"Invalid gender provided: {gender}. Must be one of {list(GENDER_MAP.keys())}")
|
| 448 |
+
if not all(k in LEVELS_MAP for k in [pitch, speed]): # Basic check
|
| 449 |
+
raise ValueError(f"Invalid pitch or speed level provided. Must be one of {list(LEVELS_MAP.keys())}")
|
| 450 |
+
|
| 451 |
+
gender_id = GENDER_MAP[gender]
|
| 452 |
+
pitch_level_id = LEVELS_MAP[pitch]
|
| 453 |
+
speed_level_id = LEVELS_MAP[speed]
|
| 454 |
+
|
| 455 |
+
pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
|
| 456 |
+
speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
|
| 457 |
+
gender_tokens = f"<|gender_{gender_id}|>"
|
| 458 |
+
|
| 459 |
+
attribute_tokens = "".join([gender_tokens, pitch_label_tokens, speed_label_tokens])
|
| 460 |
+
|
| 461 |
+
prompt_list = [
|
| 462 |
+
TASK_TOKEN_MAP["controllable_tts"],
|
| 463 |
+
"<|start_content|>",
|
| 464 |
+
text,
|
| 465 |
+
"<|end_content|>",
|
| 466 |
+
"<|start_style_label|>",
|
| 467 |
+
attribute_tokens,
|
| 468 |
+
"<|end_style_label|>",
|
| 469 |
+
]
|
| 470 |
+
prompt_string = "".join(prompt_list)
|
| 471 |
+
|
| 472 |
+
elif is_cloning_mode:
|
| 473 |
+
# --- Voice Cloning Prompt Construction ---
|
| 474 |
+
if self.model is None:
|
| 475 |
+
raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before performing voice cloning.")
|
| 476 |
+
prompt_speech_path = Path(prompt_speech_path) # Ensure it's a Path object
|
| 477 |
+
if not prompt_speech_path.exists():
|
| 478 |
+
raise FileNotFoundError(f"Prompt audio file not found: {prompt_speech_path}")
|
| 479 |
+
|
| 480 |
+
# Load and process prompt audio
|
| 481 |
+
try:
|
| 482 |
+
model_config = self.model.config if self.model and hasattr(self.model, 'config') else self.config
|
| 483 |
+
if model_config is None:
|
| 484 |
+
raise ValueError("Configuration not available in processor or linked model.")
|
| 485 |
+
|
| 486 |
+
# Load main wav
|
| 487 |
+
wav = load_audio(
|
| 488 |
+
prompt_speech_path,
|
| 489 |
+
sampling_rate=self.sampling_rate,
|
| 490 |
+
volume_normalize=getattr(model_config, 'volume_normalize', True), # Use getattr for safety
|
| 491 |
+
)
|
| 492 |
+
# Get reference clip
|
| 493 |
+
wav_ref_np = get_ref_clip(wav, model_config) # Pass config object
|
| 494 |
+
wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float()
|
| 495 |
+
wav_tensor = torch.from_numpy(wav).unsqueeze(0).float()
|
| 496 |
+
|
| 497 |
+
# Tokenize using the linked model's method
|
| 498 |
+
# Assuming tokenize_audio returns tensors with batch dim 1: [1, N_global], [1, N_semantic]
|
| 499 |
+
global_tokens_tensor, semantic_tokens_tensor = self.model.tokenize_audio(wav_tensor, wav_ref)
|
| 500 |
+
|
| 501 |
+
# Store the global tokens tensor (with batch dim) for the output dict
|
| 502 |
+
global_token_ids_prompt = global_tokens_tensor # Keep batch dim [1, N_global]
|
| 503 |
+
|
| 504 |
+
# Convert tensors to lists of ints for string formatting
|
| 505 |
+
global_token_list = global_tokens_tensor.squeeze().tolist() # Remove batch dim -> list
|
| 506 |
+
semantic_token_list = semantic_tokens_tensor.squeeze().tolist() # Remove batch dim -> list
|
| 507 |
+
|
| 508 |
+
except Exception as e:
|
| 509 |
+
logger.error(f"Error processing prompt audio {prompt_speech_path}: {e}")
|
| 510 |
+
import traceback
|
| 511 |
+
traceback.print_exc()
|
| 512 |
+
raise
|
| 513 |
+
|
| 514 |
+
# ==============================================================
|
| 515 |
+
# CORRECTED TOKEN STRING FORMATTING
|
| 516 |
+
# ==============================================================
|
| 517 |
+
# Create individual token strings for each ID
|
| 518 |
+
global_tokens_str = "".join([f"<|bicodec_global_{gid}|>" for gid in global_token_list])
|
| 519 |
+
semantic_tokens_str = "".join([f"<|bicodec_semantic_{sid}|>" for sid in semantic_token_list])
|
| 520 |
+
# ==============================================================
|
| 521 |
+
|
| 522 |
+
# Construct prompt list based on presence of prompt_text
|
| 523 |
+
if prompt_text is not None and prompt_text.strip(): # Check if prompt_text is meaningful
|
| 524 |
+
logger.info("Using prompt text in voice cloning prompt.")
|
| 525 |
+
prompt_list = [
|
| 526 |
+
TASK_TOKEN_MAP["tts"], # Or maybe TASK_TOKEN_MAP["prompt_tts"]? Check original logic. Assuming "tts".
|
| 527 |
+
"<|start_content|>",
|
| 528 |
+
prompt_text, # Transcript first
|
| 529 |
+
text, # Then target text
|
| 530 |
+
"<|end_content|>",
|
| 531 |
+
"<|start_global_token|>",
|
| 532 |
+
global_tokens_str,
|
| 533 |
+
"<|end_global_token|>",
|
| 534 |
+
"<|start_semantic_token|>",
|
| 535 |
+
semantic_tokens_str,
|
| 536 |
+
# "<|end_semantic_token|>", # Original code didn't have this marker here
|
| 537 |
+
]
|
| 538 |
+
else:
|
| 539 |
+
# Simpler prompt without semantic tokens if no transcript provided
|
| 540 |
+
logger.info("No prompt text provided, using text-only voice cloning prompt.")
|
| 541 |
+
prompt_list = [
|
| 542 |
+
TASK_TOKEN_MAP["tts"], # Or maybe TASK_TOKEN_MAP["prompt_tts"]?
|
| 543 |
+
"<|start_content|>",
|
| 544 |
+
text, # Only target text
|
| 545 |
+
"<|end_content|>",
|
| 546 |
+
"<|start_global_token|>",
|
| 547 |
+
global_tokens_str,
|
| 548 |
+
"<|end_global_token|>",
|
| 549 |
+
]
|
| 550 |
+
prompt_string = "".join(prompt_list)
|
| 551 |
+
logger.debug(f"Generated prompt string (cloning): {prompt_string[:200]}...") # Log start of prompt
|
| 552 |
+
|
| 553 |
+
else:
|
| 554 |
+
raise ValueError("Invalid input combination. Either provide `prompt_speech_path` for cloning or (`gender`, `pitch`, `speed`) for control.")
|
| 555 |
+
|
| 556 |
+
# --- Tokenize the final prompt string ---
|
| 557 |
+
# print(f"Tokenizing prompt: {prompt_string}")
|
| 558 |
+
inputs = self.tokenizer(
|
| 559 |
+
prompt_string,
|
| 560 |
+
return_tensors=return_tensors,
|
| 561 |
+
padding=kwargs.get("padding", False), # Often False for generation prompts unless batching > 1
|
| 562 |
+
truncation=kwargs.get("truncation", True),
|
| 563 |
+
max_length=kwargs.get("max_length", self.tokenizer.model_max_length),
|
| 564 |
+
add_special_tokens=kwargs.get("add_special_tokens", True), # Usually True unless handled manually
|
| 565 |
+
return_attention_mask=kwargs.get("return_attention_mask", True), # Need attention mask
|
| 566 |
+
**{k: v for k, v in kwargs.items() if k not in ["padding", "truncation", "max_length", "add_special_tokens", "return_attention_mask"]}
|
| 567 |
+
)
|
| 568 |
+
logger.debug(f"Tokenized input_ids shape: {inputs['input_ids'].shape}")
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
# Add the prompt's global tokens (as tensor with batch dim) to the output if in cloning mode
|
| 572 |
+
if is_cloning_mode and global_token_ids_prompt is not None:
|
| 573 |
+
if return_tensors == "pt":
|
| 574 |
+
inputs["global_token_ids_prompt"] = global_token_ids_prompt # Already has batch dim [1, N_global]
|
| 575 |
+
else:
|
| 576 |
+
# Handle non-tensor return if necessary
|
| 577 |
+
inputs["global_token_ids_prompt"] = global_token_ids_prompt.tolist()
|
| 578 |
+
|
| 579 |
+
return inputs
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def decode(
|
| 583 |
+
self,
|
| 584 |
+
generated_ids: torch.Tensor,
|
| 585 |
+
global_token_ids_prompt: Optional[torch.Tensor] = None,
|
| 586 |
+
input_ids_len: Optional[int] = None,
|
| 587 |
+
skip_special_tokens: bool = True,
|
| 588 |
+
) -> Dict[str, Any]:
|
| 589 |
+
"""
|
| 590 |
+
Decodes the generated token IDs from [`SparkTTSModel`] into an audio waveform.
|
| 591 |
+
|
| 592 |
+
Args:
|
| 593 |
+
generated_ids (`torch.Tensor`):
|
| 594 |
+
Tensor of token IDs generated by `model.generate()`, including the input prompt part. Shape [B, seq_len].
|
| 595 |
+
global_token_ids_prompt (`torch.Tensor`, *optional*):
|
| 596 |
+
The global tokens extracted from the prompt audio during the `__call__` step (for voice cloning).
|
| 597 |
+
Shape [B, N_global]. Required if the generation was for voice cloning.
|
| 598 |
+
input_ids_len (`int`, *optional*):
|
| 599 |
+
The length of the original input prompt `input_ids` fed to `model.generate()`. Required to
|
| 600 |
+
correctly isolate the newly generated tokens.
|
| 601 |
+
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
| 602 |
+
Whether to skip special tokens during the text decoding step (used to extract audio tokens).
|
| 603 |
+
|
| 604 |
+
Returns:
|
| 605 |
+
Dict[str, Any]: A dictionary containing:
|
| 606 |
+
- "audio": The decoded audio waveform as a NumPy array. Shape [T_audio] (if B=1) or [B, T_audio].
|
| 607 |
+
- "sampling_rate": The sampling rate of the audio.
|
| 608 |
+
"""
|
| 609 |
+
if self.model is None:
|
| 610 |
+
raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before decoding.")
|
| 611 |
+
if input_ids_len is None:
|
| 612 |
+
raise ValueError("`input_ids_len` (length of the prompt input_ids) must be provided for decoding.")
|
| 613 |
+
|
| 614 |
+
# --- Isolate generated part and decode text ---
|
| 615 |
+
# Assumes generated_ids has shape [B, full_seq_len]
|
| 616 |
+
# Handle case where generated sequence is shorter than prompt (shouldn't happen with max_new_tokens > 0)
|
| 617 |
+
if generated_ids.shape[1] < input_ids_len:
|
| 618 |
+
logger.warning(f"Generated sequence length ({generated_ids.shape[1]}) is shorter than input prompt length ({input_ids_len}). Decoding might be incorrect.")
|
| 619 |
+
output_only_ids = generated_ids[:, input_ids_len:] # Will be empty if equal
|
| 620 |
+
else:
|
| 621 |
+
output_only_ids = generated_ids[:, input_ids_len:]
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
# Decode the generated part to find audio tokens
|
| 625 |
+
# Need to handle batch decoding if B > 1
|
| 626 |
+
# print("decode token", self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=False))
|
| 627 |
+
decoded_texts = self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=skip_special_tokens)
|
| 628 |
+
|
| 629 |
+
# --- Extract Audio Tokens ---
|
| 630 |
+
# Handle batch processing correctly
|
| 631 |
+
batch_size = generated_ids.shape[0]
|
| 632 |
+
all_semantic_ids = []
|
| 633 |
+
all_global_tokens = []
|
| 634 |
+
successful_indices = [] # Keep track of which batch items were successful
|
| 635 |
+
|
| 636 |
+
for i in range(batch_size):
|
| 637 |
+
decoded_text = decoded_texts[i]
|
| 638 |
+
current_semantic_ids = None
|
| 639 |
+
current_global_tokens = None
|
| 640 |
+
|
| 641 |
+
# Extract semantic tokens
|
| 642 |
+
try:
|
| 643 |
+
pred_semantic_indices = [int(token) for token in re.findall(r"bicodec_semantic_(\d+)", decoded_text)]
|
| 644 |
+
if not pred_semantic_indices:
|
| 645 |
+
logger.warning(f"Batch item {i}: No semantic tokens found in decoded text: '{decoded_text[:200]}...'")
|
| 646 |
+
continue # Skip this item
|
| 647 |
+
|
| 648 |
+
current_semantic_ids = torch.tensor(pred_semantic_indices).long() # Shape [N_semantic]
|
| 649 |
+
except Exception as e:
|
| 650 |
+
logger.error(f"Batch item {i}: Error parsing semantic tokens from: '{decoded_text[:200]}...'. Error: {e}")
|
| 651 |
+
continue # Skip this item
|
| 652 |
+
|
| 653 |
+
# Determine global tokens
|
| 654 |
+
if global_token_ids_prompt is not None:
|
| 655 |
+
# Cloning mode: Use the provided prompt global tokens for this batch item
|
| 656 |
+
if global_token_ids_prompt.shape[0] != batch_size:
|
| 657 |
+
raise ValueError(f"Batch size mismatch: generated_ids has {batch_size}, but global_token_ids_prompt has {global_token_ids_prompt.shape[0]}.")
|
| 658 |
+
current_global_tokens = global_token_ids_prompt[i] # Shape [N_global]
|
| 659 |
+
else:
|
| 660 |
+
# Control mode: Extract global tokens from the generated text
|
| 661 |
+
try:
|
| 662 |
+
pred_global_indices = [int(token) for token in re.findall(r"bicodec_global_(\d+)", decoded_text)]
|
| 663 |
+
if not pred_global_indices:
|
| 664 |
+
logger.warning(f"Batch item {i}: No global tokens found in decoded text for control mode: '{decoded_text[:200]}...'")
|
| 665 |
+
continue # Skip this item
|
| 666 |
+
|
| 667 |
+
current_global_tokens = torch.tensor(pred_global_indices).long() # Shape [N_global]
|
| 668 |
+
|
| 669 |
+
except Exception as e:
|
| 670 |
+
logger.error(f"Batch item {i}: Error parsing global tokens from: '{decoded_text[:200]}...'. Error: {e}")
|
| 671 |
+
continue # Skip this item
|
| 672 |
+
|
| 673 |
+
# If both tokens extracted successfully
|
| 674 |
+
all_semantic_ids.append(current_semantic_ids)
|
| 675 |
+
all_global_tokens.append(current_global_tokens)
|
| 676 |
+
successful_indices.append(i)
|
| 677 |
+
|
| 678 |
+
if not successful_indices:
|
| 679 |
+
logger.error("Failed to extract audio tokens for any item in the batch.")
|
| 680 |
+
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
|
| 681 |
+
|
| 682 |
+
# Pad sequences to the max length within the successful batch items for batch detokenization
|
| 683 |
+
# Note: BiCodec might not support batching if sequences have different lengths. Check its implementation.
|
| 684 |
+
# Assuming BiCodec *can* handle batches if padded (or if lengths are naturally equal).
|
| 685 |
+
# This padding might be unnecessary if BiCodec handles variable lengths or if B=1 anyway.
|
| 686 |
+
# For now, let's assume B=1 was handled correctly and skip complex padding.
|
| 687 |
+
if batch_size > 1 and len(successful_indices) < batch_size:
|
| 688 |
+
logger.warning(f"Only successfully decoded {len(successful_indices)} out of {batch_size} batch items.")
|
| 689 |
+
# Further processing might need to handle only the successful items.
|
| 690 |
+
|
| 691 |
+
# Let's proceed assuming B=1 or BiCodec handles batches appropriately.
|
| 692 |
+
# Stack the successful tokens.
|
| 693 |
+
try:
|
| 694 |
+
# Need to ensure tensors have the same length before stacking if BiCodec requires it.
|
| 695 |
+
# If BiCodec handles variable length, stacking might not be needed, just loop and call detokenize.
|
| 696 |
+
# Let's assume B=1 for simplicity of the example, matching original code's likely behavior.
|
| 697 |
+
if len(successful_indices) != 1:
|
| 698 |
+
raise NotImplementedError("Batch decoding (B > 1) requires verification of BiCodec's batch handling and potentially padding.")
|
| 699 |
+
|
| 700 |
+
final_semantic_ids = all_semantic_ids[0].unsqueeze(0) # Add batch dim [1, N_semantic]
|
| 701 |
+
final_global_tokens = all_global_tokens[0].unsqueeze(0) # Add batch dim [1, N_global]
|
| 702 |
+
|
| 703 |
+
except IndexError: # Should not happen if successful_indices is not empty
|
| 704 |
+
logger.error("Internal error during token batch preparation.")
|
| 705 |
+
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
# --- Detokenize Audio ---
|
| 709 |
+
try:
|
| 710 |
+
# Call the linked model's detokenize method
|
| 711 |
+
# print(f"DEBUG: Detokenizing audio with global tokens {final_global_tokens.shape}, semantic tokens {final_semantic_ids.shape}")
|
| 712 |
+
output_wav = self.model.detokenize_audio(final_global_tokens, final_semantic_ids)
|
| 713 |
+
# detokenize_audio now returns numpy array float32 in [-1, 1]
|
| 714 |
+
|
| 715 |
+
# Optional: Double-check dtype here if needed, but should be handled by detokenize_audio now
|
| 716 |
+
# if output_wav.dtype != np.float32:
|
| 717 |
+
# logger.warning(f"Audio dtype after detokenize is {output_wav.dtype}. Converting to float32.")
|
| 718 |
+
# output_wav = output_wav.astype(np.float32)
|
| 719 |
+
# output_wav = np.clip(output_wav, -1.0, 1.0) # Clipping done in detokenize_audio
|
| 720 |
+
|
| 721 |
+
except Exception as e:
|
| 722 |
+
logger.error(f"Error during audio detokenization: {e}")
|
| 723 |
+
import traceback
|
| 724 |
+
traceback.print_exc()
|
| 725 |
+
raise RuntimeError("Audio detokenization failed.") from e
|
| 726 |
+
|
| 727 |
+
return {"audio": output_wav, "sampling_rate": self.sampling_rate}
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
@classmethod
|
| 731 |
+
def from_pretrained(
|
| 732 |
+
cls,
|
| 733 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 734 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 735 |
+
force_download: bool = False,
|
| 736 |
+
local_files_only: bool = False,
|
| 737 |
+
token: Optional[Union[str, bool]] = None,
|
| 738 |
+
revision: str = "main",
|
| 739 |
+
trust_remote_code: bool = False, # Allow passing this, needed for config potentially
|
| 740 |
+
**kwargs,
|
| 741 |
+
):
|
| 742 |
+
r"""
|
| 743 |
+
Instantiate a SparkTTSProcessor from pretrained components.
|
| 744 |
+
"""
|
| 745 |
+
# Pop specific kwargs for this method
|
| 746 |
+
config = kwargs.pop("config", None) # Allow passing config explicitly
|
| 747 |
+
|
| 748 |
+
# --- 1. Load Config (to find component paths) ---
|
| 749 |
+
# We need the config even if the processor doesn't store it permanently,
|
| 750 |
+
# just to find where the tokenizer/feature_extractor live.
|
| 751 |
+
loaded_config = None
|
| 752 |
+
if not isinstance(config, SparkTTSConfig):
|
| 753 |
+
try:
|
| 754 |
+
# Load the specific config class
|
| 755 |
+
loaded_config = SparkTTSConfig.from_pretrained(
|
| 756 |
+
pretrained_model_name_or_path,
|
| 757 |
+
cache_dir=cache_dir,
|
| 758 |
+
force_download=force_download,
|
| 759 |
+
local_files_only=local_files_only,
|
| 760 |
+
token=token,
|
| 761 |
+
revision=revision,
|
| 762 |
+
trust_remote_code=trust_remote_code, # Config might be custom
|
| 763 |
+
**kwargs, # Pass relevant kwargs
|
| 764 |
+
)
|
| 765 |
+
except Exception as e:
|
| 766 |
+
logger.warning(
|
| 767 |
+
f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. "
|
| 768 |
+
f"Attempting to load components from default relative paths ('LLM', 'wav2vec2-large-xlsr-53'). Error: {e}"
|
| 769 |
+
)
|
| 770 |
+
loaded_config = None # Fallback
|
| 771 |
+
else:
|
| 772 |
+
# Config object was passed directly
|
| 773 |
+
loaded_config = config
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
# --- 2. Determine Component Paths ---
|
| 777 |
+
llm_tokenizer_path_or_id = "./LLM" # Default relative path
|
| 778 |
+
w2v_processor_path_or_id = "./wav2vec2-large-xlsr-53" # Default relative path
|
| 779 |
+
|
| 780 |
+
if loaded_config:
|
| 781 |
+
llm_tokenizer_path_or_id = getattr(loaded_config, 'llm_model_name_or_path', llm_tokenizer_path_or_id)
|
| 782 |
+
w2v_processor_path_or_id = getattr(loaded_config, 'wav2vec2_model_name_or_path', w2v_processor_path_or_id)
|
| 783 |
+
|
| 784 |
+
# The component `from_pretrained` methods handle resolving these paths/IDs
|
| 785 |
+
# whether they are relative subfolders of `pretrained_model_name_or_path`
|
| 786 |
+
# or separate Hub IDs.
|
| 787 |
+
|
| 788 |
+
# --- 3. Load Components ---
|
| 789 |
+
# Pass down relevant kwargs for loading components
|
| 790 |
+
component_loading_kwargs = {
|
| 791 |
+
"cache_dir": cache_dir,
|
| 792 |
+
"force_download": force_download,
|
| 793 |
+
"local_files_only": local_files_only,
|
| 794 |
+
"token": token,
|
| 795 |
+
"revision": revision,
|
| 796 |
+
**kwargs # Pass other user kwargs
|
| 797 |
+
}
|
| 798 |
+
try:
|
| 799 |
+
# Tokenizer might require trust_remote_code if its class is custom
|
| 800 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 801 |
+
pretrained_model_name_or_path, # Main path
|
| 802 |
+
subfolder=llm_tokenizer_path_or_id.lstrip('./'), # Specify subfolder relative to main path
|
| 803 |
+
trust_remote_code=trust_remote_code,
|
| 804 |
+
**component_loading_kwargs
|
| 805 |
+
)
|
| 806 |
+
except Exception as e:
|
| 807 |
+
# Fallback: try loading directly using the path/id from config if different
|
| 808 |
+
if llm_tokenizer_path_or_id != "./LLM":
|
| 809 |
+
try:
|
| 810 |
+
logger.info(f"Retrying tokenizer load directly from: {llm_tokenizer_path_or_id}")
|
| 811 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 812 |
+
llm_tokenizer_path_or_id,
|
| 813 |
+
trust_remote_code=trust_remote_code,
|
| 814 |
+
**component_loading_kwargs
|
| 815 |
+
)
|
| 816 |
+
except Exception as e2:
|
| 817 |
+
raise OSError(f"Could not load tokenizer using main path + subfolder or directly from '{llm_tokenizer_path_or_id}'. Error: {e2}") from e
|
| 818 |
+
else:
|
| 819 |
+
raise OSError(f"Could not load tokenizer from subfolder '{llm_tokenizer_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}")
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
try:
|
| 823 |
+
# Feature extractor usually doesn't need trust_remote_code
|
| 824 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
| 825 |
+
pretrained_model_name_or_path, # Main path
|
| 826 |
+
subfolder=w2v_processor_path_or_id.lstrip('./'), # Specify subfolder relative to main path
|
| 827 |
+
**component_loading_kwargs
|
| 828 |
+
)
|
| 829 |
+
except Exception as e:
|
| 830 |
+
# Fallback: try loading directly using the path/id from config if different
|
| 831 |
+
if w2v_processor_path_or_id != "./wav2vec2-large-xlsr-53":
|
| 832 |
+
try:
|
| 833 |
+
logger.info(f"Retrying feature extractor load directly from: {w2v_processor_path_or_id}")
|
| 834 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
| 835 |
+
w2v_processor_path_or_id,
|
| 836 |
+
**component_loading_kwargs
|
| 837 |
+
)
|
| 838 |
+
except Exception as e2:
|
| 839 |
+
raise OSError(f"Could not load feature extractor using main path + subfolder or directly from '{w2v_processor_path_or_id}'. Error: {e2}") from e
|
| 840 |
+
else:
|
| 841 |
+
raise OSError(f"Could not load feature extractor from subfolder '{w2v_processor_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}")
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
# --- 4. Instantiate processor ---
|
| 845 |
+
# Pass the potentially loaded config object (or None)
|
| 846 |
+
return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=loaded_config)
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
def save_pretrained(
|
| 850 |
+
self,
|
| 851 |
+
save_directory: Union[str, os.PathLike],
|
| 852 |
+
push_to_hub: bool = False,
|
| 853 |
+
**kwargs,
|
| 854 |
+
):
|
| 855 |
+
"""
|
| 856 |
+
Save the processor's state (tokenizer and feature extractor files) to a directory.
|
| 857 |
+
|
| 858 |
+
Args:
|
| 859 |
+
save_directory (`str` or `os.PathLike`):
|
| 860 |
+
Directory where the processor files will be saved.
|
| 861 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 862 |
+
Whether or not to push your model to the Hugging Face Hub after saving it.
|
| 863 |
+
**kwargs:
|
| 864 |
+
Additional key word arguments passed along to the `push_to_hub` method.
|
| 865 |
+
"""
|
| 866 |
+
save_directory = Path(save_directory)
|
| 867 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
| 868 |
+
|
| 869 |
+
# Save tokenizer
|
| 870 |
+
self.tokenizer.save_pretrained(str(save_directory), **kwargs)
|
| 871 |
+
|
| 872 |
+
# Save feature extractor
|
| 873 |
+
self.feature_extractor.save_pretrained(str(save_directory), **kwargs)
|
| 874 |
+
|
| 875 |
+
# Save the main processor config (if it exists and has relevant info)
|
| 876 |
+
# Note: The SparkTTSConfig is usually saved with the *model*, not the processor.
|
| 877 |
+
# However, if the processor holds specific config needed for reloading *itself*,
|
| 878 |
+
# it could be saved here. Usually, relying on the model's config is sufficient.
|
| 879 |
+
# if self.config:
|
| 880 |
+
# self.config.save_pretrained(str(save_directory)) # Example if needed
|
| 881 |
+
|
| 882 |
+
logger.info(f"Processor components saved in {save_directory}")
|
| 883 |
+
|
| 884 |
+
if push_to_hub:
|
| 885 |
+
# Commit message and other hub kwargs can be passed via **kwargs
|
| 886 |
+
commit_message = kwargs.pop("commit_message", "Save processor")
|
| 887 |
+
return self.push_to_hub(save_directory, commit_message=commit_message, **kwargs)
|
| 888 |
+
|
| 889 |
+
return str(save_directory) # Return path consistent with Mixin
|
wav2vec2-large-xlsr-53/README.md
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: multilingual
|
| 3 |
+
datasets:
|
| 4 |
+
- common_voice
|
| 5 |
+
tags:
|
| 6 |
+
- speech
|
| 7 |
+
license: apache-2.0
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Wav2Vec2-XLSR-53
|
| 11 |
+
|
| 12 |
+
[Facebook's XLSR-Wav2Vec2](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/)
|
| 13 |
+
|
| 14 |
+
The base model pretrained on 16kHz sampled speech audio. When using the model make sure that your speech input is also sampled at 16Khz. Note that this model should be fine-tuned on a downstream task, like Automatic Speech Recognition. Check out [this blog](https://huggingface.co/blog/fine-tune-wav2vec2-english) for more information.
|
| 15 |
+
|
| 16 |
+
[Paper](https://arxiv.org/abs/2006.13979)
|
| 17 |
+
|
| 18 |
+
Authors: Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli
|
| 19 |
+
|
| 20 |
+
**Abstract**
|
| 21 |
+
This paper presents XLSR which learns cross-lingual speech representations by pretraining a single model from the raw waveform of speech in multiple languages. We build on wav2vec 2.0 which is trained by solving a contrastive task over masked latent speech representations and jointly learns a quantization of the latents shared across languages. The resulting model is fine-tuned on labeled data and experiments show that cross-lingual pretraining significantly outperforms monolingual pretraining. On the CommonVoice benchmark, XLSR shows a relative phoneme error rate reduction of 72% compared to the best known results. On BABEL, our approach improves word error rate by 16% relative compared to a comparable system. Our approach enables a single multilingual speech recognition model which is competitive to strong individual models. Analysis shows that the latent discrete speech representations are shared across languages with increased sharing for related languages. We hope to catalyze research in low-resource speech understanding by releasing XLSR-53, a large model pretrained in 53 languages.
|
| 22 |
+
|
| 23 |
+
The original model can be found under https://github.com/pytorch/fairseq/tree/master/examples/wav2vec#wav2vec-20.
|
| 24 |
+
|
| 25 |
+
# Usage
|
| 26 |
+
|
| 27 |
+
See [this notebook](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_Tune_XLSR_Wav2Vec2_on_Turkish_ASR_with_%F0%9F%A4%97_Transformers.ipynb) for more information on how to fine-tune the model.
|
| 28 |
+
|
| 29 |
+

|
wav2vec2-large-xlsr-53/config.json
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation_dropout": 0.0,
|
| 3 |
+
"apply_spec_augment": true,
|
| 4 |
+
"architectures": [
|
| 5 |
+
"Wav2Vec2ForPreTraining"
|
| 6 |
+
],
|
| 7 |
+
"attention_dropout": 0.1,
|
| 8 |
+
"bos_token_id": 1,
|
| 9 |
+
"codevector_dim": 768,
|
| 10 |
+
"contrastive_logits_temperature": 0.1,
|
| 11 |
+
"conv_bias": true,
|
| 12 |
+
"conv_dim": [
|
| 13 |
+
512,
|
| 14 |
+
512,
|
| 15 |
+
512,
|
| 16 |
+
512,
|
| 17 |
+
512,
|
| 18 |
+
512,
|
| 19 |
+
512
|
| 20 |
+
],
|
| 21 |
+
"conv_kernel": [
|
| 22 |
+
10,
|
| 23 |
+
3,
|
| 24 |
+
3,
|
| 25 |
+
3,
|
| 26 |
+
3,
|
| 27 |
+
2,
|
| 28 |
+
2
|
| 29 |
+
],
|
| 30 |
+
"conv_stride": [
|
| 31 |
+
5,
|
| 32 |
+
2,
|
| 33 |
+
2,
|
| 34 |
+
2,
|
| 35 |
+
2,
|
| 36 |
+
2,
|
| 37 |
+
2
|
| 38 |
+
],
|
| 39 |
+
"ctc_loss_reduction": "sum",
|
| 40 |
+
"ctc_zero_infinity": false,
|
| 41 |
+
"diversity_loss_weight": 0.1,
|
| 42 |
+
"do_stable_layer_norm": true,
|
| 43 |
+
"eos_token_id": 2,
|
| 44 |
+
"feat_extract_activation": "gelu",
|
| 45 |
+
"feat_extract_dropout": 0.0,
|
| 46 |
+
"feat_extract_norm": "layer",
|
| 47 |
+
"feat_proj_dropout": 0.1,
|
| 48 |
+
"feat_quantizer_dropout": 0.0,
|
| 49 |
+
"final_dropout": 0.0,
|
| 50 |
+
"gradient_checkpointing": false,
|
| 51 |
+
"hidden_act": "gelu",
|
| 52 |
+
"hidden_dropout": 0.1,
|
| 53 |
+
"hidden_size": 1024,
|
| 54 |
+
"initializer_range": 0.02,
|
| 55 |
+
"intermediate_size": 4096,
|
| 56 |
+
"layer_norm_eps": 1e-05,
|
| 57 |
+
"layerdrop": 0.1,
|
| 58 |
+
"mask_channel_length": 10,
|
| 59 |
+
"mask_channel_min_space": 1,
|
| 60 |
+
"mask_channel_other": 0.0,
|
| 61 |
+
"mask_channel_prob": 0.0,
|
| 62 |
+
"mask_channel_selection": "static",
|
| 63 |
+
"mask_feature_length": 10,
|
| 64 |
+
"mask_feature_prob": 0.0,
|
| 65 |
+
"mask_time_length": 10,
|
| 66 |
+
"mask_time_min_space": 1,
|
| 67 |
+
"mask_time_other": 0.0,
|
| 68 |
+
"mask_time_prob": 0.075,
|
| 69 |
+
"mask_time_selection": "static",
|
| 70 |
+
"model_type": "wav2vec2",
|
| 71 |
+
"num_attention_heads": 16,
|
| 72 |
+
"num_codevector_groups": 2,
|
| 73 |
+
"num_codevectors_per_group": 320,
|
| 74 |
+
"num_conv_pos_embedding_groups": 16,
|
| 75 |
+
"num_conv_pos_embeddings": 128,
|
| 76 |
+
"num_feat_extract_layers": 7,
|
| 77 |
+
"num_hidden_layers": 24,
|
| 78 |
+
"num_negatives": 100,
|
| 79 |
+
"pad_token_id": 0,
|
| 80 |
+
"proj_codevector_dim": 768,
|
| 81 |
+
"transformers_version": "4.7.0.dev0",
|
| 82 |
+
"vocab_size": 32
|
| 83 |
+
}
|
wav2vec2-large-xlsr-53/preprocessor_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_normalize": true,
|
| 3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
| 4 |
+
"feature_size": 1,
|
| 5 |
+
"padding_side": "right",
|
| 6 |
+
"padding_value": 0,
|
| 7 |
+
"return_attention_mask": true,
|
| 8 |
+
"sampling_rate": 16000
|
| 9 |
+
}
|
wav2vec2-large-xlsr-53/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:314340227371a608f71adcd5f0de5933824fe77e55822aa4b24dba9c1c364dcb
|
| 3 |
+
size 1269737156
|