Step-Audio (code, dataset, demo, paper, tools)
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- Step-Audio-AQAA. A Fully End-to-End Expressive Large Audio Language Model.pdf +3 -0
- Step-Audio-EditX Technical Report.pdf +3 -0
- Step-Audio. Unified Understanding and Generation in Intelligent Speech Interaction.pdf +3 -0
- code/ComfyUI_StepAudioTTS.zip +3 -0
- code/Step-Audio [intervitens].zip +3 -0
- code/Step-Audio-EditX.zip +3 -0
- code/Step-Audio-tts.zip +3 -0
- code/Step-Audio.zip +3 -0
- code/Step-Audio2.zip +3 -0
- code/StepAudioInfer.zip +3 -0
- code/astrbot_plugin_tts_Step_Audio.zip +3 -0
- dataset/StepEval-Audio-360/.gitattributes +59 -0
- dataset/StepEval-Audio-360/README.md +79 -0
- dataset/StepEval-Audio-360/audios.tar.gz +3 -0
- dataset/StepEval-Audio-360/data/test-00000-of-00001.parquet +3 -0
- dataset/StepEval-Audio-360/source.txt +1 -0
- demo/Step-Audio-EditX/.gitattributes +4 -0
- demo/Step-Audio-EditX/.gitignore +2 -0
- demo/Step-Audio-EditX/LICENSE +201 -0
- demo/Step-Audio-EditX/README.md +13 -0
- demo/Step-Audio-EditX/__init__.py +0 -0
- demo/Step-Audio-EditX/app.py +505 -0
- demo/Step-Audio-EditX/config/__init__.py +12 -0
- demo/Step-Audio-EditX/config/edit_config.py +32 -0
- demo/Step-Audio-EditX/config/prompts.py +23 -0
- demo/Step-Audio-EditX/funasr_detach/__init__.py +38 -0
- demo/Step-Audio-EditX/funasr_detach/auto/__init__.py +0 -0
- demo/Step-Audio-EditX/funasr_detach/auto/auto_frontend.py +90 -0
- demo/Step-Audio-EditX/funasr_detach/auto/auto_model.py +575 -0
- demo/Step-Audio-EditX/funasr_detach/auto/auto_tokenizer.py +7 -0
- demo/Step-Audio-EditX/funasr_detach/bin/__init__.py +0 -0
- demo/Step-Audio-EditX/funasr_detach/bin/compute_audio_cmvn.py +152 -0
- demo/Step-Audio-EditX/funasr_detach/bin/inference.py +33 -0
- demo/Step-Audio-EditX/funasr_detach/bin/tokenize_text.py +281 -0
- demo/Step-Audio-EditX/funasr_detach/bin/train.py +227 -0
- demo/Step-Audio-EditX/funasr_detach/datasets/__init__.py +0 -0
- demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/__init__.py +0 -0
- demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/datasets.py +112 -0
- demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/index_ds.py +150 -0
- demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/preprocessor.py +55 -0
- demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/samplers.py +306 -0
- demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/scp2jsonl.py +116 -0
- demo/Step-Audio-EditX/funasr_detach/download/__init__.py +0 -0
- demo/Step-Audio-EditX/funasr_detach/download/download_dataset_from_hub.py +19 -0
- demo/Step-Audio-EditX/funasr_detach/download/download_from_hub.py +231 -0
- demo/Step-Audio-EditX/funasr_detach/download/file.py +335 -0
- demo/Step-Audio-EditX/funasr_detach/download/name_maps_from_hub.py +13 -0
- demo/Step-Audio-EditX/funasr_detach/download/runtime_sdk_download_tool.py +60 -0
- demo/Step-Audio-EditX/funasr_detach/frontends/__init__.py +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Step-Audio-AQAA.[[:space:]]A[[:space:]]Fully[[:space:]]End-to-End[[:space:]]Expressive[[:space:]]Large[[:space:]]Audio[[:space:]]Language[[:space:]]Model.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Step-Audio-EditX[[:space:]]Technical[[:space:]]Report.pdf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Step-Audio.[[:space:]]Unified[[:space:]]Understanding[[:space:]]and[[:space:]]Generation[[:space:]]in[[:space:]]Intelligent[[:space:]]Speech[[:space:]]Interaction.pdf filter=lfs diff=lfs merge=lfs -text
|
Step-Audio-AQAA. A Fully End-to-End Expressive Large Audio Language Model.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4290ba946aaf9ebc8a1df00a905cbafb19f18ca3bcf9a38389716602ee5f7d7e
|
| 3 |
+
size 1203894
|
Step-Audio-EditX Technical Report.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c1ff7493dcedd3e506b8de85860b0c608d06f0392245fb5385b7fa8231234e50
|
| 3 |
+
size 786245
|
Step-Audio. Unified Understanding and Generation in Intelligent Speech Interaction.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce5f6d5b9575f4552c970f118d3191ff49f3e509a847f0fff58c23aa7b510b3f
|
| 3 |
+
size 6952309
|
code/ComfyUI_StepAudioTTS.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:17e36cf8812529f50c8b80b72d9111e20c476a2694365ad4f9049f019106b38b
|
| 3 |
+
size 14201121
|
code/Step-Audio [intervitens].zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:88549705b04c2bbde00c6e5c67b8966c0aac195c1605a756abd893c53d690e00
|
| 3 |
+
size 37467537
|
code/Step-Audio-EditX.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6376c6fe2c68201749f7dee3717eab495e5630d3611511e3afaa9b1fe265afcd
|
| 3 |
+
size 5979796
|
code/Step-Audio-tts.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bba78bbc9039e5b3d7df4f77917119bbacc6d8aa9baf4d8114edfceda83fc624
|
| 3 |
+
size 3827854
|
code/Step-Audio.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4aaa4b50011e9c82ac51020de7177b43a07e5f47cbd7e8bb55e80929cac5d7a
|
| 3 |
+
size 55625681
|
code/Step-Audio2.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:30f63a2dc8c598cc9c968c1a6cca2bc8150beeeec45b8b081843ac2580388dc9
|
| 3 |
+
size 26895459
|
code/StepAudioInfer.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1bf64260061b9cfb68fc673770b78f561a783a889dc01a3346d5bb12c1f8bf25
|
| 3 |
+
size 39121775
|
code/astrbot_plugin_tts_Step_Audio.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b6d563e44b30e27b1c51a05316e262ab6483db0cf6d42a43f5a5407ba9206380
|
| 3 |
+
size 6151313
|
dataset/StepEval-Audio-360/.gitattributes
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.mds filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
# Audio files - uncompressed
|
| 39 |
+
*.pcm filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.sam filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
*.raw filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
# Audio files - compressed
|
| 43 |
+
*.aac filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.flac filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
*.ogg filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
# Image files - uncompressed
|
| 49 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
*.tiff filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
# Image files - compressed
|
| 54 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
# Video files - compressed
|
| 58 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
*.webm filter=lfs diff=lfs merge=lfs -text
|
dataset/StepEval-Audio-360/README.md
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
| 4 |
+
# StepEval-Audio-360
|
| 5 |
+
## Dataset Description
|
| 6 |
+
StepEval Audio 360 is a comprehensive dataset that evaluates the ability of multi-modal large language models (MLLMs) in human-AI audio interaction. This audio benchmark dataset, sourced from professional human annotators, covers a full spectrum of capabilities: singing, creativity, role-playing, logical reasoning, voice understanding, voice instruction following, gaming, speech emotion control, and language ability.
|
| 7 |
+
|
| 8 |
+
## Languages
|
| 9 |
+
StepEval Audio 360 comprises about human voice recorded in different languages and dialects, including Chinese(Szechuan dialect and cantonese), English, and Japanese. It contains both audio and transcription data.
|
| 10 |
+
|
| 11 |
+
## Links
|
| 12 |
+
- Homepage: [Step-Audio](https://github.com/stepfun-ai/Step-Audio)
|
| 13 |
+
- Paper: [Step-Audio: Unified Understanding and Generation in Intelligent Speech Interaction
|
| 14 |
+
](https://arxiv.org/abs/2502.11946)
|
| 15 |
+
- ModelScope: https://modelscope.cn/datasets/stepfun-ai/StepEval-Audio-360
|
| 16 |
+
- Step-Audio Model Suite:
|
| 17 |
+
- Step-Audio-Tokenizer:
|
| 18 |
+
- Hugging Face:https://huggingface.co/stepfun-ai/Step-Audio-Tokenizer
|
| 19 |
+
- ModelScope:https://modelscope.cn/models/stepfun-ai/Step-Audio-Tokenizer
|
| 20 |
+
- Step-Audio-Chat :
|
| 21 |
+
- HuggingFace: https://huggingface.co/stepfun-ai/Step-Audio-Chat
|
| 22 |
+
- ModelScope: https://modelscope.cn/models/stepfun-ai/Step-Audio-Chat
|
| 23 |
+
- Step-Audio-TTS-3B:
|
| 24 |
+
- Hugging Face: https://huggingface.co/stepfun-ai/Step-Audio-TTS-3B
|
| 25 |
+
- ModelScope: https://modelscope.cn/models/stepfun-ai/Step-Audio-TTS-3B
|
| 26 |
+
|
| 27 |
+
## User Manual
|
| 28 |
+
* Download the dataset
|
| 29 |
+
```
|
| 30 |
+
# Make sure you have git-lfs installed (https://git-lfs.com)
|
| 31 |
+
git lfs install
|
| 32 |
+
git clone https://huggingface.co/datasets/stepfun-ai/StepEval-Audio-360
|
| 33 |
+
cd StepEval-Audio-360
|
| 34 |
+
git lfs pull
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
* Decompress audio data
|
| 38 |
+
```
|
| 39 |
+
mkdir audios
|
| 40 |
+
tar -xvf audios.tar.gz -C audios
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
* How to use
|
| 44 |
+
```
|
| 45 |
+
from datasets import load_dataset
|
| 46 |
+
|
| 47 |
+
dataset = load_dataset("stepfun-ai/StepEval-Audio-360")
|
| 48 |
+
dataset = dataset["test"]
|
| 49 |
+
for item in dataset:
|
| 50 |
+
conversation_id = item["conversation_id"]
|
| 51 |
+
category = item["category"]
|
| 52 |
+
conversation = item["conversation"]
|
| 53 |
+
|
| 54 |
+
# parse multi-turn dialogue data
|
| 55 |
+
for turn in conversation:
|
| 56 |
+
role = turn["role"]
|
| 57 |
+
text = turn["text"]
|
| 58 |
+
audio_filename = turn["audio_filename"] # refer to decompressed audio file
|
| 59 |
+
if audio_filename is not None:
|
| 60 |
+
print(role, text, audio_filename)
|
| 61 |
+
else:
|
| 62 |
+
print(role, text)
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Licensing
|
| 66 |
+
This dataset project is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0).
|
| 67 |
+
|
| 68 |
+
## Citation
|
| 69 |
+
If you utilize this dataset, please cite it using the BibTeX provided.
|
| 70 |
+
```
|
| 71 |
+
@misc {stepfun_2025,
|
| 72 |
+
author = { {StepFun} },
|
| 73 |
+
title = { StepEval-Audio-360 (Revision 72a072e) },
|
| 74 |
+
year = 2025,
|
| 75 |
+
url = { https://huggingface.co/datasets/stepfun-ai/StepEval-Audio-360 },
|
| 76 |
+
doi = { 10.57967/hf/4528 },
|
| 77 |
+
publisher = { Hugging Face }
|
| 78 |
+
}
|
| 79 |
+
```
|
dataset/StepEval-Audio-360/audios.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7e9f043765500c6f6940ae58a55cf226ddfdde533099f4765bc40d2710d82d3
|
| 3 |
+
size 166398432
|
dataset/StepEval-Audio-360/data/test-00000-of-00001.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2990e77433b866431bbad8adc27b3aebee77046ceca5d265113994fedf2eaff
|
| 3 |
+
size 69065
|
dataset/StepEval-Audio-360/source.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
https://huggingface.co/datasets/stepfun-ai/StepEval-Audio-360
|
demo/Step-Audio-EditX/.gitattributes
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
examples filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
speakers/nezha_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
speakers/nezhaRAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
speakers/nezha哼唱_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
demo/Step-Audio-EditX/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
output/
|
demo/Step-Audio-EditX/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
demo/Step-Audio-EditX/README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Step-Audio-EditX
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.49.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
short_description: Try out Step-Audio-EditX
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
demo/Step-Audio-EditX/__init__.py
ADDED
|
File without changes
|
demo/Step-Audio-EditX/app.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import torch
|
| 5 |
+
import logging
|
| 6 |
+
import threading
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import torchaudio
|
| 9 |
+
import librosa
|
| 10 |
+
import soundfile as sf
|
| 11 |
+
|
| 12 |
+
# ZeroGPU support
|
| 13 |
+
try:
|
| 14 |
+
import spaces
|
| 15 |
+
ZEROGPU_AVAILABLE = True
|
| 16 |
+
except ImportError:
|
| 17 |
+
ZEROGPU_AVAILABLE = False
|
| 18 |
+
# Create a dummy decorator for non-ZeroGPU environments
|
| 19 |
+
class spaces:
|
| 20 |
+
@staticmethod
|
| 21 |
+
def GPU(duration=10):
|
| 22 |
+
def decorator(func):
|
| 23 |
+
return func
|
| 24 |
+
return decorator
|
| 25 |
+
|
| 26 |
+
# Project imports
|
| 27 |
+
from tokenizer import StepAudioTokenizer
|
| 28 |
+
from tts import StepAudioTTS
|
| 29 |
+
from model_loader import ModelSource
|
| 30 |
+
from config.edit_config import get_supported_edit_types
|
| 31 |
+
|
| 32 |
+
# Configure logging
|
| 33 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
# Global variables for ZeroGPU-optimized loading
|
| 37 |
+
encoder = None
|
| 38 |
+
common_tts_engine = None
|
| 39 |
+
args_global = None
|
| 40 |
+
_model_lock = threading.Lock() # Thread lock for model initialization
|
| 41 |
+
|
| 42 |
+
def initialize_models():
|
| 43 |
+
"""Initialize models on first GPU call (ZeroGPU optimization: load inside GPU context)"""
|
| 44 |
+
global encoder, common_tts_engine, args_global
|
| 45 |
+
|
| 46 |
+
# Fast path: check if already initialized (without lock)
|
| 47 |
+
if common_tts_engine is not None:
|
| 48 |
+
return # Already initialized
|
| 49 |
+
|
| 50 |
+
# Slow path: acquire lock and double-check
|
| 51 |
+
with _model_lock:
|
| 52 |
+
# Double-check pattern: another thread might have initialized while waiting for lock
|
| 53 |
+
if common_tts_engine is not None:
|
| 54 |
+
return # Already initialized by another thread
|
| 55 |
+
|
| 56 |
+
if args_global is None:
|
| 57 |
+
raise RuntimeError("Global args not set. Cannot initialize models.")
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
logger.info("🚀 Initializing models inside GPU context (first call)...")
|
| 61 |
+
|
| 62 |
+
# Determine model source
|
| 63 |
+
source_mapping = {
|
| 64 |
+
"auto": ModelSource.AUTO,
|
| 65 |
+
"local": ModelSource.LOCAL,
|
| 66 |
+
"modelscope": ModelSource.MODELSCOPE,
|
| 67 |
+
"huggingface": ModelSource.HUGGINGFACE
|
| 68 |
+
}
|
| 69 |
+
model_source = source_mapping[args_global.model_source]
|
| 70 |
+
|
| 71 |
+
# Load StepAudioTokenizer (avoid CUDA initialization in main process)
|
| 72 |
+
encoder = StepAudioTokenizer(
|
| 73 |
+
os.path.join(args_global.model_path, "Step-Audio-Tokenizer"),
|
| 74 |
+
model_source=model_source,
|
| 75 |
+
funasr_model_id=args_global.tokenizer_model_id
|
| 76 |
+
)
|
| 77 |
+
logger.info("✓ StepAudioTokenizer loaded")
|
| 78 |
+
|
| 79 |
+
# Initialize common TTS engine (avoid CUDA initialization in main process)
|
| 80 |
+
common_tts_engine = StepAudioTTS(
|
| 81 |
+
os.path.join(args_global.model_path, "Step-Audio-EditX"),
|
| 82 |
+
encoder,
|
| 83 |
+
model_source=model_source,
|
| 84 |
+
tts_model_id=args_global.tts_model_id
|
| 85 |
+
)
|
| 86 |
+
logger.info("✓ StepCommonAudioTTS loaded")
|
| 87 |
+
print("Models initialized inside GPU context.")
|
| 88 |
+
|
| 89 |
+
if ZEROGPU_AVAILABLE:
|
| 90 |
+
logger.info("💡 Models loaded inside GPU context - ready for inference")
|
| 91 |
+
else:
|
| 92 |
+
logger.info("💡 Models loaded - ready for inference")
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"❌ Error loading models: {e}")
|
| 96 |
+
raise
|
| 97 |
+
|
| 98 |
+
def get_model_config():
|
| 99 |
+
"""Get model configuration without initializing GPU models"""
|
| 100 |
+
if args_global is None:
|
| 101 |
+
raise RuntimeError("Global args not set. Cannot get model config.")
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
"encoder_path": os.path.join(args_global.model_path, "Step-Audio-Tokenizer"),
|
| 105 |
+
"tts_path": os.path.join(args_global.model_path, "Step-Audio-EditX"),
|
| 106 |
+
"model_source": args_global.model_source,
|
| 107 |
+
"tokenizer_model_id": args_global.tokenizer_model_id,
|
| 108 |
+
"tts_model_id": args_global.tts_model_id
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
def get_gpu_duration(audio_input, text_input, target_text, task_type, task_info):
|
| 112 |
+
"""Dynamic GPU duration based on whether models need initialization"""
|
| 113 |
+
global common_tts_engine
|
| 114 |
+
|
| 115 |
+
if common_tts_engine is None:
|
| 116 |
+
# First call - need time for model loading (up to 5 minutes)
|
| 117 |
+
return 300 # Maximum allowed duration for model initialization
|
| 118 |
+
else:
|
| 119 |
+
# Subsequent calls - only inference time needed
|
| 120 |
+
return 120 # Standard inference duration
|
| 121 |
+
|
| 122 |
+
@spaces.GPU(duration=get_gpu_duration) # Dynamic duration based on model state
|
| 123 |
+
def process_audio_with_gpu(audio_input, text_input, target_text, task_type, task_info):
|
| 124 |
+
"""Process audio using GPU (models are loaded inside GPU context to avoid main process errors)"""
|
| 125 |
+
global common_tts_engine
|
| 126 |
+
|
| 127 |
+
# Initialize models if not already loaded (inside GPU context to avoid main process errors)
|
| 128 |
+
if common_tts_engine is None:
|
| 129 |
+
print("Initializing common_tts_engine inside GPU context...")
|
| 130 |
+
logger.info("🎯 GPU allocated for 300s (first call with model loading)...")
|
| 131 |
+
initialize_models()
|
| 132 |
+
logger.info("✅ Models loaded successfully inside GPU context")
|
| 133 |
+
else:
|
| 134 |
+
print("common_tts_engine already initialized.")
|
| 135 |
+
logger.info("🎯 GPU allocated for 120s (inference with loaded models)...")
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
# Use loaded models (first call may include loading time, subsequent calls are fast)
|
| 139 |
+
if task_type == "clone":
|
| 140 |
+
output_audio, sr = common_tts_engine.clone(audio_input, text_input, target_text)
|
| 141 |
+
else:
|
| 142 |
+
output_audio, sr = common_tts_engine.edit(audio_input, text_input, task_type, task_info, target_text)
|
| 143 |
+
|
| 144 |
+
logger.info("✅ Audio processing completed")
|
| 145 |
+
return output_audio, sr
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"❌ Audio processing failed: {e}")
|
| 149 |
+
raise
|
| 150 |
+
# GPU automatically deallocated when function exits
|
| 151 |
+
|
| 152 |
+
# Save audio to temporary directory
|
| 153 |
+
def save_audio(audio_type, audio_data, sr, tmp_dir):
|
| 154 |
+
"""Save audio data to a temporary file with timestamp"""
|
| 155 |
+
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 156 |
+
save_path = os.path.join(tmp_dir, audio_type, f"{current_time}.wav")
|
| 157 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
if isinstance(audio_data, torch.Tensor):
|
| 161 |
+
torchaudio.save(save_path, audio_data, sr)
|
| 162 |
+
else:
|
| 163 |
+
sf.write(save_path, audio_data, sr)
|
| 164 |
+
logger.debug(f"Audio saved to: {save_path}")
|
| 165 |
+
return save_path
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"Failed to save audio: {e}")
|
| 168 |
+
raise
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class EditxTab:
|
| 172 |
+
"""Audio editing and voice cloning interface tab"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, args):
|
| 175 |
+
self.args = args
|
| 176 |
+
self.edit_type_list = list(get_supported_edit_types().keys())
|
| 177 |
+
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
| 178 |
+
|
| 179 |
+
def history_messages_to_show(self, messages):
|
| 180 |
+
"""Convert message history to gradio chatbot format"""
|
| 181 |
+
show_msgs = []
|
| 182 |
+
for message in messages:
|
| 183 |
+
edit_type = message['edit_type']
|
| 184 |
+
edit_info = message['edit_info']
|
| 185 |
+
source_text = message['source_text']
|
| 186 |
+
target_text = message['target_text']
|
| 187 |
+
raw_audio_part = message['raw_wave']
|
| 188 |
+
edit_audio_part = message['edit_wave']
|
| 189 |
+
type_str = f"{edit_type}-{edit_info}" if edit_info is not None else f"{edit_type}"
|
| 190 |
+
show_msgs.extend([
|
| 191 |
+
{"role": "user", "content": f"任务类型:{type_str}\n文本:{source_text}"},
|
| 192 |
+
{"role": "user", "content": gr.Audio(value=raw_audio_part, interactive=False)},
|
| 193 |
+
{"role": "assistant", "content": f"输出音频:\n文本:{target_text}"},
|
| 194 |
+
{"role": "assistant", "content": gr.Audio(value=edit_audio_part, interactive=False)}
|
| 195 |
+
])
|
| 196 |
+
return show_msgs
|
| 197 |
+
|
| 198 |
+
def generate_clone(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state):
|
| 199 |
+
"""Generate cloned audio (models are loaded on first GPU call)"""
|
| 200 |
+
self.logger.info("Starting voice cloning process")
|
| 201 |
+
state['history_audio'] = []
|
| 202 |
+
state['history_messages'] = []
|
| 203 |
+
|
| 204 |
+
# Input validation
|
| 205 |
+
if not prompt_text_input or prompt_text_input.strip() == "":
|
| 206 |
+
error_msg = "[Error] Uploaded text cannot be empty."
|
| 207 |
+
self.logger.error(error_msg)
|
| 208 |
+
return [{"role": "user", "content": error_msg}], state
|
| 209 |
+
if not prompt_audio_input:
|
| 210 |
+
error_msg = "[Error] Uploaded audio cannot be empty."
|
| 211 |
+
self.logger.error(error_msg)
|
| 212 |
+
return [{"role": "user", "content": error_msg}], state
|
| 213 |
+
if not generated_text or generated_text.strip() == "":
|
| 214 |
+
error_msg = "[Error] Clone content cannot be empty."
|
| 215 |
+
self.logger.error(error_msg)
|
| 216 |
+
return [{"role": "user", "content": error_msg}], state
|
| 217 |
+
if edit_type != "clone":
|
| 218 |
+
error_msg = "[Error] CLONE button must use clone task."
|
| 219 |
+
self.logger.error(error_msg)
|
| 220 |
+
return [{"role": "user", "content": error_msg}], state
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
# Use GPU inference with models loaded inside GPU context
|
| 224 |
+
output_audio, output_sr = process_audio_with_gpu(
|
| 225 |
+
prompt_audio_input, prompt_text_input, generated_text, "clone", edit_info
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if output_audio is not None and output_sr is not None:
|
| 229 |
+
# Convert tensor to numpy if needed
|
| 230 |
+
if isinstance(output_audio, torch.Tensor):
|
| 231 |
+
audio_numpy = output_audio.cpu().numpy().squeeze()
|
| 232 |
+
else:
|
| 233 |
+
audio_numpy = output_audio
|
| 234 |
+
|
| 235 |
+
# Load original audio for comparison
|
| 236 |
+
input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input)
|
| 237 |
+
|
| 238 |
+
# Create message for history
|
| 239 |
+
cur_assistant_msg = {
|
| 240 |
+
"edit_type": edit_type,
|
| 241 |
+
"edit_info": edit_info,
|
| 242 |
+
"source_text": prompt_text_input,
|
| 243 |
+
"target_text": generated_text,
|
| 244 |
+
"raw_wave": (input_sample_rate, input_audio_data_numpy),
|
| 245 |
+
"edit_wave": (output_sr, audio_numpy),
|
| 246 |
+
}
|
| 247 |
+
state["history_audio"].append((output_sr, audio_numpy, generated_text))
|
| 248 |
+
state["history_messages"].append(cur_assistant_msg)
|
| 249 |
+
|
| 250 |
+
show_msgs = self.history_messages_to_show(state["history_messages"])
|
| 251 |
+
self.logger.info("Voice cloning completed successfully")
|
| 252 |
+
return show_msgs, state
|
| 253 |
+
else:
|
| 254 |
+
error_msg = "[Error] Clone failed"
|
| 255 |
+
self.logger.error(error_msg)
|
| 256 |
+
return [{"role": "user", "content": error_msg}], state
|
| 257 |
+
|
| 258 |
+
except Exception as e:
|
| 259 |
+
error_msg = f"[Error] Clone failed: {str(e)}"
|
| 260 |
+
self.logger.error(error_msg)
|
| 261 |
+
return [{"role": "user", "content": error_msg}], state
|
| 262 |
+
|
| 263 |
+
def generate_edit(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state):
|
| 264 |
+
"""Generate edited audio (models are loaded on first GPU call)"""
|
| 265 |
+
self.logger.info("Starting audio editing process")
|
| 266 |
+
|
| 267 |
+
# Input validation
|
| 268 |
+
if not prompt_audio_input:
|
| 269 |
+
error_msg = "[Error] Uploaded audio cannot be empty."
|
| 270 |
+
self.logger.error(error_msg)
|
| 271 |
+
return [{"role": "user", "content": error_msg}], state
|
| 272 |
+
|
| 273 |
+
try:
|
| 274 |
+
# Determine which audio to use
|
| 275 |
+
if len(state["history_audio"]) == 0:
|
| 276 |
+
# First edit - use uploaded audio
|
| 277 |
+
audio_to_edit = prompt_audio_input
|
| 278 |
+
text_to_use = prompt_text_input
|
| 279 |
+
self.logger.debug("Using prompt audio, no history found")
|
| 280 |
+
else:
|
| 281 |
+
# Use previous edited audio - save it to temp file first
|
| 282 |
+
sample_rate, audio_numpy, previous_text = state["history_audio"][-1]
|
| 283 |
+
temp_path = save_audio("temp", audio_numpy, sample_rate, self.args.tmp_dir)
|
| 284 |
+
audio_to_edit = temp_path
|
| 285 |
+
text_to_use = previous_text
|
| 286 |
+
self.logger.debug(f"Using previous audio from history, count: {len(state['history_audio'])}")
|
| 287 |
+
|
| 288 |
+
# For para-linguistic, use generated_text; otherwise use source text
|
| 289 |
+
if edit_type not in {"paralinguistic"}:
|
| 290 |
+
generated_text = text_to_use
|
| 291 |
+
|
| 292 |
+
# Use GPU inference with models loaded inside GPU context
|
| 293 |
+
output_audio, output_sr = process_audio_with_gpu(
|
| 294 |
+
audio_to_edit, text_to_use, generated_text, edit_type, edit_info
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if output_audio is not None and output_sr is not None:
|
| 298 |
+
# Convert tensor to numpy if needed
|
| 299 |
+
if isinstance(output_audio, torch.Tensor):
|
| 300 |
+
audio_numpy = output_audio.cpu().numpy().squeeze()
|
| 301 |
+
else:
|
| 302 |
+
audio_numpy = output_audio
|
| 303 |
+
|
| 304 |
+
# Load original audio for comparison
|
| 305 |
+
if len(state["history_audio"]) == 0:
|
| 306 |
+
input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input)
|
| 307 |
+
else:
|
| 308 |
+
input_sample_rate, input_audio_data_numpy, _ = state["history_audio"][-1]
|
| 309 |
+
|
| 310 |
+
# Create message for history
|
| 311 |
+
cur_assistant_msg = {
|
| 312 |
+
"edit_type": edit_type,
|
| 313 |
+
"edit_info": edit_info,
|
| 314 |
+
"source_text": text_to_use,
|
| 315 |
+
"target_text": generated_text,
|
| 316 |
+
"raw_wave": (input_sample_rate, input_audio_data_numpy),
|
| 317 |
+
"edit_wave": (output_sr, audio_numpy),
|
| 318 |
+
}
|
| 319 |
+
state["history_audio"].append((output_sr, audio_numpy, generated_text))
|
| 320 |
+
state["history_messages"].append(cur_assistant_msg)
|
| 321 |
+
|
| 322 |
+
show_msgs = self.history_messages_to_show(state["history_messages"])
|
| 323 |
+
self.logger.info("Audio editing completed successfully")
|
| 324 |
+
return show_msgs, state
|
| 325 |
+
else:
|
| 326 |
+
error_msg = "[Error] Edit failed"
|
| 327 |
+
self.logger.error(error_msg)
|
| 328 |
+
return [{"role": "user", "content": error_msg}], state
|
| 329 |
+
|
| 330 |
+
except Exception as e:
|
| 331 |
+
error_msg = f"[Error] Edit failed: {str(e)}"
|
| 332 |
+
self.logger.error(error_msg)
|
| 333 |
+
return [{"role": "user", "content": error_msg}], state
|
| 334 |
+
|
| 335 |
+
def clear_history(self, state):
|
| 336 |
+
"""Clear conversation history"""
|
| 337 |
+
state["history_messages"] = []
|
| 338 |
+
state["history_audio"] = []
|
| 339 |
+
return [], state
|
| 340 |
+
|
| 341 |
+
def init_state(self):
|
| 342 |
+
"""Initialize conversation state"""
|
| 343 |
+
return {
|
| 344 |
+
"history_messages": [],
|
| 345 |
+
"history_audio": []
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
def register_components(self):
|
| 349 |
+
"""Register gradio components - maintaining exact layout from original"""
|
| 350 |
+
with gr.Tab("Editx"):
|
| 351 |
+
with gr.Row():
|
| 352 |
+
with gr.Column():
|
| 353 |
+
self.model_input = gr.Textbox(label="Model Name", value="Step-Audio-EditX", scale=1)
|
| 354 |
+
self.prompt_text_input = gr.Textbox(label="Prompt Text", value="", scale=1)
|
| 355 |
+
self.prompt_audio_input = gr.Audio(
|
| 356 |
+
sources=["upload", "microphone"],
|
| 357 |
+
format="wav",
|
| 358 |
+
type="filepath",
|
| 359 |
+
label="Input Audio",
|
| 360 |
+
)
|
| 361 |
+
self.generated_text = gr.Textbox(label="Target Text", lines=1, max_lines=200, max_length=1000)
|
| 362 |
+
with gr.Column():
|
| 363 |
+
with gr.Row():
|
| 364 |
+
self.edit_type = gr.Dropdown(label="Task", choices=self.edit_type_list, value="clone")
|
| 365 |
+
self.edit_info = gr.Dropdown(label="Sub-task", choices=[], value=None)
|
| 366 |
+
self.chat_box = gr.Chatbot(label="History", type="messages", height=480*1)
|
| 367 |
+
with gr.Row():
|
| 368 |
+
with gr.Column():
|
| 369 |
+
with gr.Row():
|
| 370 |
+
self.button_tts = gr.Button("CLONE", variant="primary")
|
| 371 |
+
self.button_edit = gr.Button("EDIT", variant="primary")
|
| 372 |
+
with gr.Column():
|
| 373 |
+
self.clean_history_submit = gr.Button("Clear History", variant="primary")
|
| 374 |
+
|
| 375 |
+
gr.Markdown("---")
|
| 376 |
+
gr.Markdown("""
|
| 377 |
+
**Button Description:**
|
| 378 |
+
- CLONE: Synthesizes audio based on uploaded audio and text, only used for clone mode, will clear history information when used.
|
| 379 |
+
- EDIT: Edits based on uploaded audio, or continues to stack edit effects based on the previous round of generated audio.
|
| 380 |
+
""")
|
| 381 |
+
gr.Markdown("""
|
| 382 |
+
**Operation Workflow:**
|
| 383 |
+
- Upload the audio to be edited on the left side and fill in the corresponding text content of the audio;
|
| 384 |
+
- If the task requires modifying text content (such as clone, para-linguistic), fill in the text to be synthesized in the "clone text" field. For all other tasks, keep the uploaded audio text content unchanged;
|
| 385 |
+
- Select tasks and subtasks on the right side (some tasks have no subtasks, such as vad, etc.);
|
| 386 |
+
- Click the "CLONE" or "EDIT" button on the left side, and audio will be generated in the dialog box on the right side.
|
| 387 |
+
""")
|
| 388 |
+
gr.Markdown("""
|
| 389 |
+
**Para-linguistic Description:**
|
| 390 |
+
- Supported tags include: [Breathing] [Laughter] [Surprise-oh] [Confirmation-en] [Uhm] [Surprise-ah] [Surprise-wa] [Sigh] [Question-ei] [Dissatisfaction-hnn]
|
| 391 |
+
- Example:
|
| 392 |
+
- Fill in "clone text" field: "Great, the weather is so nice today." Click the "CLONE" button to get audio.
|
| 393 |
+
- Change "clone text" field to: "Great[Laughter], the weather is so nice today[Surprise-ah]." Click the "EDIT" button to get para-linguistic audio.
|
| 394 |
+
""")
|
| 395 |
+
|
| 396 |
+
def register_events(self):
|
| 397 |
+
"""Register event handlers"""
|
| 398 |
+
# Create independent state for each session
|
| 399 |
+
state = gr.State(self.init_state())
|
| 400 |
+
|
| 401 |
+
self.button_tts.click(self.generate_clone,
|
| 402 |
+
inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state],
|
| 403 |
+
outputs=[self.chat_box, state])
|
| 404 |
+
self.button_edit.click(self.generate_edit,
|
| 405 |
+
inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state],
|
| 406 |
+
outputs=[self.chat_box, state])
|
| 407 |
+
|
| 408 |
+
self.clean_history_submit.click(self.clear_history, inputs=[state], outputs=[self.chat_box, state])
|
| 409 |
+
self.edit_type.change(
|
| 410 |
+
fn=self.update_edit_info,
|
| 411 |
+
inputs=self.edit_type,
|
| 412 |
+
outputs=self.edit_info,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
def update_edit_info(self, category):
|
| 416 |
+
"""Update sub-task dropdown based on main task selection"""
|
| 417 |
+
category_items = get_supported_edit_types()
|
| 418 |
+
choices = category_items.get(category, [])
|
| 419 |
+
value = None if len(choices) == 0 else choices[0]
|
| 420 |
+
return gr.Dropdown(label="Sub-task", choices=choices, value=value)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def launch_demo(args, editx_tab):
|
| 424 |
+
"""Launch the gradio demo"""
|
| 425 |
+
with gr.Blocks(
|
| 426 |
+
theme=gr.themes.Soft(),
|
| 427 |
+
title="🎙️ Step-Audio-EditX",
|
| 428 |
+
css="""
|
| 429 |
+
:root {
|
| 430 |
+
--font: "Helvetica Neue", Helvetica, Arial, sans-serif;
|
| 431 |
+
--font-mono: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
|
| 432 |
+
}
|
| 433 |
+
""") as demo:
|
| 434 |
+
gr.Markdown("## 🎙️ Step-Audio-EditX")
|
| 435 |
+
gr.Markdown("Audio Editing and Zero-Shot Cloning using Step-Audio-EditX")
|
| 436 |
+
|
| 437 |
+
# Register components
|
| 438 |
+
editx_tab.register_components()
|
| 439 |
+
|
| 440 |
+
# Register events
|
| 441 |
+
editx_tab.register_events()
|
| 442 |
+
|
| 443 |
+
# Launch demo
|
| 444 |
+
demo.queue().launch(
|
| 445 |
+
server_name=args.server_name,
|
| 446 |
+
server_port=args.server_port,
|
| 447 |
+
share=args.share if hasattr(args, 'share') else False
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
if __name__ == "__main__":
|
| 452 |
+
# Parse command line arguments
|
| 453 |
+
parser = argparse.ArgumentParser(description="Step-Audio Edit Demo")
|
| 454 |
+
parser.add_argument("--model-path", type=str, default="stepfun-ai", help="Model path.")
|
| 455 |
+
parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
|
| 456 |
+
parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
|
| 457 |
+
parser.add_argument("--tmp-dir", type=str, default="/tmp/gradio", help="Save path.")
|
| 458 |
+
parser.add_argument("--share", action="store_true", help="Share gradio app.")
|
| 459 |
+
|
| 460 |
+
# Multi-source loading support parameters
|
| 461 |
+
parser.add_argument(
|
| 462 |
+
"--model-source",
|
| 463 |
+
type=str,
|
| 464 |
+
default="huggingface",
|
| 465 |
+
choices=["auto", "local", "modelscope", "huggingface"],
|
| 466 |
+
help="Model source: auto (detect automatically), local, modelscope, or huggingface"
|
| 467 |
+
)
|
| 468 |
+
parser.add_argument(
|
| 469 |
+
"--tokenizer-model-id",
|
| 470 |
+
type=str,
|
| 471 |
+
default="dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online",
|
| 472 |
+
help="Tokenizer model ID for online loading"
|
| 473 |
+
)
|
| 474 |
+
parser.add_argument(
|
| 475 |
+
"--tts-model-id",
|
| 476 |
+
type=str,
|
| 477 |
+
default=None,
|
| 478 |
+
help="TTS model ID for online loading (if different from model-path)"
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
args = parser.parse_args()
|
| 482 |
+
|
| 483 |
+
# Store args globally for model configuration
|
| 484 |
+
args_global = args
|
| 485 |
+
|
| 486 |
+
logger.info(f"Configuration loaded:")
|
| 487 |
+
logger.info(f"Model source: {args.model_source}")
|
| 488 |
+
logger.info(f"Model path: {args.model_path}")
|
| 489 |
+
logger.info(f"Tokenizer model ID: {args.tokenizer_model_id}")
|
| 490 |
+
if args.tts_model_id:
|
| 491 |
+
logger.info(f"TTS model ID: {args.tts_model_id}")
|
| 492 |
+
|
| 493 |
+
# Models will be initialized on first GPU call to avoid ZeroGPU main process errors
|
| 494 |
+
|
| 495 |
+
if ZEROGPU_AVAILABLE:
|
| 496 |
+
logger.info("🎉 ZeroGPU detected - using dynamic GPU duration management!")
|
| 497 |
+
logger.info("💡 First call: 300s (model loading), subsequent calls: 120s (inference only)")
|
| 498 |
+
else:
|
| 499 |
+
logger.info("💻 Running in local mode - models will be loaded on first call")
|
| 500 |
+
|
| 501 |
+
# Create EditxTab instance
|
| 502 |
+
editx_tab = EditxTab(args)
|
| 503 |
+
|
| 504 |
+
# Launch demo
|
| 505 |
+
launch_demo(args, editx_tab)
|
demo/Step-Audio-EditX/config/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration module for Step-Audio
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .prompts import AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL, AUDIO_EDIT_SYSTEM_PROMPT
|
| 6 |
+
from .edit_config import get_supported_edit_types
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
'AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL',
|
| 10 |
+
'AUDIO_EDIT_SYSTEM_PROMPT',
|
| 11 |
+
'get_supported_edit_types'
|
| 12 |
+
]
|
demo/Step-Audio-EditX/config/edit_config.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
音频编辑配置模块
|
| 3 |
+
包含支持的编辑类型和相关配置
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
def get_supported_edit_types():
|
| 7 |
+
"""
|
| 8 |
+
获取支持的编辑类型和选项
|
| 9 |
+
|
| 10 |
+
Returns:
|
| 11 |
+
Dict[str, list]: Dictionary of edit types and their options
|
| 12 |
+
"""
|
| 13 |
+
return {
|
| 14 |
+
"clone": [],
|
| 15 |
+
"emotion": [
|
| 16 |
+
'happy', 'angry', 'sad', 'humour', 'confusion', 'disgusted',
|
| 17 |
+
'empathy', 'embarrass', 'fear', 'surprised', 'excited',
|
| 18 |
+
'depressed', 'coldness', 'admiration', 'remove'
|
| 19 |
+
],
|
| 20 |
+
"style": [
|
| 21 |
+
'serious', 'arrogant', 'child', 'older', 'girl', 'pure',
|
| 22 |
+
'sister', 'sweet', 'ethereal', 'whisper', 'gentle', 'recite',
|
| 23 |
+
'generous', 'act_coy', 'warm', 'shy', 'comfort', 'authority',
|
| 24 |
+
'chat', 'radio', 'soulful', 'story', 'vivid', 'program',
|
| 25 |
+
'news', 'advertising', 'roar', 'murmur', 'shout', 'deeply', 'loudly',
|
| 26 |
+
'remove', 'exaggerated'
|
| 27 |
+
],
|
| 28 |
+
"vad": [],
|
| 29 |
+
"denoise": [],
|
| 30 |
+
"paralinguistic": [],
|
| 31 |
+
"speed": ["faster", "slower", "more faster", "more slower"],
|
| 32 |
+
}
|
demo/Step-Audio-EditX/config/prompts.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
系统提示配置模块
|
| 3 |
+
包含所有TTS和编辑相关的系统提示
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL = """Generate audio with the following timbre, prosody and speaking style
|
| 7 |
+
|
| 8 |
+
[speaker_start]
|
| 9 |
+
speaker name: {speaker}
|
| 10 |
+
speaker prompt text:
|
| 11 |
+
{prompt_text}
|
| 12 |
+
speaker audio tokens:
|
| 13 |
+
{prompt_wav_tokens}
|
| 14 |
+
[speaker_end]
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
AUDIO_EDIT_SYSTEM_PROMPT = """As a highly skilled audio editing and tuning specialist, you excel in interpreting user instructions and applying precise adjustments to meet their needs. Your expertise spans a wide range of enhancement capabilities, including but not limited to:
|
| 18 |
+
# Emotional Enhancement
|
| 19 |
+
# Speaking Style Transfer
|
| 20 |
+
# Non-linguistic Adjustments
|
| 21 |
+
# Audio Tuning & Editing
|
| 22 |
+
Note: You will receive instructions in natural language and are expected to accurately interpret and execute the most suitable audio edits and enhancements.
|
| 23 |
+
"""
|
demo/Step-Audio-EditX/funasr_detach/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Initialize funasr package."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import pkgutil
|
| 5 |
+
import importlib
|
| 6 |
+
|
| 7 |
+
dirname = os.path.dirname(__file__)
|
| 8 |
+
version_file = os.path.join(dirname, "version.txt")
|
| 9 |
+
with open(version_file, "r") as f:
|
| 10 |
+
__version__ = f.read().strip()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import importlib
|
| 14 |
+
import pkgutil
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def import_submodules(package, recursive=True):
|
| 18 |
+
if isinstance(package, str):
|
| 19 |
+
package = importlib.import_module(package)
|
| 20 |
+
results = {}
|
| 21 |
+
for loader, name, is_pkg in pkgutil.walk_packages(
|
| 22 |
+
package.__path__, package.__name__ + "."
|
| 23 |
+
):
|
| 24 |
+
try:
|
| 25 |
+
results[name] = importlib.import_module(name)
|
| 26 |
+
except Exception as e:
|
| 27 |
+
# 如果想要看到导入错误的具体信息,可以取消注释下面的行
|
| 28 |
+
# print(f"Failed to import {name}: {e}")
|
| 29 |
+
pass
|
| 30 |
+
if recursive and is_pkg:
|
| 31 |
+
results.update(import_submodules(name))
|
| 32 |
+
return results
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
import_submodules(__name__)
|
| 36 |
+
|
| 37 |
+
from funasr_detach.auto.auto_model import AutoModel
|
| 38 |
+
from funasr_detach.auto.auto_frontend import AutoFrontend
|
demo/Step-Audio-EditX/funasr_detach/auto/__init__.py
ADDED
|
File without changes
|
demo/Step-Audio-EditX/funasr_detach/auto/auto_frontend.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import logging
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
from funasr_detach.register import tables
|
| 6 |
+
from funasr_detach.download.download_from_hub import download_model
|
| 7 |
+
from funasr_detach.utils.load_utils import load_audio_text_image_video, extract_fbank
|
| 8 |
+
from funasr_detach.auto.auto_model import prepare_data_iterator
|
| 9 |
+
from funasr_detach.auto.auto_model import prepare_data_iterator
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AutoFrontend:
|
| 13 |
+
def __init__(self, **kwargs):
|
| 14 |
+
assert "model" in kwargs
|
| 15 |
+
if "model_conf" not in kwargs:
|
| 16 |
+
logging.info(
|
| 17 |
+
"download models from model hub: {}".format(
|
| 18 |
+
kwargs.get("model_hub", "ms")
|
| 19 |
+
)
|
| 20 |
+
)
|
| 21 |
+
kwargs = download_model(**kwargs)
|
| 22 |
+
|
| 23 |
+
# build frontend
|
| 24 |
+
frontend = kwargs.get("frontend", None)
|
| 25 |
+
if frontend is not None:
|
| 26 |
+
frontend_class = tables.frontend_classes.get(frontend)
|
| 27 |
+
frontend = frontend_class(**kwargs["frontend_conf"])
|
| 28 |
+
|
| 29 |
+
self.frontend = frontend
|
| 30 |
+
if "frontend" in kwargs:
|
| 31 |
+
del kwargs["frontend"]
|
| 32 |
+
self.kwargs = kwargs
|
| 33 |
+
|
| 34 |
+
def __call__(self, input, input_len=None, kwargs=None, **cfg):
|
| 35 |
+
|
| 36 |
+
kwargs = self.kwargs if kwargs is None else kwargs
|
| 37 |
+
kwargs.update(cfg)
|
| 38 |
+
|
| 39 |
+
key_list, data_list = prepare_data_iterator(input, input_len=input_len)
|
| 40 |
+
batch_size = kwargs.get("batch_size", 1)
|
| 41 |
+
device = kwargs.get("device", "cpu")
|
| 42 |
+
if device == "cpu":
|
| 43 |
+
batch_size = 1
|
| 44 |
+
|
| 45 |
+
meta_data = {}
|
| 46 |
+
|
| 47 |
+
result_list = []
|
| 48 |
+
num_samples = len(data_list)
|
| 49 |
+
pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
|
| 50 |
+
|
| 51 |
+
time0 = time.perf_counter()
|
| 52 |
+
for beg_idx in range(0, num_samples, batch_size):
|
| 53 |
+
end_idx = min(num_samples, beg_idx + batch_size)
|
| 54 |
+
data_batch = data_list[beg_idx:end_idx]
|
| 55 |
+
key_batch = key_list[beg_idx:end_idx]
|
| 56 |
+
|
| 57 |
+
# extract fbank feats
|
| 58 |
+
time1 = time.perf_counter()
|
| 59 |
+
audio_sample_list = load_audio_text_image_video(
|
| 60 |
+
data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)
|
| 61 |
+
)
|
| 62 |
+
time2 = time.perf_counter()
|
| 63 |
+
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
| 64 |
+
speech, speech_lengths = extract_fbank(
|
| 65 |
+
audio_sample_list,
|
| 66 |
+
data_type=kwargs.get("data_type", "sound"),
|
| 67 |
+
frontend=self.frontend,
|
| 68 |
+
**kwargs,
|
| 69 |
+
)
|
| 70 |
+
time3 = time.perf_counter()
|
| 71 |
+
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
| 72 |
+
meta_data["batch_data_time"] = (
|
| 73 |
+
speech_lengths.sum().item()
|
| 74 |
+
* self.frontend.frame_shift
|
| 75 |
+
* self.frontend.lfr_n
|
| 76 |
+
/ 1000
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
speech.to(device=device), speech_lengths.to(device=device)
|
| 80 |
+
batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
|
| 81 |
+
result_list.append(batch)
|
| 82 |
+
|
| 83 |
+
pbar.update(1)
|
| 84 |
+
description = f"{meta_data}, "
|
| 85 |
+
pbar.set_description(description)
|
| 86 |
+
|
| 87 |
+
time_end = time.perf_counter()
|
| 88 |
+
pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
|
| 89 |
+
|
| 90 |
+
return result_list
|
demo/Step-Audio-EditX/funasr_detach/auto/auto_model.py
ADDED
|
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
import copy
|
| 4 |
+
import torch
|
| 5 |
+
import random
|
| 6 |
+
import string
|
| 7 |
+
import logging
|
| 8 |
+
import os.path
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from funasr_detach.register import tables
|
| 13 |
+
from funasr_detach.utils.load_utils import load_bytes
|
| 14 |
+
from funasr_detach.download.file import download_from_url
|
| 15 |
+
from funasr_detach.download.download_from_hub import download_model
|
| 16 |
+
from funasr_detach.utils.vad_utils import slice_padding_audio_samples
|
| 17 |
+
from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
|
| 18 |
+
from funasr_detach.train_utils.load_pretrained_model import load_pretrained_model
|
| 19 |
+
from funasr_detach.utils.load_utils import load_audio_text_image_video
|
| 20 |
+
from funasr_detach.utils.timestamp_tools import timestamp_sentence
|
| 21 |
+
from funasr_detach.models.campplus.utils import sv_chunk, postprocess, distribute_spk
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from funasr_detach.models.campplus.cluster_backend import ClusterBackend
|
| 25 |
+
except:
|
| 26 |
+
print("If you want to use the speaker diarization, please `pip install hdbscan`")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
:param input:
|
| 33 |
+
:param input_len:
|
| 34 |
+
:param data_type:
|
| 35 |
+
:param frontend:
|
| 36 |
+
:return:
|
| 37 |
+
"""
|
| 38 |
+
data_list = []
|
| 39 |
+
key_list = []
|
| 40 |
+
filelist = [".scp", ".txt", ".json", ".jsonl"]
|
| 41 |
+
|
| 42 |
+
chars = string.ascii_letters + string.digits
|
| 43 |
+
if isinstance(data_in, str) and data_in.startswith("http"): # url
|
| 44 |
+
data_in = download_from_url(data_in)
|
| 45 |
+
if isinstance(data_in, str) and os.path.exists(
|
| 46 |
+
data_in
|
| 47 |
+
): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
|
| 48 |
+
_, file_extension = os.path.splitext(data_in)
|
| 49 |
+
file_extension = file_extension.lower()
|
| 50 |
+
if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
|
| 51 |
+
with open(data_in, encoding="utf-8") as fin:
|
| 52 |
+
for line in fin:
|
| 53 |
+
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
| 54 |
+
if data_in.endswith(
|
| 55 |
+
".jsonl"
|
| 56 |
+
): # file.jsonl: json.dumps({"source": data})
|
| 57 |
+
lines = json.loads(line.strip())
|
| 58 |
+
data = lines["source"]
|
| 59 |
+
key = data["key"] if "key" in data else key
|
| 60 |
+
else: # filelist, wav.scp, text.txt: id \t data or data
|
| 61 |
+
lines = line.strip().split(maxsplit=1)
|
| 62 |
+
data = lines[1] if len(lines) > 1 else lines[0]
|
| 63 |
+
key = lines[0] if len(lines) > 1 else key
|
| 64 |
+
|
| 65 |
+
data_list.append(data)
|
| 66 |
+
key_list.append(key)
|
| 67 |
+
else:
|
| 68 |
+
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
| 69 |
+
data_list = [data_in]
|
| 70 |
+
key_list = [key]
|
| 71 |
+
elif isinstance(data_in, (list, tuple)):
|
| 72 |
+
if data_type is not None and isinstance(
|
| 73 |
+
data_type, (list, tuple)
|
| 74 |
+
): # mutiple inputs
|
| 75 |
+
data_list_tmp = []
|
| 76 |
+
for data_in_i, data_type_i in zip(data_in, data_type):
|
| 77 |
+
key_list, data_list_i = prepare_data_iterator(
|
| 78 |
+
data_in=data_in_i, data_type=data_type_i
|
| 79 |
+
)
|
| 80 |
+
data_list_tmp.append(data_list_i)
|
| 81 |
+
data_list = []
|
| 82 |
+
for item in zip(*data_list_tmp):
|
| 83 |
+
data_list.append(item)
|
| 84 |
+
else:
|
| 85 |
+
# [audio sample point, fbank, text]
|
| 86 |
+
data_list = data_in
|
| 87 |
+
key_list = [
|
| 88 |
+
"rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
| 89 |
+
for _ in range(len(data_in))
|
| 90 |
+
]
|
| 91 |
+
else: # raw text; audio sample point, fbank; bytes
|
| 92 |
+
if isinstance(data_in, bytes): # audio bytes
|
| 93 |
+
data_in = load_bytes(data_in)
|
| 94 |
+
if key is None:
|
| 95 |
+
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
| 96 |
+
data_list = [data_in]
|
| 97 |
+
key_list = [key]
|
| 98 |
+
|
| 99 |
+
return key_list, data_list
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class AutoModel:
|
| 103 |
+
|
| 104 |
+
def __init__(self, **kwargs):
|
| 105 |
+
if not kwargs.get("disable_log", False):
|
| 106 |
+
tables.print()
|
| 107 |
+
|
| 108 |
+
model, kwargs = self.build_model(**kwargs)
|
| 109 |
+
|
| 110 |
+
# if vad_model is not None, build vad model else None
|
| 111 |
+
vad_model = kwargs.get("vad_model", None)
|
| 112 |
+
vad_kwargs = kwargs.get("vad_model_revision", None)
|
| 113 |
+
if vad_model is not None:
|
| 114 |
+
logging.info("Building VAD model.")
|
| 115 |
+
vad_kwargs = {
|
| 116 |
+
"model": vad_model,
|
| 117 |
+
"model_revision": vad_kwargs,
|
| 118 |
+
"device": kwargs["device"],
|
| 119 |
+
}
|
| 120 |
+
vad_model, vad_kwargs = self.build_model(**vad_kwargs)
|
| 121 |
+
|
| 122 |
+
# if punc_model is not None, build punc model else None
|
| 123 |
+
punc_model = kwargs.get("punc_model", None)
|
| 124 |
+
punc_kwargs = kwargs.get("punc_model_revision", None)
|
| 125 |
+
if punc_model is not None:
|
| 126 |
+
logging.info("Building punc model.")
|
| 127 |
+
punc_kwargs = {
|
| 128 |
+
"model": punc_model,
|
| 129 |
+
"model_revision": punc_kwargs,
|
| 130 |
+
"device": kwargs["device"],
|
| 131 |
+
}
|
| 132 |
+
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
|
| 133 |
+
|
| 134 |
+
# if spk_model is not None, build spk model else None
|
| 135 |
+
spk_model = kwargs.get("spk_model", None)
|
| 136 |
+
spk_kwargs = kwargs.get("spk_model_revision", None)
|
| 137 |
+
if spk_model is not None:
|
| 138 |
+
logging.info("Building SPK model.")
|
| 139 |
+
spk_kwargs = {
|
| 140 |
+
"model": spk_model,
|
| 141 |
+
"model_revision": spk_kwargs,
|
| 142 |
+
"device": kwargs["device"],
|
| 143 |
+
}
|
| 144 |
+
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
|
| 145 |
+
self.cb_model = ClusterBackend().to(kwargs["device"])
|
| 146 |
+
spk_mode = kwargs.get("spk_mode", "punc_segment")
|
| 147 |
+
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
|
| 148 |
+
logging.error(
|
| 149 |
+
"spk_mode should be one of default, vad_segment and punc_segment."
|
| 150 |
+
)
|
| 151 |
+
self.spk_mode = spk_mode
|
| 152 |
+
|
| 153 |
+
self.kwargs = kwargs
|
| 154 |
+
self.model = model
|
| 155 |
+
self.vad_model = vad_model
|
| 156 |
+
self.vad_kwargs = vad_kwargs
|
| 157 |
+
self.punc_model = punc_model
|
| 158 |
+
self.punc_kwargs = punc_kwargs
|
| 159 |
+
self.spk_model = spk_model
|
| 160 |
+
self.spk_kwargs = spk_kwargs
|
| 161 |
+
self.model_path = kwargs.get("model_path")
|
| 162 |
+
self.repo_path = kwargs.get("repo_path")
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def build_model(self, **kwargs):
|
| 166 |
+
assert "model" in kwargs
|
| 167 |
+
if "model_conf" not in kwargs:
|
| 168 |
+
logging.info(
|
| 169 |
+
"download models from model hub: {}".format(
|
| 170 |
+
kwargs.get("model_hub", "ms")
|
| 171 |
+
)
|
| 172 |
+
)
|
| 173 |
+
kwargs = download_model(**kwargs)
|
| 174 |
+
|
| 175 |
+
set_all_random_seed(kwargs.get("seed", 0))
|
| 176 |
+
|
| 177 |
+
device = kwargs.get("device", "cuda")
|
| 178 |
+
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
|
| 179 |
+
device = "cpu"
|
| 180 |
+
kwargs["batch_size"] = 1
|
| 181 |
+
kwargs["device"] = device
|
| 182 |
+
|
| 183 |
+
if kwargs.get("ncpu", None):
|
| 184 |
+
torch.set_num_threads(kwargs.get("ncpu"))
|
| 185 |
+
|
| 186 |
+
# build tokenizer
|
| 187 |
+
tokenizer = kwargs.get("tokenizer", None)
|
| 188 |
+
if tokenizer is not None:
|
| 189 |
+
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
| 190 |
+
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
| 191 |
+
kwargs["tokenizer"] = tokenizer
|
| 192 |
+
kwargs["token_list"] = tokenizer.token_list
|
| 193 |
+
vocab_size = len(tokenizer.token_list)
|
| 194 |
+
else:
|
| 195 |
+
vocab_size = -1
|
| 196 |
+
|
| 197 |
+
# build frontend
|
| 198 |
+
frontend = kwargs.get("frontend", None)
|
| 199 |
+
if frontend is not None:
|
| 200 |
+
frontend_class = tables.frontend_classes.get(frontend)
|
| 201 |
+
frontend = frontend_class(**kwargs["frontend_conf"])
|
| 202 |
+
kwargs["frontend"] = frontend
|
| 203 |
+
kwargs["input_size"] = frontend.output_size()
|
| 204 |
+
|
| 205 |
+
# build model
|
| 206 |
+
model_class = tables.model_classes.get(kwargs["model"])
|
| 207 |
+
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
|
| 208 |
+
|
| 209 |
+
model.to(device)
|
| 210 |
+
|
| 211 |
+
# init_param
|
| 212 |
+
init_param = kwargs.get("init_param", None)
|
| 213 |
+
if init_param is not None:
|
| 214 |
+
logging.info(f"Loading pretrained params from {init_param}")
|
| 215 |
+
load_pretrained_model(
|
| 216 |
+
model=model,
|
| 217 |
+
path=init_param,
|
| 218 |
+
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
|
| 219 |
+
oss_bucket=kwargs.get("oss_bucket", None),
|
| 220 |
+
scope_map=kwargs.get("scope_map", None),
|
| 221 |
+
excludes=kwargs.get("excludes", None),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
return model, kwargs
|
| 225 |
+
|
| 226 |
+
def __call__(self, *args, **cfg):
|
| 227 |
+
kwargs = self.kwargs
|
| 228 |
+
kwargs.update(cfg)
|
| 229 |
+
res = self.model(*args, kwargs)
|
| 230 |
+
return res
|
| 231 |
+
|
| 232 |
+
def generate(self, input, input_len=None, **cfg):
|
| 233 |
+
if self.vad_model is None:
|
| 234 |
+
return self.inference(input, input_len=input_len, **cfg)
|
| 235 |
+
|
| 236 |
+
else:
|
| 237 |
+
return self.inference_with_vad(input, input_len=input_len, **cfg)
|
| 238 |
+
|
| 239 |
+
def inference(
|
| 240 |
+
self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
|
| 241 |
+
):
|
| 242 |
+
kwargs = self.kwargs if kwargs is None else kwargs
|
| 243 |
+
kwargs.update(cfg)
|
| 244 |
+
model = self.model if model is None else model
|
| 245 |
+
model = model.cuda()
|
| 246 |
+
model.eval()
|
| 247 |
+
|
| 248 |
+
batch_size = kwargs.get("batch_size", 1)
|
| 249 |
+
# if kwargs.get("device", "cpu") == "cpu":
|
| 250 |
+
# batch_size = 1
|
| 251 |
+
|
| 252 |
+
key_list, data_list = prepare_data_iterator(
|
| 253 |
+
input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
speed_stats = {}
|
| 257 |
+
asr_result_list = []
|
| 258 |
+
num_samples = len(data_list)
|
| 259 |
+
disable_pbar = kwargs.get("disable_pbar", False)
|
| 260 |
+
pbar = (
|
| 261 |
+
tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
|
| 262 |
+
if not disable_pbar
|
| 263 |
+
else None
|
| 264 |
+
)
|
| 265 |
+
time_speech_total = 0.0
|
| 266 |
+
time_escape_total = 0.0
|
| 267 |
+
for beg_idx in range(0, num_samples, batch_size):
|
| 268 |
+
end_idx = min(num_samples, beg_idx + batch_size)
|
| 269 |
+
data_batch = data_list[beg_idx:end_idx]
|
| 270 |
+
key_batch = key_list[beg_idx:end_idx]
|
| 271 |
+
batch = {"data_in": data_batch, "key": key_batch}
|
| 272 |
+
if (end_idx - beg_idx) == 1 and kwargs.get(
|
| 273 |
+
"data_type", None
|
| 274 |
+
) == "fbank": # fbank
|
| 275 |
+
batch["data_in"] = data_batch[0]
|
| 276 |
+
batch["data_lengths"] = input_len
|
| 277 |
+
|
| 278 |
+
time1 = time.perf_counter()
|
| 279 |
+
with torch.no_grad():
|
| 280 |
+
results, meta_data = model.inference(**batch, **kwargs)
|
| 281 |
+
time2 = time.perf_counter()
|
| 282 |
+
|
| 283 |
+
asr_result_list.extend(results)
|
| 284 |
+
|
| 285 |
+
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
|
| 286 |
+
batch_data_time = meta_data.get("batch_data_time", -1)
|
| 287 |
+
time_escape = time2 - time1
|
| 288 |
+
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
|
| 289 |
+
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
|
| 290 |
+
speed_stats["forward"] = f"{time_escape:0.3f}"
|
| 291 |
+
speed_stats["batch_size"] = f"{len(results)}"
|
| 292 |
+
speed_stats["time_cost"] = f"{(time_escape)}"
|
| 293 |
+
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
|
| 294 |
+
description = f"{speed_stats}, "
|
| 295 |
+
if pbar:
|
| 296 |
+
pbar.update(1)
|
| 297 |
+
pbar.set_description(description)
|
| 298 |
+
time_speech_total += batch_data_time
|
| 299 |
+
time_escape_total += time_escape
|
| 300 |
+
|
| 301 |
+
if pbar:
|
| 302 |
+
# pbar.update(1)
|
| 303 |
+
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
|
| 304 |
+
torch.cuda.empty_cache()
|
| 305 |
+
return asr_result_list
|
| 306 |
+
|
| 307 |
+
def inference_with_vad(self, input, input_len=None, **cfg):
|
| 308 |
+
|
| 309 |
+
# step.1: compute the vad model
|
| 310 |
+
self.vad_kwargs.update(cfg)
|
| 311 |
+
beg_vad = time.time()
|
| 312 |
+
res = self.inference(
|
| 313 |
+
input,
|
| 314 |
+
input_len=input_len,
|
| 315 |
+
model=self.vad_model,
|
| 316 |
+
kwargs=self.vad_kwargs,
|
| 317 |
+
**cfg,
|
| 318 |
+
)
|
| 319 |
+
end_vad = time.time()
|
| 320 |
+
print(f"time cost vad: {end_vad - beg_vad:0.3f}")
|
| 321 |
+
|
| 322 |
+
# step.2 compute asr model
|
| 323 |
+
model = self.model
|
| 324 |
+
kwargs = self.kwargs
|
| 325 |
+
kwargs.update(cfg)
|
| 326 |
+
batch_size = int(kwargs.get("batch_size_s", 300)) * 1000
|
| 327 |
+
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
|
| 328 |
+
kwargs["batch_size"] = batch_size
|
| 329 |
+
|
| 330 |
+
key_list, data_list = prepare_data_iterator(
|
| 331 |
+
input, input_len=input_len, data_type=kwargs.get("data_type", None)
|
| 332 |
+
)
|
| 333 |
+
results_ret_list = []
|
| 334 |
+
time_speech_total_all_samples = 1e-6
|
| 335 |
+
|
| 336 |
+
beg_total = time.time()
|
| 337 |
+
pbar_total = tqdm(colour="red", total=len(res), dynamic_ncols=True)
|
| 338 |
+
for i in range(len(res)):
|
| 339 |
+
key = res[i]["key"]
|
| 340 |
+
vadsegments = res[i]["value"]
|
| 341 |
+
input_i = data_list[i]
|
| 342 |
+
speech = load_audio_text_image_video(
|
| 343 |
+
input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000)
|
| 344 |
+
)
|
| 345 |
+
speech_lengths = len(speech)
|
| 346 |
+
n = len(vadsegments)
|
| 347 |
+
data_with_index = [(vadsegments[i], i) for i in range(n)]
|
| 348 |
+
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
|
| 349 |
+
results_sorted = []
|
| 350 |
+
|
| 351 |
+
if not len(sorted_data):
|
| 352 |
+
logging.info("decoding, utt: {}, empty speech".format(key))
|
| 353 |
+
continue
|
| 354 |
+
|
| 355 |
+
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
|
| 356 |
+
batch_size = max(
|
| 357 |
+
batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
batch_size_ms_cum = 0
|
| 361 |
+
beg_idx = 0
|
| 362 |
+
beg_asr_total = time.time()
|
| 363 |
+
time_speech_total_per_sample = speech_lengths / 16000
|
| 364 |
+
time_speech_total_all_samples += time_speech_total_per_sample
|
| 365 |
+
|
| 366 |
+
all_segments = []
|
| 367 |
+
for j, _ in enumerate(range(0, n)):
|
| 368 |
+
# pbar_sample.update(1)
|
| 369 |
+
batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
|
| 370 |
+
if (
|
| 371 |
+
j < n - 1
|
| 372 |
+
and (
|
| 373 |
+
batch_size_ms_cum
|
| 374 |
+
+ sorted_data[j + 1][0][1]
|
| 375 |
+
- sorted_data[j + 1][0][0]
|
| 376 |
+
)
|
| 377 |
+
< batch_size
|
| 378 |
+
and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0])
|
| 379 |
+
< batch_size_threshold_ms
|
| 380 |
+
):
|
| 381 |
+
continue
|
| 382 |
+
batch_size_ms_cum = 0
|
| 383 |
+
end_idx = j + 1
|
| 384 |
+
speech_j, speech_lengths_j = slice_padding_audio_samples(
|
| 385 |
+
speech, speech_lengths, sorted_data[beg_idx:end_idx]
|
| 386 |
+
)
|
| 387 |
+
results = self.inference(
|
| 388 |
+
speech_j,
|
| 389 |
+
input_len=None,
|
| 390 |
+
model=model,
|
| 391 |
+
kwargs=kwargs,
|
| 392 |
+
disable_pbar=True,
|
| 393 |
+
**cfg,
|
| 394 |
+
)
|
| 395 |
+
if self.spk_model is not None:
|
| 396 |
+
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
|
| 397 |
+
for _b in range(len(speech_j)):
|
| 398 |
+
vad_segments = [
|
| 399 |
+
[
|
| 400 |
+
sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
|
| 401 |
+
sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
|
| 402 |
+
np.array(speech_j[_b]),
|
| 403 |
+
]
|
| 404 |
+
]
|
| 405 |
+
segments = sv_chunk(vad_segments)
|
| 406 |
+
all_segments.extend(segments)
|
| 407 |
+
speech_b = [i[2] for i in segments]
|
| 408 |
+
spk_res = self.inference(
|
| 409 |
+
speech_b,
|
| 410 |
+
input_len=None,
|
| 411 |
+
model=self.spk_model,
|
| 412 |
+
kwargs=kwargs,
|
| 413 |
+
disable_pbar=True,
|
| 414 |
+
**cfg,
|
| 415 |
+
)
|
| 416 |
+
results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
|
| 417 |
+
beg_idx = end_idx
|
| 418 |
+
if len(results) < 1:
|
| 419 |
+
continue
|
| 420 |
+
results_sorted.extend(results)
|
| 421 |
+
|
| 422 |
+
restored_data = [0] * n
|
| 423 |
+
for j in range(n):
|
| 424 |
+
index = sorted_data[j][1]
|
| 425 |
+
restored_data[index] = results_sorted[j]
|
| 426 |
+
result = {}
|
| 427 |
+
|
| 428 |
+
# results combine for texts, timestamps, speaker embeddings and others
|
| 429 |
+
# TODO: rewrite for clean code
|
| 430 |
+
for j in range(n):
|
| 431 |
+
for k, v in restored_data[j].items():
|
| 432 |
+
if k.startswith("timestamp"):
|
| 433 |
+
if k not in result:
|
| 434 |
+
result[k] = []
|
| 435 |
+
for t in restored_data[j][k]:
|
| 436 |
+
t[0] += vadsegments[j][0]
|
| 437 |
+
t[1] += vadsegments[j][0]
|
| 438 |
+
result[k].extend(restored_data[j][k])
|
| 439 |
+
elif k == "spk_embedding":
|
| 440 |
+
if k not in result:
|
| 441 |
+
result[k] = restored_data[j][k]
|
| 442 |
+
else:
|
| 443 |
+
result[k] = torch.cat(
|
| 444 |
+
[result[k], restored_data[j][k]], dim=0
|
| 445 |
+
)
|
| 446 |
+
elif "text" in k:
|
| 447 |
+
if k not in result:
|
| 448 |
+
result[k] = restored_data[j][k]
|
| 449 |
+
else:
|
| 450 |
+
result[k] += " " + restored_data[j][k]
|
| 451 |
+
else:
|
| 452 |
+
if k not in result:
|
| 453 |
+
result[k] = restored_data[j][k]
|
| 454 |
+
else:
|
| 455 |
+
result[k] += restored_data[j][k]
|
| 456 |
+
|
| 457 |
+
return_raw_text = kwargs.get("return_raw_text", False)
|
| 458 |
+
# step.3 compute punc model
|
| 459 |
+
if self.punc_model is not None:
|
| 460 |
+
self.punc_kwargs.update(cfg)
|
| 461 |
+
punc_res = self.inference(
|
| 462 |
+
result["text"],
|
| 463 |
+
model=self.punc_model,
|
| 464 |
+
kwargs=self.punc_kwargs,
|
| 465 |
+
disable_pbar=True,
|
| 466 |
+
**cfg,
|
| 467 |
+
)
|
| 468 |
+
raw_text = copy.copy(result["text"])
|
| 469 |
+
if return_raw_text:
|
| 470 |
+
result["raw_text"] = raw_text
|
| 471 |
+
result["text"] = punc_res[0]["text"]
|
| 472 |
+
else:
|
| 473 |
+
raw_text = None
|
| 474 |
+
|
| 475 |
+
# speaker embedding cluster after resorted
|
| 476 |
+
if self.spk_model is not None and kwargs.get("return_spk_res", True):
|
| 477 |
+
if raw_text is None:
|
| 478 |
+
logging.error("Missing punc_model, which is required by spk_model.")
|
| 479 |
+
all_segments = sorted(all_segments, key=lambda x: x[0])
|
| 480 |
+
spk_embedding = result["spk_embedding"]
|
| 481 |
+
labels = self.cb_model(
|
| 482 |
+
spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None)
|
| 483 |
+
)
|
| 484 |
+
# del result['spk_embedding']
|
| 485 |
+
sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
|
| 486 |
+
if self.spk_mode == "vad_segment": # recover sentence_list
|
| 487 |
+
sentence_list = []
|
| 488 |
+
for res, vadsegment in zip(restored_data, vadsegments):
|
| 489 |
+
if "timestamp" not in res:
|
| 490 |
+
logging.error(
|
| 491 |
+
"Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
|
| 492 |
+
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
|
| 493 |
+
can predict timestamp, and speaker diarization relies on timestamps."
|
| 494 |
+
)
|
| 495 |
+
sentence_list.append(
|
| 496 |
+
{
|
| 497 |
+
"start": vadsegment[0],
|
| 498 |
+
"end": vadsegment[1],
|
| 499 |
+
"sentence": res["text"],
|
| 500 |
+
"timestamp": res["timestamp"],
|
| 501 |
+
}
|
| 502 |
+
)
|
| 503 |
+
elif self.spk_mode == "punc_segment":
|
| 504 |
+
if "timestamp" not in result:
|
| 505 |
+
logging.error(
|
| 506 |
+
"Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
|
| 507 |
+
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
|
| 508 |
+
can predict timestamp, and speaker diarization relies on timestamps."
|
| 509 |
+
)
|
| 510 |
+
sentence_list = timestamp_sentence(
|
| 511 |
+
punc_res[0]["punc_array"],
|
| 512 |
+
result["timestamp"],
|
| 513 |
+
raw_text,
|
| 514 |
+
return_raw_text=return_raw_text,
|
| 515 |
+
)
|
| 516 |
+
distribute_spk(sentence_list, sv_output)
|
| 517 |
+
result["sentence_info"] = sentence_list
|
| 518 |
+
elif kwargs.get("sentence_timestamp", False):
|
| 519 |
+
sentence_list = timestamp_sentence(
|
| 520 |
+
punc_res[0]["punc_array"],
|
| 521 |
+
result["timestamp"],
|
| 522 |
+
raw_text,
|
| 523 |
+
return_raw_text=return_raw_text,
|
| 524 |
+
)
|
| 525 |
+
result["sentence_info"] = sentence_list
|
| 526 |
+
if "spk_embedding" in result:
|
| 527 |
+
del result["spk_embedding"]
|
| 528 |
+
|
| 529 |
+
result["key"] = key
|
| 530 |
+
results_ret_list.append(result)
|
| 531 |
+
end_asr_total = time.time()
|
| 532 |
+
time_escape_total_per_sample = end_asr_total - beg_asr_total
|
| 533 |
+
pbar_total.update(1)
|
| 534 |
+
pbar_total.set_description(
|
| 535 |
+
f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
| 536 |
+
f"time_speech: {time_speech_total_per_sample: 0.3f}, "
|
| 537 |
+
f"time_escape: {time_escape_total_per_sample:0.3f}"
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
return results_ret_list
|
| 541 |
+
|
| 542 |
+
def infer_encoder(
|
| 543 |
+
self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
|
| 544 |
+
):
|
| 545 |
+
kwargs = self.kwargs if kwargs is None else kwargs
|
| 546 |
+
kwargs.update(cfg)
|
| 547 |
+
model = self.model if model is None else model
|
| 548 |
+
model = model.cuda()
|
| 549 |
+
model.eval()
|
| 550 |
+
|
| 551 |
+
batch_size = kwargs.get("batch_size", 1)
|
| 552 |
+
|
| 553 |
+
key_list, data_list = prepare_data_iterator(
|
| 554 |
+
input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
asr_result_list = []
|
| 558 |
+
num_samples = len(data_list)
|
| 559 |
+
for beg_idx in range(0, num_samples, batch_size):
|
| 560 |
+
end_idx = min(num_samples, beg_idx + batch_size)
|
| 561 |
+
data_batch = data_list[beg_idx:end_idx]
|
| 562 |
+
key_batch = key_list[beg_idx:end_idx]
|
| 563 |
+
batch = {"data_in": data_batch, "key": key_batch}
|
| 564 |
+
if (end_idx - beg_idx) == 1 and kwargs.get(
|
| 565 |
+
"data_type", None
|
| 566 |
+
) == "fbank": # fbank
|
| 567 |
+
batch["data_in"] = data_batch[0]
|
| 568 |
+
batch["data_lengths"] = input_len
|
| 569 |
+
|
| 570 |
+
with torch.no_grad():
|
| 571 |
+
results, meta_data, cache = model.infer_encoder(**batch, **kwargs)
|
| 572 |
+
asr_result_list.extend(results)
|
| 573 |
+
|
| 574 |
+
torch.cuda.empty_cache()
|
| 575 |
+
return asr_result_list, cache
|
demo/Step-Audio-EditX/funasr_detach/auto/auto_tokenizer.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class AutoTokenizer:
|
| 2 |
+
"""
|
| 3 |
+
Undo
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
def __init__(self):
|
| 7 |
+
pass
|
demo/Step-Audio-EditX/funasr_detach/bin/__init__.py
ADDED
|
File without changes
|
demo/Step-Audio-EditX/funasr_detach/bin/compute_audio_cmvn.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import hydra
|
| 6 |
+
import logging
|
| 7 |
+
from omegaconf import DictConfig, OmegaConf
|
| 8 |
+
|
| 9 |
+
from funasr_detach.register import tables
|
| 10 |
+
from funasr_detach.download.download_from_hub import download_model
|
| 11 |
+
from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@hydra.main(config_name=None, version_base=None)
|
| 15 |
+
def main_hydra(kwargs: DictConfig):
|
| 16 |
+
if kwargs.get("debug", False):
|
| 17 |
+
import pdb
|
| 18 |
+
|
| 19 |
+
pdb.set_trace()
|
| 20 |
+
|
| 21 |
+
assert "model" in kwargs
|
| 22 |
+
if "model_conf" not in kwargs:
|
| 23 |
+
logging.info(
|
| 24 |
+
"download models from model hub: {}".format(kwargs.get("model_hub", "ms"))
|
| 25 |
+
)
|
| 26 |
+
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
|
| 27 |
+
|
| 28 |
+
main(**kwargs)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def main(**kwargs):
|
| 32 |
+
print(kwargs)
|
| 33 |
+
# set random seed
|
| 34 |
+
tables.print()
|
| 35 |
+
set_all_random_seed(kwargs.get("seed", 0))
|
| 36 |
+
torch.backends.cudnn.enabled = kwargs.get(
|
| 37 |
+
"cudnn_enabled", torch.backends.cudnn.enabled
|
| 38 |
+
)
|
| 39 |
+
torch.backends.cudnn.benchmark = kwargs.get(
|
| 40 |
+
"cudnn_benchmark", torch.backends.cudnn.benchmark
|
| 41 |
+
)
|
| 42 |
+
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
|
| 43 |
+
|
| 44 |
+
tokenizer = kwargs.get("tokenizer", None)
|
| 45 |
+
|
| 46 |
+
# build frontend if frontend is none None
|
| 47 |
+
frontend = kwargs.get("frontend", None)
|
| 48 |
+
if frontend is not None:
|
| 49 |
+
frontend_class = tables.frontend_classes.get(frontend)
|
| 50 |
+
frontend = frontend_class(**kwargs["frontend_conf"])
|
| 51 |
+
kwargs["frontend"] = frontend
|
| 52 |
+
kwargs["input_size"] = frontend.output_size()
|
| 53 |
+
|
| 54 |
+
# dataset
|
| 55 |
+
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
|
| 56 |
+
dataset_train = dataset_class(
|
| 57 |
+
kwargs.get("train_data_set_list"),
|
| 58 |
+
frontend=frontend,
|
| 59 |
+
tokenizer=None,
|
| 60 |
+
is_training=False,
|
| 61 |
+
**kwargs.get("dataset_conf")
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# dataloader
|
| 65 |
+
batch_sampler = kwargs["dataset_conf"].get(
|
| 66 |
+
"batch_sampler", "DynamicBatchLocalShuffleSampler"
|
| 67 |
+
)
|
| 68 |
+
batch_sampler_train = None
|
| 69 |
+
if batch_sampler is not None:
|
| 70 |
+
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
|
| 71 |
+
dataset_conf = kwargs.get("dataset_conf")
|
| 72 |
+
dataset_conf["batch_type"] = "example"
|
| 73 |
+
dataset_conf["batch_size"] = 1
|
| 74 |
+
batch_sampler_train = batch_sampler_class(
|
| 75 |
+
dataset_train, is_training=False, **dataset_conf
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
dataloader_train = torch.utils.data.DataLoader(
|
| 79 |
+
dataset_train,
|
| 80 |
+
collate_fn=dataset_train.collator,
|
| 81 |
+
batch_sampler=batch_sampler_train,
|
| 82 |
+
num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)),
|
| 83 |
+
pin_memory=True,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
iter_stop = int(kwargs.get("scale", 1.0) * len(dataloader_train))
|
| 87 |
+
|
| 88 |
+
total_frames = 0
|
| 89 |
+
for batch_idx, batch in enumerate(dataloader_train):
|
| 90 |
+
if batch_idx >= iter_stop:
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
fbank = batch["speech"].numpy()[0, :, :]
|
| 94 |
+
if total_frames == 0:
|
| 95 |
+
mean_stats = np.sum(fbank, axis=0)
|
| 96 |
+
var_stats = np.sum(np.square(fbank), axis=0)
|
| 97 |
+
else:
|
| 98 |
+
mean_stats += np.sum(fbank, axis=0)
|
| 99 |
+
var_stats += np.sum(np.square(fbank), axis=0)
|
| 100 |
+
total_frames += fbank.shape[0]
|
| 101 |
+
|
| 102 |
+
cmvn_info = {
|
| 103 |
+
"mean_stats": list(mean_stats.tolist()),
|
| 104 |
+
"var_stats": list(var_stats.tolist()),
|
| 105 |
+
"total_frames": total_frames,
|
| 106 |
+
}
|
| 107 |
+
cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
|
| 108 |
+
# import pdb;pdb.set_trace()
|
| 109 |
+
with open(cmvn_file, "w") as fout:
|
| 110 |
+
fout.write(json.dumps(cmvn_info))
|
| 111 |
+
|
| 112 |
+
mean = -1.0 * mean_stats / total_frames
|
| 113 |
+
var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean)
|
| 114 |
+
dims = mean.shape[0]
|
| 115 |
+
am_mvn = os.path.dirname(cmvn_file) + "/am.mvn"
|
| 116 |
+
with open(am_mvn, "w") as fout:
|
| 117 |
+
fout.write(
|
| 118 |
+
"<Nnet>"
|
| 119 |
+
+ "\n"
|
| 120 |
+
+ "<Splice> "
|
| 121 |
+
+ str(dims)
|
| 122 |
+
+ " "
|
| 123 |
+
+ str(dims)
|
| 124 |
+
+ "\n"
|
| 125 |
+
+ "[ 0 ]"
|
| 126 |
+
+ "\n"
|
| 127 |
+
+ "<AddShift> "
|
| 128 |
+
+ str(dims)
|
| 129 |
+
+ " "
|
| 130 |
+
+ str(dims)
|
| 131 |
+
+ "\n"
|
| 132 |
+
)
|
| 133 |
+
mean_str = (
|
| 134 |
+
str(list(mean)).replace(",", "").replace("[", "[ ").replace("]", " ]")
|
| 135 |
+
)
|
| 136 |
+
fout.write("<LearnRateCoef> 0 " + mean_str + "\n")
|
| 137 |
+
fout.write("<Rescale> " + str(dims) + " " + str(dims) + "\n")
|
| 138 |
+
var_str = str(list(var)).replace(",", "").replace("[", "[ ").replace("]", " ]")
|
| 139 |
+
fout.write("<LearnRateCoef> 0 " + var_str + "\n")
|
| 140 |
+
fout.write("</Nnet>" + "\n")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
"""
|
| 144 |
+
python funasr/bin/compute_audio_cmvn.py \
|
| 145 |
+
--config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
|
| 146 |
+
--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
|
| 147 |
+
++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
|
| 148 |
+
++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
|
| 149 |
+
++dataset_conf.num_workers=0
|
| 150 |
+
"""
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
main_hydra()
|
demo/Step-Audio-EditX/funasr_detach/bin/inference.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hydra
|
| 2 |
+
import logging
|
| 3 |
+
from omegaconf import DictConfig, OmegaConf, ListConfig
|
| 4 |
+
|
| 5 |
+
from funasr_detach.auto.auto_model import AutoModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@hydra.main(config_name=None, version_base=None)
|
| 9 |
+
def main_hydra(cfg: DictConfig):
|
| 10 |
+
def to_plain_list(cfg_item):
|
| 11 |
+
if isinstance(cfg_item, ListConfig):
|
| 12 |
+
return OmegaConf.to_container(cfg_item, resolve=True)
|
| 13 |
+
elif isinstance(cfg_item, DictConfig):
|
| 14 |
+
return {k: to_plain_list(v) for k, v in cfg_item.items()}
|
| 15 |
+
else:
|
| 16 |
+
return cfg_item
|
| 17 |
+
|
| 18 |
+
kwargs = to_plain_list(cfg)
|
| 19 |
+
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
|
| 20 |
+
|
| 21 |
+
logging.basicConfig(level=log_level)
|
| 22 |
+
|
| 23 |
+
if kwargs.get("debug", False):
|
| 24 |
+
import pdb
|
| 25 |
+
|
| 26 |
+
pdb.set_trace()
|
| 27 |
+
model = AutoModel(**kwargs)
|
| 28 |
+
res = model.generate(input=kwargs["input"])
|
| 29 |
+
print(res)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
main_hydra()
|
demo/Step-Audio-EditX/funasr_detach/bin/tokenize_text.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
from collections import Counter
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import sys
|
| 7 |
+
from typing import List
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from funasr_detach.utils.cli_utils import get_commandline_args
|
| 12 |
+
from funasr_detach.tokenizer.build_tokenizer import build_tokenizer
|
| 13 |
+
from funasr_detach.tokenizer.cleaner import TextCleaner
|
| 14 |
+
from funasr_detach.tokenizer.phoneme_tokenizer import g2p_classes
|
| 15 |
+
from funasr_detach.utils.types import str2bool
|
| 16 |
+
from funasr_detach.utils.types import str_or_none
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def field2slice(field: Optional[str]) -> slice:
|
| 20 |
+
"""Convert field string to slice
|
| 21 |
+
|
| 22 |
+
Note that field string accepts 1-based integer.
|
| 23 |
+
|
| 24 |
+
Examples:
|
| 25 |
+
>>> field2slice("1-")
|
| 26 |
+
slice(0, None, None)
|
| 27 |
+
>>> field2slice("1-3")
|
| 28 |
+
slice(0, 3, None)
|
| 29 |
+
>>> field2slice("-3")
|
| 30 |
+
slice(None, 3, None)
|
| 31 |
+
"""
|
| 32 |
+
field = field.strip()
|
| 33 |
+
try:
|
| 34 |
+
if "-" in field:
|
| 35 |
+
# e.g. "2-" or "2-5" or "-7"
|
| 36 |
+
s1, s2 = field.split("-", maxsplit=1)
|
| 37 |
+
if s1.strip() == "":
|
| 38 |
+
s1 = None
|
| 39 |
+
else:
|
| 40 |
+
s1 = int(s1)
|
| 41 |
+
if s1 == 0:
|
| 42 |
+
raise ValueError("1-based string")
|
| 43 |
+
if s2.strip() == "":
|
| 44 |
+
s2 = None
|
| 45 |
+
else:
|
| 46 |
+
s2 = int(s2)
|
| 47 |
+
else:
|
| 48 |
+
# e.g. "2"
|
| 49 |
+
s1 = int(field)
|
| 50 |
+
s2 = s1 + 1
|
| 51 |
+
if s1 == 0:
|
| 52 |
+
raise ValueError("must be 1 or more value")
|
| 53 |
+
except ValueError:
|
| 54 |
+
raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
|
| 55 |
+
|
| 56 |
+
if s1 is None:
|
| 57 |
+
slic = slice(None, s2)
|
| 58 |
+
else:
|
| 59 |
+
# -1 because of 1-based integer following "cut" command
|
| 60 |
+
# e.g "1-3" -> slice(0, 3)
|
| 61 |
+
slic = slice(s1 - 1, s2)
|
| 62 |
+
return slic
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def tokenize(
|
| 66 |
+
input: str,
|
| 67 |
+
output: str,
|
| 68 |
+
field: Optional[str],
|
| 69 |
+
delimiter: Optional[str],
|
| 70 |
+
token_type: str,
|
| 71 |
+
space_symbol: str,
|
| 72 |
+
non_linguistic_symbols: Optional[str],
|
| 73 |
+
bpemodel: Optional[str],
|
| 74 |
+
log_level: str,
|
| 75 |
+
write_vocabulary: bool,
|
| 76 |
+
vocabulary_size: int,
|
| 77 |
+
remove_non_linguistic_symbols: bool,
|
| 78 |
+
cutoff: int,
|
| 79 |
+
add_symbol: List[str],
|
| 80 |
+
cleaner: Optional[str],
|
| 81 |
+
g2p: Optional[str],
|
| 82 |
+
):
|
| 83 |
+
|
| 84 |
+
logging.basicConfig(
|
| 85 |
+
level=log_level,
|
| 86 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
| 87 |
+
)
|
| 88 |
+
if input == "-":
|
| 89 |
+
fin = sys.stdin
|
| 90 |
+
else:
|
| 91 |
+
fin = Path(input).open("r", encoding="utf-8")
|
| 92 |
+
if output == "-":
|
| 93 |
+
fout = sys.stdout
|
| 94 |
+
else:
|
| 95 |
+
p = Path(output)
|
| 96 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 97 |
+
fout = p.open("w", encoding="utf-8")
|
| 98 |
+
|
| 99 |
+
cleaner = TextCleaner(cleaner)
|
| 100 |
+
tokenizer = build_tokenizer(
|
| 101 |
+
token_type=token_type,
|
| 102 |
+
bpemodel=bpemodel,
|
| 103 |
+
delimiter=delimiter,
|
| 104 |
+
space_symbol=space_symbol,
|
| 105 |
+
non_linguistic_symbols=non_linguistic_symbols,
|
| 106 |
+
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
|
| 107 |
+
g2p_type=g2p,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
counter = Counter()
|
| 111 |
+
if field is not None:
|
| 112 |
+
field = field2slice(field)
|
| 113 |
+
|
| 114 |
+
for line in fin:
|
| 115 |
+
line = line.rstrip()
|
| 116 |
+
if field is not None:
|
| 117 |
+
# e.g. field="2-"
|
| 118 |
+
# uttidA hello world!! -> hello world!!
|
| 119 |
+
tokens = line.split(delimiter)
|
| 120 |
+
tokens = tokens[field]
|
| 121 |
+
if delimiter is None:
|
| 122 |
+
line = " ".join(tokens)
|
| 123 |
+
else:
|
| 124 |
+
line = delimiter.join(tokens)
|
| 125 |
+
|
| 126 |
+
line = cleaner(line)
|
| 127 |
+
tokens = tokenizer.text2tokens(line)
|
| 128 |
+
if not write_vocabulary:
|
| 129 |
+
fout.write(" ".join(tokens) + "\n")
|
| 130 |
+
else:
|
| 131 |
+
for t in tokens:
|
| 132 |
+
counter[t] += 1
|
| 133 |
+
|
| 134 |
+
if not write_vocabulary:
|
| 135 |
+
return
|
| 136 |
+
|
| 137 |
+
## FIXME
|
| 138 |
+
## del duplicate add_symbols in counter
|
| 139 |
+
for symbol_and_id in add_symbol:
|
| 140 |
+
# e.g symbol="<blank>:0"
|
| 141 |
+
try:
|
| 142 |
+
symbol, idx = symbol_and_id.split(":")
|
| 143 |
+
except ValueError:
|
| 144 |
+
raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
|
| 145 |
+
symbol = symbol.strip()
|
| 146 |
+
if symbol in counter:
|
| 147 |
+
del counter[symbol]
|
| 148 |
+
|
| 149 |
+
# ======= write_vocabulary mode from here =======
|
| 150 |
+
# Sort by the number of occurrences in descending order
|
| 151 |
+
# and filter lower frequency words than cutoff value
|
| 152 |
+
words_and_counts = list(
|
| 153 |
+
filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
|
| 154 |
+
)
|
| 155 |
+
# Restrict the vocabulary size
|
| 156 |
+
if vocabulary_size > 0:
|
| 157 |
+
if vocabulary_size < len(add_symbol):
|
| 158 |
+
raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
|
| 159 |
+
words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
|
| 160 |
+
|
| 161 |
+
# Parse the values of --add_symbol
|
| 162 |
+
for symbol_and_id in add_symbol:
|
| 163 |
+
# e.g symbol="<blank>:0"
|
| 164 |
+
try:
|
| 165 |
+
symbol, idx = symbol_and_id.split(":")
|
| 166 |
+
idx = int(idx)
|
| 167 |
+
except ValueError:
|
| 168 |
+
raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
|
| 169 |
+
symbol = symbol.strip()
|
| 170 |
+
|
| 171 |
+
# e.g. idx=0 -> append as the first symbol
|
| 172 |
+
# e.g. idx=-1 -> append as the last symbol
|
| 173 |
+
if idx < 0:
|
| 174 |
+
idx = len(words_and_counts) + 1 + idx
|
| 175 |
+
words_and_counts.insert(idx, (symbol, None))
|
| 176 |
+
|
| 177 |
+
# Write words
|
| 178 |
+
for w, c in words_and_counts:
|
| 179 |
+
fout.write(w + "\n")
|
| 180 |
+
|
| 181 |
+
# Logging
|
| 182 |
+
total_count = sum(counter.values())
|
| 183 |
+
invocab_count = sum(c for w, c in words_and_counts if c is not None)
|
| 184 |
+
logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_parser() -> argparse.ArgumentParser:
|
| 188 |
+
parser = argparse.ArgumentParser(
|
| 189 |
+
description="Tokenize texts",
|
| 190 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--log_level",
|
| 194 |
+
type=lambda x: x.upper(),
|
| 195 |
+
default="INFO",
|
| 196 |
+
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
| 197 |
+
help="The verbose level of logging",
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--input", "-i", required=True, help="Input text. - indicates sys.stdin"
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--output", "-o", required=True, help="Output text. - indicates sys.stdout"
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--field",
|
| 208 |
+
"-f",
|
| 209 |
+
help="The target columns of the input text as 1-based integer. e.g 2-",
|
| 210 |
+
)
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--token_type",
|
| 213 |
+
"-t",
|
| 214 |
+
default="char",
|
| 215 |
+
choices=["char", "bpe", "word", "phn"],
|
| 216 |
+
help="Token type",
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
|
| 219 |
+
parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
|
| 220 |
+
parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--non_linguistic_symbols",
|
| 223 |
+
type=str_or_none,
|
| 224 |
+
help="non_linguistic_symbols file path",
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--remove_non_linguistic_symbols",
|
| 228 |
+
type=str2bool,
|
| 229 |
+
default=False,
|
| 230 |
+
help="Remove non-language-symbols from tokens",
|
| 231 |
+
)
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--cleaner",
|
| 234 |
+
type=str_or_none,
|
| 235 |
+
choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
|
| 236 |
+
default=None,
|
| 237 |
+
help="Apply text cleaning",
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--g2p",
|
| 241 |
+
type=str_or_none,
|
| 242 |
+
choices=g2p_classes,
|
| 243 |
+
default=None,
|
| 244 |
+
help="Specify g2p method if --token_type=phn",
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
group = parser.add_argument_group("write_vocabulary mode related")
|
| 248 |
+
group.add_argument(
|
| 249 |
+
"--write_vocabulary",
|
| 250 |
+
type=str2bool,
|
| 251 |
+
default=False,
|
| 252 |
+
help="Write tokens list instead of tokenized text per line",
|
| 253 |
+
)
|
| 254 |
+
group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
|
| 255 |
+
group.add_argument(
|
| 256 |
+
"--cutoff",
|
| 257 |
+
default=0,
|
| 258 |
+
type=int,
|
| 259 |
+
help="cut-off frequency used for write-vocabulary mode",
|
| 260 |
+
)
|
| 261 |
+
group.add_argument(
|
| 262 |
+
"--add_symbol",
|
| 263 |
+
type=str,
|
| 264 |
+
default=[],
|
| 265 |
+
action="append",
|
| 266 |
+
help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
return parser
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def main(cmd=None):
|
| 273 |
+
print(get_commandline_args(), file=sys.stderr)
|
| 274 |
+
parser = get_parser()
|
| 275 |
+
args = parser.parse_args(cmd)
|
| 276 |
+
kwargs = vars(args)
|
| 277 |
+
tokenize(**kwargs)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
main()
|
demo/Step-Audio-EditX/funasr_detach/bin/train.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import torch
|
| 7 |
+
import hydra
|
| 8 |
+
import logging
|
| 9 |
+
import argparse
|
| 10 |
+
from io import BytesIO
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
from collections.abc import Sequence
|
| 13 |
+
from omegaconf import DictConfig, OmegaConf
|
| 14 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 15 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 16 |
+
|
| 17 |
+
from funasr_detach.register import tables
|
| 18 |
+
from funasr_detach.optimizers import optim_classes
|
| 19 |
+
from funasr_detach.train_utils.trainer import Trainer
|
| 20 |
+
from funasr_detach.schedulers import scheduler_classes
|
| 21 |
+
from funasr_detach.train_utils.initialize import initialize
|
| 22 |
+
from funasr_detach.download.download_from_hub import download_model
|
| 23 |
+
from funasr_detach.models.lora.utils import mark_only_lora_as_trainable
|
| 24 |
+
from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
|
| 25 |
+
from funasr_detach.train_utils.load_pretrained_model import load_pretrained_model
|
| 26 |
+
|
| 27 |
+
# from funasr_detach.tokenizer.build_tokenizer import build_tokenizer
|
| 28 |
+
# from funasr_detach.tokenizer.token_id_converter import TokenIDConverter
|
| 29 |
+
# from funasr_detach.tokenizer.funtoken import build_tokenizer
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@hydra.main(config_name=None, version_base=None)
|
| 33 |
+
def main_hydra(kwargs: DictConfig):
|
| 34 |
+
if kwargs.get("debug", False):
|
| 35 |
+
import pdb
|
| 36 |
+
|
| 37 |
+
pdb.set_trace()
|
| 38 |
+
|
| 39 |
+
assert "model" in kwargs
|
| 40 |
+
if "model_conf" not in kwargs:
|
| 41 |
+
logging.info(
|
| 42 |
+
"download models from model hub: {}".format(kwargs.get("model_hub", "ms"))
|
| 43 |
+
)
|
| 44 |
+
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
|
| 45 |
+
|
| 46 |
+
main(**kwargs)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def main(**kwargs):
|
| 50 |
+
print(kwargs)
|
| 51 |
+
|
| 52 |
+
# set random seed
|
| 53 |
+
set_all_random_seed(kwargs.get("seed", 0))
|
| 54 |
+
torch.backends.cudnn.enabled = kwargs.get(
|
| 55 |
+
"cudnn_enabled", torch.backends.cudnn.enabled
|
| 56 |
+
)
|
| 57 |
+
torch.backends.cudnn.benchmark = kwargs.get(
|
| 58 |
+
"cudnn_benchmark", torch.backends.cudnn.benchmark
|
| 59 |
+
)
|
| 60 |
+
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
|
| 61 |
+
|
| 62 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 63 |
+
if local_rank == 0:
|
| 64 |
+
tables.print()
|
| 65 |
+
# Check if we are using DDP or FSDP
|
| 66 |
+
use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
|
| 67 |
+
use_fsdp = kwargs.get("use_fsdp", None)
|
| 68 |
+
if use_ddp or use_fsdp:
|
| 69 |
+
dist.init_process_group(
|
| 70 |
+
backend=kwargs.get("backend", "nccl"), init_method="env://"
|
| 71 |
+
)
|
| 72 |
+
torch.cuda.set_device(local_rank)
|
| 73 |
+
|
| 74 |
+
# save config.yaml
|
| 75 |
+
if (
|
| 76 |
+
(use_ddp or use_fsdp)
|
| 77 |
+
and dist.get_rank() == 0
|
| 78 |
+
or not (use_ddp or use_fsdp)
|
| 79 |
+
and local_rank == 0
|
| 80 |
+
):
|
| 81 |
+
os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
|
| 82 |
+
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
|
| 83 |
+
OmegaConf.save(config=kwargs, f=yaml_file)
|
| 84 |
+
logging.info("config.yaml is saved to: %s", yaml_file)
|
| 85 |
+
|
| 86 |
+
tokenizer = kwargs.get("tokenizer", None)
|
| 87 |
+
if tokenizer is not None:
|
| 88 |
+
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
| 89 |
+
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
| 90 |
+
kwargs["tokenizer"] = tokenizer
|
| 91 |
+
|
| 92 |
+
# build frontend if frontend is none None
|
| 93 |
+
frontend = kwargs.get("frontend", None)
|
| 94 |
+
if frontend is not None:
|
| 95 |
+
frontend_class = tables.frontend_classes.get(frontend)
|
| 96 |
+
frontend = frontend_class(**kwargs["frontend_conf"])
|
| 97 |
+
kwargs["frontend"] = frontend
|
| 98 |
+
kwargs["input_size"] = frontend.output_size()
|
| 99 |
+
|
| 100 |
+
# build model
|
| 101 |
+
model_class = tables.model_classes.get(kwargs["model"])
|
| 102 |
+
model = model_class(
|
| 103 |
+
**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# init_param
|
| 107 |
+
init_param = kwargs.get("init_param", None)
|
| 108 |
+
if init_param is not None:
|
| 109 |
+
if not isinstance(init_param, (list, tuple)):
|
| 110 |
+
init_param = (init_param,)
|
| 111 |
+
logging.info("init_param is not None: %s", init_param)
|
| 112 |
+
for p in init_param:
|
| 113 |
+
logging.info(f"Loading pretrained params from {p}")
|
| 114 |
+
load_pretrained_model(
|
| 115 |
+
model=model,
|
| 116 |
+
path=p,
|
| 117 |
+
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
| 118 |
+
oss_bucket=kwargs.get("oss_bucket", None),
|
| 119 |
+
scope_map=kwargs.get("scope_map", None),
|
| 120 |
+
excludes=kwargs.get("excludes", None),
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
initialize(model, kwargs.get("init", "kaiming_normal"))
|
| 124 |
+
|
| 125 |
+
# freeze_param
|
| 126 |
+
freeze_param = kwargs.get("freeze_param", None)
|
| 127 |
+
if freeze_param is not None:
|
| 128 |
+
freeze_param = eval(freeze_param)
|
| 129 |
+
if isinstance(freeze_param, Sequence):
|
| 130 |
+
freeze_param = (freeze_param,)
|
| 131 |
+
logging.info("freeze_param is not None: %s", freeze_param)
|
| 132 |
+
for t in freeze_param:
|
| 133 |
+
for k, p in model.named_parameters():
|
| 134 |
+
if k.startswith(t + ".") or k == t:
|
| 135 |
+
logging.info(f"Setting {k}.requires_grad = False")
|
| 136 |
+
p.requires_grad = False
|
| 137 |
+
|
| 138 |
+
if use_ddp:
|
| 139 |
+
model = model.cuda(local_rank)
|
| 140 |
+
model = DDP(
|
| 141 |
+
model,
|
| 142 |
+
device_ids=[local_rank],
|
| 143 |
+
find_unused_parameters=kwargs.get("train_conf", {}).get(
|
| 144 |
+
"find_unused_parameters", False
|
| 145 |
+
),
|
| 146 |
+
)
|
| 147 |
+
elif use_fsdp:
|
| 148 |
+
model = FSDP(model).cuda(local_rank)
|
| 149 |
+
else:
|
| 150 |
+
model = model.to(device=kwargs.get("device", "cuda"))
|
| 151 |
+
|
| 152 |
+
# optim
|
| 153 |
+
optim = kwargs.get("optim", "adam")
|
| 154 |
+
assert optim in optim_classes
|
| 155 |
+
optim_class = optim_classes.get(optim)
|
| 156 |
+
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
|
| 157 |
+
|
| 158 |
+
# scheduler
|
| 159 |
+
scheduler = kwargs.get("scheduler", "warmuplr")
|
| 160 |
+
assert scheduler in scheduler_classes
|
| 161 |
+
scheduler_class = scheduler_classes.get(scheduler)
|
| 162 |
+
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
|
| 163 |
+
|
| 164 |
+
# dataset
|
| 165 |
+
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
|
| 166 |
+
dataset_tr = dataset_class(
|
| 167 |
+
kwargs.get("train_data_set_list"),
|
| 168 |
+
frontend=frontend,
|
| 169 |
+
tokenizer=tokenizer,
|
| 170 |
+
is_training=True,
|
| 171 |
+
**kwargs.get("dataset_conf"),
|
| 172 |
+
)
|
| 173 |
+
dataset_val = dataset_class(
|
| 174 |
+
kwargs.get("valid_data_set_list"),
|
| 175 |
+
frontend=frontend,
|
| 176 |
+
tokenizer=tokenizer,
|
| 177 |
+
is_training=False,
|
| 178 |
+
**kwargs.get("dataset_conf"),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# dataloader
|
| 182 |
+
batch_sampler = kwargs["dataset_conf"].get(
|
| 183 |
+
"batch_sampler", "DynamicBatchLocalShuffleSampler"
|
| 184 |
+
)
|
| 185 |
+
batch_sampler_val = None
|
| 186 |
+
if batch_sampler is not None:
|
| 187 |
+
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
|
| 188 |
+
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
| 189 |
+
batch_sampler_val = batch_sampler_class(
|
| 190 |
+
dataset_val, is_training=False, **kwargs.get("dataset_conf")
|
| 191 |
+
)
|
| 192 |
+
dataloader_tr = torch.utils.data.DataLoader(
|
| 193 |
+
dataset_tr,
|
| 194 |
+
collate_fn=dataset_tr.collator,
|
| 195 |
+
batch_sampler=batch_sampler,
|
| 196 |
+
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
|
| 197 |
+
pin_memory=True,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
dataloader_val = torch.utils.data.DataLoader(
|
| 201 |
+
dataset_val,
|
| 202 |
+
collate_fn=dataset_val.collator,
|
| 203 |
+
batch_sampler=batch_sampler_val,
|
| 204 |
+
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
|
| 205 |
+
pin_memory=True,
|
| 206 |
+
)
|
| 207 |
+
trainer = Trainer(
|
| 208 |
+
model=model,
|
| 209 |
+
optim=optim,
|
| 210 |
+
scheduler=scheduler,
|
| 211 |
+
dataloader_train=dataloader_tr,
|
| 212 |
+
dataloader_val=dataloader_val,
|
| 213 |
+
local_rank=local_rank,
|
| 214 |
+
use_ddp=use_ddp,
|
| 215 |
+
use_fsdp=use_fsdp,
|
| 216 |
+
output_dir=kwargs.get("output_dir", "./exp"),
|
| 217 |
+
resume=kwargs.get("resume", True),
|
| 218 |
+
**kwargs.get("train_conf"),
|
| 219 |
+
)
|
| 220 |
+
trainer.run()
|
| 221 |
+
|
| 222 |
+
if use_ddp or use_fsdp:
|
| 223 |
+
torch.distributed.destroy_process_group()
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if __name__ == "__main__":
|
| 227 |
+
main_hydra()
|
demo/Step-Audio-EditX/funasr_detach/datasets/__init__.py
ADDED
|
File without changes
|
demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/__init__.py
ADDED
|
File without changes
|
demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/datasets.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from funasr_detach.register import tables
|
| 4 |
+
from funasr_detach.utils.load_utils import extract_fbank, load_audio_text_image_video
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@tables.register("dataset_classes", "AudioDataset")
|
| 8 |
+
class AudioDataset(torch.utils.data.Dataset):
|
| 9 |
+
"""
|
| 10 |
+
AudioDataset
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
path,
|
| 16 |
+
index_ds: str = None,
|
| 17 |
+
frontend=None,
|
| 18 |
+
tokenizer=None,
|
| 19 |
+
int_pad_value: int = -1,
|
| 20 |
+
float_pad_value: float = 0.0,
|
| 21 |
+
**kwargs
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
index_ds_class = tables.index_ds_classes.get(index_ds)
|
| 25 |
+
self.index_ds = index_ds_class(path, **kwargs)
|
| 26 |
+
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
| 27 |
+
if preprocessor_speech:
|
| 28 |
+
preprocessor_speech_class = tables.preprocessor_classes.get(
|
| 29 |
+
preprocessor_speech
|
| 30 |
+
)
|
| 31 |
+
preprocessor_speech = preprocessor_speech_class(
|
| 32 |
+
**kwargs.get("preprocessor_speech_conf")
|
| 33 |
+
)
|
| 34 |
+
self.preprocessor_speech = preprocessor_speech
|
| 35 |
+
preprocessor_text = kwargs.get("preprocessor_text", None)
|
| 36 |
+
if preprocessor_text:
|
| 37 |
+
preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
|
| 38 |
+
preprocessor_text = preprocessor_text_class(
|
| 39 |
+
**kwargs.get("preprocessor_text_conf")
|
| 40 |
+
)
|
| 41 |
+
self.preprocessor_text = preprocessor_text
|
| 42 |
+
|
| 43 |
+
self.frontend = frontend
|
| 44 |
+
self.fs = 16000 if frontend is None else frontend.fs
|
| 45 |
+
self.data_type = "sound"
|
| 46 |
+
self.tokenizer = tokenizer
|
| 47 |
+
|
| 48 |
+
self.int_pad_value = int_pad_value
|
| 49 |
+
self.float_pad_value = float_pad_value
|
| 50 |
+
|
| 51 |
+
def get_source_len(self, index):
|
| 52 |
+
item = self.index_ds[index]
|
| 53 |
+
return self.index_ds.get_source_len(item)
|
| 54 |
+
|
| 55 |
+
def get_target_len(self, index):
|
| 56 |
+
item = self.index_ds[index]
|
| 57 |
+
return self.index_ds.get_target_len(item)
|
| 58 |
+
|
| 59 |
+
def __len__(self):
|
| 60 |
+
return len(self.index_ds)
|
| 61 |
+
|
| 62 |
+
def __getitem__(self, index):
|
| 63 |
+
item = self.index_ds[index]
|
| 64 |
+
# import pdb;
|
| 65 |
+
# pdb.set_trace()
|
| 66 |
+
source = item["source"]
|
| 67 |
+
data_src = load_audio_text_image_video(source, fs=self.fs)
|
| 68 |
+
if self.preprocessor_speech:
|
| 69 |
+
data_src = self.preprocessor_speech(data_src, fs=self.fs)
|
| 70 |
+
speech, speech_lengths = extract_fbank(
|
| 71 |
+
data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
|
| 72 |
+
) # speech: [b, T, d]
|
| 73 |
+
|
| 74 |
+
target = item["target"]
|
| 75 |
+
if self.preprocessor_text:
|
| 76 |
+
target = self.preprocessor_text(target)
|
| 77 |
+
if self.tokenizer:
|
| 78 |
+
ids = self.tokenizer.encode(target)
|
| 79 |
+
text = torch.tensor(ids, dtype=torch.int64)
|
| 80 |
+
else:
|
| 81 |
+
ids = target
|
| 82 |
+
text = ids
|
| 83 |
+
ids_lengths = len(ids)
|
| 84 |
+
text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
|
| 85 |
+
|
| 86 |
+
return {
|
| 87 |
+
"speech": speech[0, :, :],
|
| 88 |
+
"speech_lengths": speech_lengths,
|
| 89 |
+
"text": text,
|
| 90 |
+
"text_lengths": text_lengths,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
def collator(self, samples: list = None):
|
| 94 |
+
outputs = {}
|
| 95 |
+
for sample in samples:
|
| 96 |
+
for key in sample.keys():
|
| 97 |
+
if key not in outputs:
|
| 98 |
+
outputs[key] = []
|
| 99 |
+
outputs[key].append(sample[key])
|
| 100 |
+
|
| 101 |
+
for key, data_list in outputs.items():
|
| 102 |
+
if isinstance(data_list[0], torch.Tensor):
|
| 103 |
+
if data_list[0].dtype == torch.int64:
|
| 104 |
+
|
| 105 |
+
pad_value = self.int_pad_value
|
| 106 |
+
else:
|
| 107 |
+
pad_value = self.float_pad_value
|
| 108 |
+
|
| 109 |
+
outputs[key] = torch.nn.utils.rnn.pad_sequence(
|
| 110 |
+
data_list, batch_first=True, padding_value=pad_value
|
| 111 |
+
)
|
| 112 |
+
return outputs
|
demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/index_ds.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import logging
|
| 5 |
+
import concurrent.futures
|
| 6 |
+
import librosa
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
|
| 9 |
+
from funasr_detach.register import tables
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
|
| 13 |
+
class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
|
| 14 |
+
|
| 15 |
+
def __init__(self, path):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
contents = []
|
| 19 |
+
with open(path, encoding="utf-8") as fin:
|
| 20 |
+
for line in fin:
|
| 21 |
+
data = json.loads(line.strip())
|
| 22 |
+
if "text" in data: # for sft
|
| 23 |
+
self.contents.append(data["text"])
|
| 24 |
+
if "source" in data: # for speech lab pretrain
|
| 25 |
+
prompt = data["prompt"]
|
| 26 |
+
source = data["source"]
|
| 27 |
+
target = data["target"]
|
| 28 |
+
source_len = data["source_len"]
|
| 29 |
+
target_len = data["target_len"]
|
| 30 |
+
|
| 31 |
+
contents.append(
|
| 32 |
+
{
|
| 33 |
+
"source": source,
|
| 34 |
+
"prompt": prompt,
|
| 35 |
+
"target": target,
|
| 36 |
+
"source_len": source_len,
|
| 37 |
+
"target_len": target_len,
|
| 38 |
+
}
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.contents = []
|
| 42 |
+
total_num = len(contents)
|
| 43 |
+
try:
|
| 44 |
+
rank = dist.get_rank()
|
| 45 |
+
world_size = dist.get_world_size()
|
| 46 |
+
except:
|
| 47 |
+
rank = 0
|
| 48 |
+
world_size = 1
|
| 49 |
+
logging.warning("distributed is not initialized, only single shard")
|
| 50 |
+
num_per_rank = total_num // world_size
|
| 51 |
+
|
| 52 |
+
# rank = 0
|
| 53 |
+
# import ipdb; ipdb.set_trace()
|
| 54 |
+
self.contents = contents[rank * num_per_rank : (rank + 1) * num_per_rank]
|
| 55 |
+
|
| 56 |
+
logging.info(
|
| 57 |
+
"in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(
|
| 58 |
+
rank, len(self.contents), len(contents)
|
| 59 |
+
)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def __len__(self):
|
| 63 |
+
return len(self.contents)
|
| 64 |
+
|
| 65 |
+
def __getitem__(self, index):
|
| 66 |
+
try:
|
| 67 |
+
data = self.contents[index]
|
| 68 |
+
except:
|
| 69 |
+
print(index)
|
| 70 |
+
return data
|
| 71 |
+
|
| 72 |
+
def get_source_len(self, data_dict):
|
| 73 |
+
return data_dict["source_len"]
|
| 74 |
+
|
| 75 |
+
def get_target_len(self, data_dict):
|
| 76 |
+
|
| 77 |
+
return data_dict["target_len"] if "target_len" in data_dict else 0
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@tables.register("index_ds_classes", "IndexDSJsonl")
|
| 81 |
+
@tables.register("index_ds_classes", "IndexDSJsonlRankFull")
|
| 82 |
+
class IndexDSJsonlRankFull(torch.utils.data.Dataset):
|
| 83 |
+
|
| 84 |
+
def __init__(self, path: str, **kwargs):
|
| 85 |
+
super().__init__()
|
| 86 |
+
|
| 87 |
+
if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans
|
| 88 |
+
from funasr_detach.datasets.audio_datasets.scp2jsonl import (
|
| 89 |
+
gen_jsonl_from_wav_text_list,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
jsonl_outdir = os.path.dirname(path[0])
|
| 93 |
+
jsonl_name = (
|
| 94 |
+
"datalist_train.jsonl"
|
| 95 |
+
if kwargs.get("is_training", True)
|
| 96 |
+
else "datalist_val.jsonl"
|
| 97 |
+
)
|
| 98 |
+
jsonl_file_out = os.path.join(jsonl_outdir, jsonl_name)
|
| 99 |
+
if not os.path.exists(jsonl_file_out):
|
| 100 |
+
print(f"datalist is: {path}, generate jsonl from it")
|
| 101 |
+
gen_jsonl_from_wav_text_list(
|
| 102 |
+
path, jsonl_file_out=jsonl_file_out, **kwargs
|
| 103 |
+
)
|
| 104 |
+
path = jsonl_file_out
|
| 105 |
+
|
| 106 |
+
contents = []
|
| 107 |
+
with open(path, encoding="utf-8") as fin:
|
| 108 |
+
for line in fin:
|
| 109 |
+
data = json.loads(line.strip())
|
| 110 |
+
if "text" in data: # for sft
|
| 111 |
+
self.contents.append(data["text"])
|
| 112 |
+
if "source" in data: # for speech lab pretrain
|
| 113 |
+
prompt = data.get("prompt", "<ASR>")
|
| 114 |
+
source = data["source"]
|
| 115 |
+
target = data["target"]
|
| 116 |
+
source_len = data.get("source_len", 1)
|
| 117 |
+
target_len = data.get("target_len", 0)
|
| 118 |
+
|
| 119 |
+
contents.append(
|
| 120 |
+
{
|
| 121 |
+
"source": source,
|
| 122 |
+
"prompt": prompt,
|
| 123 |
+
"target": target,
|
| 124 |
+
"source_len": source_len,
|
| 125 |
+
"target_len": target_len,
|
| 126 |
+
}
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.contents = contents
|
| 130 |
+
|
| 131 |
+
logging.info(
|
| 132 |
+
"total_num of samplers across ranks: {}".format(len(self.contents))
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def __len__(self):
|
| 136 |
+
return len(self.contents)
|
| 137 |
+
|
| 138 |
+
def __getitem__(self, index):
|
| 139 |
+
try:
|
| 140 |
+
data = self.contents[index]
|
| 141 |
+
except:
|
| 142 |
+
print(index)
|
| 143 |
+
return data
|
| 144 |
+
|
| 145 |
+
def get_source_len(self, data_dict):
|
| 146 |
+
return data_dict.get("source_len", 1)
|
| 147 |
+
|
| 148 |
+
def get_target_len(self, data_dict):
|
| 149 |
+
|
| 150 |
+
return data_dict.get("target_len", 0)
|
demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/preprocessor.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import logging
|
| 5 |
+
import concurrent.futures
|
| 6 |
+
import librosa
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
from typing import Collection
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
from torch import nn
|
| 12 |
+
import random
|
| 13 |
+
import re
|
| 14 |
+
from funasr_detach.tokenizer.cleaner import TextCleaner
|
| 15 |
+
from funasr_detach.register import tables
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
|
| 19 |
+
class SpeechPreprocessSpeedPerturb(nn.Module):
|
| 20 |
+
def __init__(self, speed_perturb: list = None, **kwargs):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.speed_perturb = speed_perturb
|
| 23 |
+
|
| 24 |
+
def forward(self, waveform, fs, **kwargs):
|
| 25 |
+
if self.speed_perturb is None:
|
| 26 |
+
return waveform
|
| 27 |
+
speed = random.choice(self.speed_perturb)
|
| 28 |
+
if speed != 1.0:
|
| 29 |
+
if not isinstance(waveform, torch.Tensor):
|
| 30 |
+
waveform = torch.tensor(waveform)
|
| 31 |
+
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
|
| 32 |
+
waveform.view(1, -1), fs, [["speed", str(speed)], ["rate", str(fs)]]
|
| 33 |
+
)
|
| 34 |
+
waveform = waveform.view(-1)
|
| 35 |
+
|
| 36 |
+
return waveform
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@tables.register("preprocessor_classes", "TextPreprocessSegDict")
|
| 40 |
+
class TextPreprocessSegDict(nn.Module):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
seg_dict: str = None,
|
| 44 |
+
text_cleaner: Collection[str] = None,
|
| 45 |
+
split_with_space: bool = False,
|
| 46 |
+
**kwargs
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.text_cleaner = TextCleaner(text_cleaner)
|
| 51 |
+
|
| 52 |
+
def forward(self, text, **kwargs):
|
| 53 |
+
text = self.text_cleaner(text)
|
| 54 |
+
|
| 55 |
+
return text
|
demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/samplers.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import logging
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
from funasr_detach.register import tables
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
|
| 10 |
+
class BatchSampler(torch.utils.data.BatchSampler):
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
dataset,
|
| 15 |
+
batch_type: str = "example",
|
| 16 |
+
batch_size: int = 100,
|
| 17 |
+
buffer_size: int = 30,
|
| 18 |
+
drop_last: bool = False,
|
| 19 |
+
shuffle: bool = True,
|
| 20 |
+
is_training: bool = True,
|
| 21 |
+
**kwargs
|
| 22 |
+
):
|
| 23 |
+
|
| 24 |
+
self.drop_last = drop_last
|
| 25 |
+
self.pre_idx = -1
|
| 26 |
+
self.dataset = dataset
|
| 27 |
+
self.total_samples = len(dataset)
|
| 28 |
+
self.batch_type = batch_type
|
| 29 |
+
self.batch_size = int(batch_size)
|
| 30 |
+
self.buffer_size = buffer_size
|
| 31 |
+
self.max_token_length = kwargs.get("max_token_length", 5000)
|
| 32 |
+
self.shuffle_idx = np.arange(self.total_samples)
|
| 33 |
+
self.shuffle = shuffle and is_training
|
| 34 |
+
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return (self.total_samples - 1) // self.batch_size + 1
|
| 38 |
+
|
| 39 |
+
def set_epoch(self, epoch):
|
| 40 |
+
np.random.seed(epoch)
|
| 41 |
+
|
| 42 |
+
def __iter__(self):
|
| 43 |
+
|
| 44 |
+
if self.shuffle:
|
| 45 |
+
np.random.shuffle(self.shuffle_idx)
|
| 46 |
+
|
| 47 |
+
batch = []
|
| 48 |
+
max_token = 0
|
| 49 |
+
num_sample = 0
|
| 50 |
+
|
| 51 |
+
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
| 52 |
+
# print("iter_num: ", iter_num)
|
| 53 |
+
for iter in range(self.pre_idx + 1, iter_num):
|
| 54 |
+
datalen_with_index = []
|
| 55 |
+
for i in range(self.buffer_size):
|
| 56 |
+
idx = iter * self.buffer_size + i
|
| 57 |
+
if idx >= self.total_samples:
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
idx_map = self.shuffle_idx[idx]
|
| 61 |
+
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
| 62 |
+
target_len = (
|
| 63 |
+
self.dataset.get_target_len(idx_map)
|
| 64 |
+
if self.batch_type == "length"
|
| 65 |
+
else 0.0
|
| 66 |
+
)
|
| 67 |
+
source_len = (
|
| 68 |
+
self.dataset.get_source_len(idx_map) / self.length_scale_source
|
| 69 |
+
)
|
| 70 |
+
sample_len_cur = source_len + target_len
|
| 71 |
+
|
| 72 |
+
datalen_with_index.append([idx, sample_len_cur])
|
| 73 |
+
|
| 74 |
+
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
| 75 |
+
for item in datalen_with_index_sort:
|
| 76 |
+
idx, sample_len_cur_raw = item
|
| 77 |
+
if sample_len_cur_raw > self.max_token_length:
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
max_token_cur = max(max_token, sample_len_cur_raw)
|
| 81 |
+
max_token_padding = 1 + num_sample
|
| 82 |
+
if self.batch_type != "example":
|
| 83 |
+
max_token_padding *= max_token_cur
|
| 84 |
+
if max_token_padding <= self.batch_size:
|
| 85 |
+
batch.append(idx)
|
| 86 |
+
max_token = max_token_cur
|
| 87 |
+
num_sample += 1
|
| 88 |
+
else:
|
| 89 |
+
yield batch
|
| 90 |
+
batch = [idx]
|
| 91 |
+
max_token = sample_len_cur_raw
|
| 92 |
+
num_sample = 1
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@tables.register("batch_sampler_classes", "BatchSampler")
|
| 96 |
+
@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
|
| 97 |
+
class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
dataset,
|
| 102 |
+
batch_type: str = "example",
|
| 103 |
+
batch_size: int = 100,
|
| 104 |
+
buffer_size: int = 30,
|
| 105 |
+
drop_last: bool = True,
|
| 106 |
+
shuffle: bool = True,
|
| 107 |
+
is_training: bool = True,
|
| 108 |
+
**kwargs
|
| 109 |
+
):
|
| 110 |
+
|
| 111 |
+
self.drop_last = drop_last
|
| 112 |
+
self.pre_idx = -1
|
| 113 |
+
self.dataset = dataset
|
| 114 |
+
self.total_samples = len(dataset)
|
| 115 |
+
self.batch_type = batch_type
|
| 116 |
+
self.batch_size = int(batch_size)
|
| 117 |
+
self.buffer_size = buffer_size
|
| 118 |
+
self.max_token_length = kwargs.get("max_token_length", 1500)
|
| 119 |
+
self.shuffle_idx = np.arange(self.total_samples)
|
| 120 |
+
self.shuffle = shuffle and is_training
|
| 121 |
+
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
rank = dist.get_rank()
|
| 125 |
+
world_size = dist.get_world_size()
|
| 126 |
+
except:
|
| 127 |
+
rank = 0
|
| 128 |
+
world_size = 1
|
| 129 |
+
self.rank = rank
|
| 130 |
+
self.world_size = world_size
|
| 131 |
+
|
| 132 |
+
def __len__(self):
|
| 133 |
+
return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
|
| 134 |
+
|
| 135 |
+
def set_epoch(self, epoch):
|
| 136 |
+
np.random.seed(epoch)
|
| 137 |
+
|
| 138 |
+
def __iter__(self):
|
| 139 |
+
|
| 140 |
+
batch_size_total = self.batch_size * self.world_size
|
| 141 |
+
|
| 142 |
+
if self.shuffle:
|
| 143 |
+
np.random.shuffle(self.shuffle_idx)
|
| 144 |
+
|
| 145 |
+
batch = []
|
| 146 |
+
max_token = 0
|
| 147 |
+
num_sample = 0
|
| 148 |
+
|
| 149 |
+
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
| 150 |
+
# print("iter_num: ", iter_num)
|
| 151 |
+
for iter in range(self.pre_idx + 1, iter_num):
|
| 152 |
+
# if iter == iter_num -1 and self.drop_last:
|
| 153 |
+
# continue
|
| 154 |
+
datalen_with_index = []
|
| 155 |
+
for i in range(self.buffer_size):
|
| 156 |
+
idx = iter * self.buffer_size + i
|
| 157 |
+
if idx >= self.total_samples:
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
idx_map = self.shuffle_idx[idx]
|
| 161 |
+
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
| 162 |
+
|
| 163 |
+
source_len = (
|
| 164 |
+
self.dataset.get_source_len(idx_map) / self.length_scale_source
|
| 165 |
+
)
|
| 166 |
+
target_len = (
|
| 167 |
+
self.dataset.get_target_len(idx_map)
|
| 168 |
+
if self.batch_type == "length"
|
| 169 |
+
else 0.0
|
| 170 |
+
)
|
| 171 |
+
sample_len_cur = source_len + target_len
|
| 172 |
+
|
| 173 |
+
datalen_with_index.append([idx, sample_len_cur])
|
| 174 |
+
|
| 175 |
+
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
| 176 |
+
for item in datalen_with_index_sort:
|
| 177 |
+
idx, sample_len_cur_raw = item
|
| 178 |
+
if sample_len_cur_raw > self.max_token_length:
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
max_token_cur = max(max_token, sample_len_cur_raw)
|
| 182 |
+
max_token_padding = 1 + num_sample
|
| 183 |
+
# if self.batch_type != 'example':
|
| 184 |
+
# max_token_padding *= max_token_cur
|
| 185 |
+
if max_token_padding <= batch_size_total:
|
| 186 |
+
batch.append(idx)
|
| 187 |
+
max_token = max_token_cur
|
| 188 |
+
num_sample += 1
|
| 189 |
+
else:
|
| 190 |
+
batch_rank = batch[
|
| 191 |
+
self.rank * self.batch_size : (self.rank + 1) * self.batch_size
|
| 192 |
+
]
|
| 193 |
+
yield batch_rank
|
| 194 |
+
batch = [idx]
|
| 195 |
+
max_token = sample_len_cur_raw
|
| 196 |
+
num_sample = 1
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
|
| 200 |
+
class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
|
| 201 |
+
|
| 202 |
+
def __init__(
|
| 203 |
+
self,
|
| 204 |
+
dataset,
|
| 205 |
+
batch_type: str = "example",
|
| 206 |
+
batch_size: int = 100,
|
| 207 |
+
buffer_size: int = 30,
|
| 208 |
+
drop_last: bool = True,
|
| 209 |
+
shuffle: bool = True,
|
| 210 |
+
is_training: bool = True,
|
| 211 |
+
**kwargs
|
| 212 |
+
):
|
| 213 |
+
|
| 214 |
+
self.drop_last = drop_last
|
| 215 |
+
self.pre_idx = -1
|
| 216 |
+
self.dataset = dataset
|
| 217 |
+
self.total_samples = len(dataset)
|
| 218 |
+
self.batch_type = batch_type
|
| 219 |
+
self.batch_size = int(batch_size)
|
| 220 |
+
self.buffer_size = buffer_size
|
| 221 |
+
self.max_token_length = kwargs.get("max_token_length", 1500)
|
| 222 |
+
self.shuffle_idx = np.arange(self.total_samples)
|
| 223 |
+
self.shuffle = shuffle and is_training
|
| 224 |
+
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
rank = dist.get_rank()
|
| 228 |
+
world_size = dist.get_world_size()
|
| 229 |
+
except:
|
| 230 |
+
rank = 0
|
| 231 |
+
world_size = 1
|
| 232 |
+
self.rank = rank
|
| 233 |
+
self.world_size = world_size
|
| 234 |
+
|
| 235 |
+
def __len__(self):
|
| 236 |
+
return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
|
| 237 |
+
|
| 238 |
+
def set_epoch(self, epoch):
|
| 239 |
+
np.random.seed(epoch)
|
| 240 |
+
|
| 241 |
+
def __iter__(self):
|
| 242 |
+
|
| 243 |
+
batch_size_total = self.batch_size * self.world_size
|
| 244 |
+
if self.shuffle:
|
| 245 |
+
np.random.shuffle(self.shuffle_idx)
|
| 246 |
+
|
| 247 |
+
batch_list_all_rank = []
|
| 248 |
+
batch_list_cur = []
|
| 249 |
+
max_token = 0
|
| 250 |
+
num_sample = 0
|
| 251 |
+
|
| 252 |
+
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
| 253 |
+
# print("iter_num: ", iter_num)
|
| 254 |
+
for iter in range(self.pre_idx + 1, iter_num):
|
| 255 |
+
# if iter == iter_num - 1 and self.drop_last:
|
| 256 |
+
# continue
|
| 257 |
+
datalen_with_index = []
|
| 258 |
+
for i in range(self.buffer_size):
|
| 259 |
+
idx = iter * self.buffer_size + i
|
| 260 |
+
if idx >= self.total_samples:
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
idx_map = self.shuffle_idx[idx]
|
| 264 |
+
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
| 265 |
+
|
| 266 |
+
source_len = (
|
| 267 |
+
self.dataset.get_source_len(idx_map) / self.length_scale_source
|
| 268 |
+
)
|
| 269 |
+
target_len = (
|
| 270 |
+
self.dataset.get_target_len(idx_map)
|
| 271 |
+
if self.batch_type == "length"
|
| 272 |
+
else 0.0
|
| 273 |
+
)
|
| 274 |
+
sample_len_cur = source_len + target_len
|
| 275 |
+
|
| 276 |
+
datalen_with_index.append([idx, sample_len_cur])
|
| 277 |
+
|
| 278 |
+
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
| 279 |
+
for ii, item in enumerate(datalen_with_index_sort):
|
| 280 |
+
is_last_batch = iter == iter_num - 1 and ii == len(
|
| 281 |
+
datalen_with_index_sort
|
| 282 |
+
)
|
| 283 |
+
idx, sample_len_cur_raw = item
|
| 284 |
+
if sample_len_cur_raw > self.max_token_length:
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
max_token_cur = max(max_token, sample_len_cur_raw)
|
| 288 |
+
max_token_padding = 1 + num_sample
|
| 289 |
+
|
| 290 |
+
if self.batch_type != "example":
|
| 291 |
+
max_token_padding *= max_token_cur
|
| 292 |
+
if len(batch_list_all_rank) < self.world_size:
|
| 293 |
+
|
| 294 |
+
if max_token_padding <= self.batch_size:
|
| 295 |
+
batch_list_cur.append(idx)
|
| 296 |
+
max_token = max_token_cur
|
| 297 |
+
num_sample += 1
|
| 298 |
+
else:
|
| 299 |
+
batch_list_all_rank.append(batch_list_cur)
|
| 300 |
+
batch_list_cur = []
|
| 301 |
+
else:
|
| 302 |
+
batch_rank = batch_list_all_rank[self.rank]
|
| 303 |
+
yield batch_rank
|
| 304 |
+
batch_list_all_rank = [idx]
|
| 305 |
+
max_token = sample_len_cur_raw
|
| 306 |
+
num_sample = 1
|
demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/scp2jsonl.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import logging
|
| 5 |
+
import hydra
|
| 6 |
+
from omegaconf import DictConfig, OmegaConf
|
| 7 |
+
import concurrent.futures
|
| 8 |
+
import librosa
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def gen_jsonl_from_wav_text_list(
|
| 13 |
+
path, data_type_list=("source", "target"), jsonl_file_out: str = None, **kwargs
|
| 14 |
+
):
|
| 15 |
+
try:
|
| 16 |
+
rank = dist.get_rank()
|
| 17 |
+
world_size = dist.get_world_size()
|
| 18 |
+
except:
|
| 19 |
+
rank = 0
|
| 20 |
+
world_size = 1
|
| 21 |
+
|
| 22 |
+
cpu_cores = os.cpu_count() or 1
|
| 23 |
+
print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
|
| 24 |
+
if rank == 0:
|
| 25 |
+
json_dict = {}
|
| 26 |
+
for data_type, data_file in zip(data_type_list, path):
|
| 27 |
+
json_dict[data_type] = {}
|
| 28 |
+
with open(data_file, "r") as f:
|
| 29 |
+
|
| 30 |
+
data_file_lists = f.readlines()
|
| 31 |
+
lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
|
| 32 |
+
task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
|
| 33 |
+
with concurrent.futures.ThreadPoolExecutor(
|
| 34 |
+
max_workers=cpu_cores
|
| 35 |
+
) as executor:
|
| 36 |
+
|
| 37 |
+
futures = [
|
| 38 |
+
executor.submit(
|
| 39 |
+
parse_context_length,
|
| 40 |
+
data_file_lists[
|
| 41 |
+
i * lines_for_each_th : (i + 1) * lines_for_each_th
|
| 42 |
+
],
|
| 43 |
+
data_type,
|
| 44 |
+
)
|
| 45 |
+
for i in range(task_num)
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
for future in concurrent.futures.as_completed(futures):
|
| 49 |
+
|
| 50 |
+
json_dict[data_type].update(future.result())
|
| 51 |
+
# print(json_dict)
|
| 52 |
+
|
| 53 |
+
with open(jsonl_file_out, "w") as f:
|
| 54 |
+
for key in json_dict[data_type_list[0]].keys():
|
| 55 |
+
jsonl_line = {"key": key}
|
| 56 |
+
for data_file in data_type_list:
|
| 57 |
+
jsonl_line.update(json_dict[data_file][key])
|
| 58 |
+
jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
|
| 59 |
+
f.write(jsonl_line + "\n")
|
| 60 |
+
f.flush()
|
| 61 |
+
|
| 62 |
+
else:
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
if world_size > 1:
|
| 66 |
+
dist.barrier()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def parse_context_length(data_list: list, data_type: str):
|
| 70 |
+
|
| 71 |
+
res = {}
|
| 72 |
+
for i, line in enumerate(data_list):
|
| 73 |
+
key, line = line.strip().split(maxsplit=1)
|
| 74 |
+
line = line.strip()
|
| 75 |
+
if os.path.exists(line):
|
| 76 |
+
waveform, _ = librosa.load(line, sr=16000)
|
| 77 |
+
sample_num = len(waveform)
|
| 78 |
+
context_len = int(sample_num // 16000 * 1000 / 10)
|
| 79 |
+
else:
|
| 80 |
+
context_len = len(line.split()) if " " in line else len(line)
|
| 81 |
+
res[key] = {data_type: line, f"{data_type}_len": context_len}
|
| 82 |
+
return res
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@hydra.main(config_name=None, version_base=None)
|
| 86 |
+
def main_hydra(cfg: DictConfig):
|
| 87 |
+
|
| 88 |
+
kwargs = OmegaConf.to_container(cfg, resolve=True)
|
| 89 |
+
|
| 90 |
+
scp_file_list = kwargs.get(
|
| 91 |
+
"scp_file_list",
|
| 92 |
+
(
|
| 93 |
+
"/Users/zhifu/funasr1.0/test_local/wav.scp",
|
| 94 |
+
"/Users/zhifu/funasr1.0/test_local/text.txt",
|
| 95 |
+
),
|
| 96 |
+
)
|
| 97 |
+
if isinstance(scp_file_list, str):
|
| 98 |
+
scp_file_list = eval(scp_file_list)
|
| 99 |
+
data_type_list = kwargs.get("data_type_list", ("source", "target"))
|
| 100 |
+
jsonl_file_out = kwargs.get(
|
| 101 |
+
"jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
|
| 102 |
+
)
|
| 103 |
+
gen_jsonl_from_wav_text_list(
|
| 104 |
+
scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
"""
|
| 109 |
+
python -m funasr_detach.datasets.audio_datasets.scp2jsonl \
|
| 110 |
+
++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
|
| 111 |
+
++data_type_list='["source", "target"]' \
|
| 112 |
+
++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
main_hydra()
|
demo/Step-Audio-EditX/funasr_detach/download/__init__.py
ADDED
|
File without changes
|
demo/Step-Audio-EditX/funasr_detach/download/download_dataset_from_hub.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def download_dataset():
|
| 2 |
+
pass
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def download_dataset_from_ms(**kwargs):
|
| 6 |
+
from modelscope.msdatasets import MsDataset
|
| 7 |
+
|
| 8 |
+
dataset_name = kwargs.get(
|
| 9 |
+
"dataset_name", "speech_asr/speech_asr_aishell1_trainsets"
|
| 10 |
+
)
|
| 11 |
+
subset_name = kwargs.get("subset_name", "default")
|
| 12 |
+
split = kwargs.get("split", "train")
|
| 13 |
+
data_dump_dir = kwargs.get("data_dump_dir", None)
|
| 14 |
+
ds = MsDataset.load(
|
| 15 |
+
dataset_name=dataset_name,
|
| 16 |
+
subset_name=subset_name,
|
| 17 |
+
split=split,
|
| 18 |
+
cache_dir=data_dump_dir,
|
| 19 |
+
)
|
demo/Step-Audio-EditX/funasr_detach/download/download_from_hub.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import threading
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
|
| 6 |
+
from funasr_detach.download.name_maps_from_hub import name_maps_ms, name_maps_hf
|
| 7 |
+
|
| 8 |
+
# Global cache for downloaded models to avoid repeated downloads
|
| 9 |
+
# Key: (repo_id, model_revision, model_hub)
|
| 10 |
+
# Value: repo_cache_dir
|
| 11 |
+
_model_cache = {}
|
| 12 |
+
_cache_lock = threading.Lock()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def download_model(**kwargs):
|
| 16 |
+
model_hub = kwargs.get("model_hub", "ms")
|
| 17 |
+
model_or_path = kwargs.get("model")
|
| 18 |
+
repo_path = kwargs.get("repo_path", "")
|
| 19 |
+
|
| 20 |
+
# Handle name mapping based on model_hub
|
| 21 |
+
if model_hub == "ms" and model_or_path in name_maps_ms:
|
| 22 |
+
model_or_path = name_maps_ms[model_or_path]
|
| 23 |
+
elif model_hub == "hf" and model_or_path in name_maps_hf:
|
| 24 |
+
model_or_path = name_maps_hf[model_or_path]
|
| 25 |
+
|
| 26 |
+
model_revision = kwargs.get("model_revision")
|
| 27 |
+
|
| 28 |
+
# Download model if it doesn't exist locally
|
| 29 |
+
if not os.path.exists(model_or_path):
|
| 30 |
+
if model_hub == "local":
|
| 31 |
+
# For local models, the path should already exist
|
| 32 |
+
raise FileNotFoundError(f"Local model path does not exist: {model_or_path}")
|
| 33 |
+
elif model_hub in ["ms", "hf"]:
|
| 34 |
+
repo_path, model_or_path = get_or_download_model_dir(
|
| 35 |
+
model_or_path,
|
| 36 |
+
model_revision,
|
| 37 |
+
is_training=kwargs.get("is_training"),
|
| 38 |
+
check_latest=kwargs.get("kwargs", True),
|
| 39 |
+
model_hub=model_hub,
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError(f"Unsupported model_hub: {model_hub}")
|
| 43 |
+
|
| 44 |
+
print(f"Using model path: {model_or_path}")
|
| 45 |
+
kwargs["model_path"] = model_or_path
|
| 46 |
+
kwargs["repo_path"] = repo_path
|
| 47 |
+
|
| 48 |
+
# Common logic for processing configuration files (same for all model hubs)
|
| 49 |
+
if os.path.exists(os.path.join(model_or_path, "configuration.json")):
|
| 50 |
+
with open(
|
| 51 |
+
os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8"
|
| 52 |
+
) as f:
|
| 53 |
+
conf_json = json.load(f)
|
| 54 |
+
cfg = {}
|
| 55 |
+
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
|
| 56 |
+
cfg.update(kwargs)
|
| 57 |
+
config = OmegaConf.load(cfg["config"])
|
| 58 |
+
kwargs = OmegaConf.merge(config, cfg)
|
| 59 |
+
kwargs["model"] = config["model"]
|
| 60 |
+
elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
|
| 61 |
+
os.path.join(model_or_path, "model.pt")
|
| 62 |
+
):
|
| 63 |
+
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
|
| 64 |
+
kwargs = OmegaConf.merge(config, kwargs)
|
| 65 |
+
init_param = os.path.join(model_or_path, "model.pb")
|
| 66 |
+
kwargs["init_param"] = init_param
|
| 67 |
+
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
| 68 |
+
kwargs["tokenizer_conf"]["token_list"] = os.path.join(
|
| 69 |
+
model_or_path, "tokens.txt"
|
| 70 |
+
)
|
| 71 |
+
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
|
| 72 |
+
kwargs["tokenizer_conf"]["token_list"] = os.path.join(
|
| 73 |
+
model_or_path, "tokens.json"
|
| 74 |
+
)
|
| 75 |
+
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
|
| 76 |
+
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(
|
| 77 |
+
model_or_path, "seg_dict"
|
| 78 |
+
)
|
| 79 |
+
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
|
| 80 |
+
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(
|
| 81 |
+
model_or_path, "bpe.model"
|
| 82 |
+
)
|
| 83 |
+
kwargs["model"] = config["model"]
|
| 84 |
+
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
|
| 85 |
+
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
| 86 |
+
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
|
| 87 |
+
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
|
| 88 |
+
|
| 89 |
+
return OmegaConf.to_container(kwargs, resolve=True)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
|
| 93 |
+
|
| 94 |
+
if isinstance(file_path_metas, dict):
|
| 95 |
+
for k, v in file_path_metas.items():
|
| 96 |
+
if isinstance(v, str):
|
| 97 |
+
p = os.path.join(model_or_path, v)
|
| 98 |
+
if os.path.exists(p):
|
| 99 |
+
cfg[k] = p
|
| 100 |
+
elif isinstance(v, dict):
|
| 101 |
+
if k not in cfg:
|
| 102 |
+
cfg[k] = {}
|
| 103 |
+
add_file_root_path(model_or_path, v, cfg[k])
|
| 104 |
+
|
| 105 |
+
return cfg
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_or_download_model_dir(
|
| 109 |
+
model,
|
| 110 |
+
model_revision=None,
|
| 111 |
+
is_training=False,
|
| 112 |
+
check_latest=True,
|
| 113 |
+
model_hub="ms",
|
| 114 |
+
):
|
| 115 |
+
"""Get local model directory or download model if necessary.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
model (str): model id or path to local model directory.
|
| 119 |
+
For HF subfolders, use format: "repo_id/subfolder_path"
|
| 120 |
+
model_revision (str, optional): model version number.
|
| 121 |
+
is_training (bool): Whether this is for training
|
| 122 |
+
check_latest (bool): Whether to check for latest version
|
| 123 |
+
model_hub (str): Model hub type ("ms" for ModelScope, "hf" for HuggingFace)
|
| 124 |
+
"""
|
| 125 |
+
# Extract repo_id for caching (handle subfolder case)
|
| 126 |
+
if "/" in model and len(model.split("/")) > 2:
|
| 127 |
+
parts = model.split("/")
|
| 128 |
+
repo_id = "/".join(parts[:2]) # e.g., "organization/repo" or "stepfun-ai/Step-Audio-EditX"
|
| 129 |
+
subfolder = "/".join(parts[2:]) # e.g., "subfolder/model"
|
| 130 |
+
else:
|
| 131 |
+
repo_id = model
|
| 132 |
+
subfolder = None
|
| 133 |
+
|
| 134 |
+
# Create cache key
|
| 135 |
+
cache_key = (repo_id, model_revision, model_hub)
|
| 136 |
+
|
| 137 |
+
# Check cache first
|
| 138 |
+
with _cache_lock:
|
| 139 |
+
if cache_key in _model_cache:
|
| 140 |
+
cached_repo_dir = _model_cache[cache_key]
|
| 141 |
+
print(f"Using cached model for {repo_id}: {cached_repo_dir}")
|
| 142 |
+
|
| 143 |
+
# For subfolder case, construct the model_cache_dir from cached repo
|
| 144 |
+
if subfolder:
|
| 145 |
+
model_cache_dir = os.path.join(cached_repo_dir, subfolder)
|
| 146 |
+
if not os.path.exists(model_cache_dir):
|
| 147 |
+
raise FileNotFoundError(f"Subfolder {subfolder} not found in cached repo {repo_id}")
|
| 148 |
+
else:
|
| 149 |
+
model_cache_dir = cached_repo_dir
|
| 150 |
+
|
| 151 |
+
return cached_repo_dir, model_cache_dir
|
| 152 |
+
|
| 153 |
+
# Cache miss, need to download
|
| 154 |
+
if model_hub == "ms":
|
| 155 |
+
# ModelScope download
|
| 156 |
+
from modelscope.hub.snapshot_download import snapshot_download
|
| 157 |
+
from modelscope.utils.constant import Invoke, ThirdParty
|
| 158 |
+
|
| 159 |
+
key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
|
| 160 |
+
|
| 161 |
+
# Download the repo (use repo_id, not the full model path with subfolder)
|
| 162 |
+
repo_cache_dir = snapshot_download(
|
| 163 |
+
repo_id,
|
| 164 |
+
revision=model_revision,
|
| 165 |
+
user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"},
|
| 166 |
+
)
|
| 167 |
+
repo_cache_dir = normalize_cache_path(repo_cache_dir)
|
| 168 |
+
|
| 169 |
+
# Construct model_cache_dir
|
| 170 |
+
if subfolder:
|
| 171 |
+
model_cache_dir = os.path.join(repo_cache_dir, subfolder)
|
| 172 |
+
if not os.path.exists(model_cache_dir):
|
| 173 |
+
raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}")
|
| 174 |
+
else:
|
| 175 |
+
model_cache_dir = normalize_cache_path(repo_cache_dir)
|
| 176 |
+
|
| 177 |
+
elif model_hub == "hf":
|
| 178 |
+
# HuggingFace download
|
| 179 |
+
try:
|
| 180 |
+
from huggingface_hub import snapshot_download
|
| 181 |
+
except ImportError:
|
| 182 |
+
raise ImportError(
|
| 183 |
+
"huggingface_hub is required for downloading from HuggingFace. "
|
| 184 |
+
"Please install it with: pip install huggingface_hub"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Download the repo (use repo_id, not the full model path with subfolder)
|
| 188 |
+
repo_cache_dir = snapshot_download(
|
| 189 |
+
repo_id=repo_id,
|
| 190 |
+
revision=model_revision,
|
| 191 |
+
allow_patterns=None, # Download all files to ensure resource files are available
|
| 192 |
+
)
|
| 193 |
+
repo_cache_dir = normalize_cache_path(repo_cache_dir)
|
| 194 |
+
|
| 195 |
+
# Construct model_cache_dir
|
| 196 |
+
if subfolder:
|
| 197 |
+
model_cache_dir = os.path.join(repo_cache_dir, subfolder)
|
| 198 |
+
if not os.path.exists(model_cache_dir):
|
| 199 |
+
raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}")
|
| 200 |
+
else:
|
| 201 |
+
model_cache_dir = normalize_cache_path(repo_cache_dir)
|
| 202 |
+
else:
|
| 203 |
+
raise ValueError(f"Unsupported model_hub: {model_hub}")
|
| 204 |
+
|
| 205 |
+
# Cache the result before returning
|
| 206 |
+
with _cache_lock:
|
| 207 |
+
_model_cache[cache_key] = repo_cache_dir
|
| 208 |
+
|
| 209 |
+
print(f"Model downloaded to: {model_cache_dir}")
|
| 210 |
+
return repo_cache_dir, model_cache_dir
|
| 211 |
+
|
| 212 |
+
def normalize_cache_path(cache_path):
|
| 213 |
+
"""Normalize cache path to ensure consistent format with snapshots/{commit_id}."""
|
| 214 |
+
# Check if the cache_path directory contains a snapshots folder
|
| 215 |
+
snapshots_dir = os.path.join(cache_path, "snapshots")
|
| 216 |
+
if os.path.exists(snapshots_dir) and os.path.isdir(snapshots_dir):
|
| 217 |
+
# Find the commit_id subdirectory in snapshots
|
| 218 |
+
try:
|
| 219 |
+
snapshot_items = os.listdir(snapshots_dir)
|
| 220 |
+
# Look for the first directory (should be the commit_id)
|
| 221 |
+
for item in snapshot_items:
|
| 222 |
+
item_path = os.path.join(snapshots_dir, item)
|
| 223 |
+
if os.path.isdir(item_path):
|
| 224 |
+
# Found commit_id directory, return the full path
|
| 225 |
+
return os.path.join(cache_path, "snapshots", item)
|
| 226 |
+
except OSError:
|
| 227 |
+
pass
|
| 228 |
+
|
| 229 |
+
# If no snapshots directory found or error occurred, return original path
|
| 230 |
+
return cache_path
|
| 231 |
+
|
demo/Step-Audio-EditX/funasr_detach/download/file.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
from abc import ABCMeta, abstractmethod
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Generator, Union
|
| 9 |
+
|
| 10 |
+
import requests
|
| 11 |
+
from urllib.parse import urlparse
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def download_from_url(url):
|
| 15 |
+
result = urlparse(url)
|
| 16 |
+
file_path = None
|
| 17 |
+
if result.scheme is not None and len(result.scheme) > 0:
|
| 18 |
+
storage = HTTPStorage()
|
| 19 |
+
# bytes
|
| 20 |
+
data = storage.read(url)
|
| 21 |
+
work_dir = tempfile.TemporaryDirectory().name
|
| 22 |
+
if not os.path.exists(work_dir):
|
| 23 |
+
os.makedirs(work_dir)
|
| 24 |
+
file_path = os.path.join(work_dir, os.path.basename(url))
|
| 25 |
+
with open(file_path, "wb") as fb:
|
| 26 |
+
fb.write(data)
|
| 27 |
+
assert file_path is not None, f"failed to download: {url}"
|
| 28 |
+
return file_path
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Storage(metaclass=ABCMeta):
|
| 32 |
+
"""Abstract class of storage.
|
| 33 |
+
|
| 34 |
+
All backends need to implement two apis: ``read()`` and ``read_text()``.
|
| 35 |
+
``read()`` reads the file as a byte stream and ``read_text()`` reads
|
| 36 |
+
the file as texts.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def read(self, filepath: str):
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
@abstractmethod
|
| 44 |
+
def read_text(self, filepath: str):
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
@abstractmethod
|
| 52 |
+
def write_text(
|
| 53 |
+
self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8"
|
| 54 |
+
) -> None:
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class LocalStorage(Storage):
|
| 59 |
+
"""Local hard disk storage"""
|
| 60 |
+
|
| 61 |
+
def read(self, filepath: Union[str, Path]) -> bytes:
|
| 62 |
+
"""Read data from a given ``filepath`` with 'rb' mode.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
filepath (str or Path): Path to read data.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
bytes: Expected bytes object.
|
| 69 |
+
"""
|
| 70 |
+
with open(filepath, "rb") as f:
|
| 71 |
+
content = f.read()
|
| 72 |
+
return content
|
| 73 |
+
|
| 74 |
+
def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
|
| 75 |
+
"""Read data from a given ``filepath`` with 'r' mode.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
filepath (str or Path): Path to read data.
|
| 79 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 80 |
+
Default: 'utf-8'.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
str: Expected text reading from ``filepath``.
|
| 84 |
+
"""
|
| 85 |
+
with open(filepath, "r", encoding=encoding) as f:
|
| 86 |
+
value_buf = f.read()
|
| 87 |
+
return value_buf
|
| 88 |
+
|
| 89 |
+
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
| 90 |
+
"""Write data to a given ``filepath`` with 'wb' mode.
|
| 91 |
+
|
| 92 |
+
Note:
|
| 93 |
+
``write`` will create a directory if the directory of ``filepath``
|
| 94 |
+
does not exist.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
obj (bytes): Data to be written.
|
| 98 |
+
filepath (str or Path): Path to write data.
|
| 99 |
+
"""
|
| 100 |
+
dirname = os.path.dirname(filepath)
|
| 101 |
+
if dirname and not os.path.exists(dirname):
|
| 102 |
+
os.makedirs(dirname, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
with open(filepath, "wb") as f:
|
| 105 |
+
f.write(obj)
|
| 106 |
+
|
| 107 |
+
def write_text(
|
| 108 |
+
self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8"
|
| 109 |
+
) -> None:
|
| 110 |
+
"""Write data to a given ``filepath`` with 'w' mode.
|
| 111 |
+
|
| 112 |
+
Note:
|
| 113 |
+
``write_text`` will create a directory if the directory of
|
| 114 |
+
``filepath`` does not exist.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
obj (str): Data to be written.
|
| 118 |
+
filepath (str or Path): Path to write data.
|
| 119 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 120 |
+
Default: 'utf-8'.
|
| 121 |
+
"""
|
| 122 |
+
dirname = os.path.dirname(filepath)
|
| 123 |
+
if dirname and not os.path.exists(dirname):
|
| 124 |
+
os.makedirs(dirname, exist_ok=True)
|
| 125 |
+
|
| 126 |
+
with open(filepath, "w", encoding=encoding) as f:
|
| 127 |
+
f.write(obj)
|
| 128 |
+
|
| 129 |
+
@contextlib.contextmanager
|
| 130 |
+
def as_local_path(
|
| 131 |
+
self, filepath: Union[str, Path]
|
| 132 |
+
) -> Generator[Union[str, Path], None, None]:
|
| 133 |
+
"""Only for unified API and do nothing."""
|
| 134 |
+
yield filepath
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class HTTPStorage(Storage):
|
| 138 |
+
"""HTTP and HTTPS storage."""
|
| 139 |
+
|
| 140 |
+
def read(self, url):
|
| 141 |
+
# TODO @wenmeng.zwm add progress bar if file is too large
|
| 142 |
+
r = requests.get(url)
|
| 143 |
+
r.raise_for_status()
|
| 144 |
+
return r.content
|
| 145 |
+
|
| 146 |
+
def read_text(self, url):
|
| 147 |
+
r = requests.get(url)
|
| 148 |
+
r.raise_for_status()
|
| 149 |
+
return r.text
|
| 150 |
+
|
| 151 |
+
@contextlib.contextmanager
|
| 152 |
+
def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
| 153 |
+
"""Download a file from ``filepath``.
|
| 154 |
+
|
| 155 |
+
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
|
| 156 |
+
can be called with ``with`` statement, and when exists from the
|
| 157 |
+
``with`` statement, the temporary path will be released.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
filepath (str): Download a file from ``filepath``.
|
| 161 |
+
|
| 162 |
+
Examples:
|
| 163 |
+
>>> storage = HTTPStorage()
|
| 164 |
+
>>> # After existing from the ``with`` clause,
|
| 165 |
+
>>> # the path will be removed
|
| 166 |
+
>>> with storage.get_local_path('http://path/to/file') as path:
|
| 167 |
+
... # do something here
|
| 168 |
+
"""
|
| 169 |
+
try:
|
| 170 |
+
f = tempfile.NamedTemporaryFile(delete=False)
|
| 171 |
+
f.write(self.read(filepath))
|
| 172 |
+
f.close()
|
| 173 |
+
yield f.name
|
| 174 |
+
finally:
|
| 175 |
+
os.remove(f.name)
|
| 176 |
+
|
| 177 |
+
def write(self, obj: bytes, url: Union[str, Path]) -> None:
|
| 178 |
+
raise NotImplementedError("write is not supported by HTTP Storage")
|
| 179 |
+
|
| 180 |
+
def write_text(
|
| 181 |
+
self, obj: str, url: Union[str, Path], encoding: str = "utf-8"
|
| 182 |
+
) -> None:
|
| 183 |
+
raise NotImplementedError("write_text is not supported by HTTP Storage")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class OSSStorage(Storage):
|
| 187 |
+
"""OSS storage."""
|
| 188 |
+
|
| 189 |
+
def __init__(self, oss_config_file=None):
|
| 190 |
+
# read from config file or env var
|
| 191 |
+
raise NotImplementedError("OSSStorage.__init__ to be implemented in the future")
|
| 192 |
+
|
| 193 |
+
def read(self, filepath):
|
| 194 |
+
raise NotImplementedError("OSSStorage.read to be implemented in the future")
|
| 195 |
+
|
| 196 |
+
def read_text(self, filepath, encoding="utf-8"):
|
| 197 |
+
raise NotImplementedError(
|
| 198 |
+
"OSSStorage.read_text to be implemented in the future"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
@contextlib.contextmanager
|
| 202 |
+
def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
| 203 |
+
"""Download a file from ``filepath``.
|
| 204 |
+
|
| 205 |
+
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
|
| 206 |
+
can be called with ``with`` statement, and when exists from the
|
| 207 |
+
``with`` statement, the temporary path will be released.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
filepath (str): Download a file from ``filepath``.
|
| 211 |
+
|
| 212 |
+
Examples:
|
| 213 |
+
>>> storage = OSSStorage()
|
| 214 |
+
>>> # After existing from the ``with`` clause,
|
| 215 |
+
>>> # the path will be removed
|
| 216 |
+
>>> with storage.get_local_path('http://path/to/file') as path:
|
| 217 |
+
... # do something here
|
| 218 |
+
"""
|
| 219 |
+
try:
|
| 220 |
+
f = tempfile.NamedTemporaryFile(delete=False)
|
| 221 |
+
f.write(self.read(filepath))
|
| 222 |
+
f.close()
|
| 223 |
+
yield f.name
|
| 224 |
+
finally:
|
| 225 |
+
os.remove(f.name)
|
| 226 |
+
|
| 227 |
+
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
| 228 |
+
raise NotImplementedError("OSSStorage.write to be implemented in the future")
|
| 229 |
+
|
| 230 |
+
def write_text(
|
| 231 |
+
self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8"
|
| 232 |
+
) -> None:
|
| 233 |
+
raise NotImplementedError(
|
| 234 |
+
"OSSStorage.write_text to be implemented in the future"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
G_STORAGES = {}
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class File(object):
|
| 242 |
+
_prefix_to_storage: dict = {
|
| 243 |
+
"oss": OSSStorage,
|
| 244 |
+
"http": HTTPStorage,
|
| 245 |
+
"https": HTTPStorage,
|
| 246 |
+
"local": LocalStorage,
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
@staticmethod
|
| 250 |
+
def _get_storage(uri):
|
| 251 |
+
assert isinstance(uri, str), f"uri should be str type, but got {type(uri)}"
|
| 252 |
+
|
| 253 |
+
if "://" not in uri:
|
| 254 |
+
# local path
|
| 255 |
+
storage_type = "local"
|
| 256 |
+
else:
|
| 257 |
+
prefix, _ = uri.split("://")
|
| 258 |
+
storage_type = prefix
|
| 259 |
+
|
| 260 |
+
assert storage_type in File._prefix_to_storage, (
|
| 261 |
+
f"Unsupported uri {uri}, valid prefixs: "
|
| 262 |
+
f"{list(File._prefix_to_storage.keys())}"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if storage_type not in G_STORAGES:
|
| 266 |
+
G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
|
| 267 |
+
|
| 268 |
+
return G_STORAGES[storage_type]
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def read(uri: str) -> bytes:
|
| 272 |
+
"""Read data from a given ``filepath`` with 'rb' mode.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
filepath (str or Path): Path to read data.
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
bytes: Expected bytes object.
|
| 279 |
+
"""
|
| 280 |
+
storage = File._get_storage(uri)
|
| 281 |
+
return storage.read(uri)
|
| 282 |
+
|
| 283 |
+
@staticmethod
|
| 284 |
+
def read_text(uri: Union[str, Path], encoding: str = "utf-8") -> str:
|
| 285 |
+
"""Read data from a given ``filepath`` with 'r' mode.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
filepath (str or Path): Path to read data.
|
| 289 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 290 |
+
Default: 'utf-8'.
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
str: Expected text reading from ``filepath``.
|
| 294 |
+
"""
|
| 295 |
+
storage = File._get_storage(uri)
|
| 296 |
+
return storage.read_text(uri)
|
| 297 |
+
|
| 298 |
+
@staticmethod
|
| 299 |
+
def write(obj: bytes, uri: Union[str, Path]) -> None:
|
| 300 |
+
"""Write data to a given ``filepath`` with 'wb' mode.
|
| 301 |
+
|
| 302 |
+
Note:
|
| 303 |
+
``write`` will create a directory if the directory of ``filepath``
|
| 304 |
+
does not exist.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
obj (bytes): Data to be written.
|
| 308 |
+
filepath (str or Path): Path to write data.
|
| 309 |
+
"""
|
| 310 |
+
storage = File._get_storage(uri)
|
| 311 |
+
return storage.write(obj, uri)
|
| 312 |
+
|
| 313 |
+
@staticmethod
|
| 314 |
+
def write_text(obj: str, uri: str, encoding: str = "utf-8") -> None:
|
| 315 |
+
"""Write data to a given ``filepath`` with 'w' mode.
|
| 316 |
+
|
| 317 |
+
Note:
|
| 318 |
+
``write_text`` will create a directory if the directory of
|
| 319 |
+
``filepath`` does not exist.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
obj (str): Data to be written.
|
| 323 |
+
filepath (str or Path): Path to write data.
|
| 324 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 325 |
+
Default: 'utf-8'.
|
| 326 |
+
"""
|
| 327 |
+
storage = File._get_storage(uri)
|
| 328 |
+
return storage.write_text(obj, uri)
|
| 329 |
+
|
| 330 |
+
@contextlib.contextmanager
|
| 331 |
+
def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
|
| 332 |
+
"""Only for unified API and do nothing."""
|
| 333 |
+
storage = File._get_storage(uri)
|
| 334 |
+
with storage.as_local_path(uri) as local_path:
|
| 335 |
+
yield local_path
|
demo/Step-Audio-EditX/funasr_detach/download/name_maps_from_hub.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name_maps_ms = {
|
| 2 |
+
"paraformer-zh": "damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
| 3 |
+
"paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
|
| 4 |
+
"paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
|
| 5 |
+
"paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
| 6 |
+
"fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
| 7 |
+
"ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
|
| 8 |
+
"ct-punc-c": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
| 9 |
+
"fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
|
| 10 |
+
"cam++": "damo/speech_campplus_sv_zh-cn_16k-common",
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
name_maps_hf = {}
|
demo/Step-Audio-EditX/funasr_detach/download/runtime_sdk_download_tool.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from funasr_detach.utils.types import str2bool
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument("--model-name", type=str, required=True)
|
| 11 |
+
parser.add_argument("--export-dir", type=str, required=True)
|
| 12 |
+
parser.add_argument(
|
| 13 |
+
"--export", type=str2bool, default=True, help="whether to export model"
|
| 14 |
+
)
|
| 15 |
+
parser.add_argument("--type", type=str, default="onnx", help='["onnx", "torch"]')
|
| 16 |
+
parser.add_argument("--device", type=str, default="cpu", help='["cpu", "cuda"]')
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--quantize", type=str2bool, default=False, help="export quantized model"
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--fallback-num", type=int, default=0, help="amp fallback number"
|
| 22 |
+
)
|
| 23 |
+
parser.add_argument("--audio_in", type=str, default=None, help='["wav", "wav.scp"]')
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--model_revision", type=str, default=None, help="model_revision"
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument("--calib_num", type=int, default=200, help="calib max num")
|
| 28 |
+
args = parser.parse_args()
|
| 29 |
+
|
| 30 |
+
model_dir = args.model_name
|
| 31 |
+
if not Path(args.model_name).exists():
|
| 32 |
+
from modelscope.hub.snapshot_download import snapshot_download
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
model_dir = snapshot_download(
|
| 36 |
+
args.model_name, cache_dir=args.export_dir, revision=args.model_revision
|
| 37 |
+
)
|
| 38 |
+
except:
|
| 39 |
+
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
|
| 40 |
+
model_dir
|
| 41 |
+
)
|
| 42 |
+
if args.export:
|
| 43 |
+
model_file = os.path.join(model_dir, "model.onnx")
|
| 44 |
+
if args.quantize:
|
| 45 |
+
model_file = os.path.join(model_dir, "model_quant.onnx")
|
| 46 |
+
if not os.path.exists(model_file):
|
| 47 |
+
print(".onnx is not exist, begin to export onnx")
|
| 48 |
+
from funasr_detach.bin.export_model import ModelExport
|
| 49 |
+
|
| 50 |
+
export_model = ModelExport(
|
| 51 |
+
cache_dir=args.export_dir,
|
| 52 |
+
onnx=True,
|
| 53 |
+
device="cpu",
|
| 54 |
+
quant=args.quantize,
|
| 55 |
+
)
|
| 56 |
+
export_model.export(model_dir)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
main()
|
demo/Step-Audio-EditX/funasr_detach/frontends/__init__.py
ADDED
|
File without changes
|