niobures commited on
Commit
bfc4d3c
·
verified ·
1 Parent(s): f4280c0

Step-Audio (code, dataset, demo, paper, tools)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. Step-Audio-AQAA. A Fully End-to-End Expressive Large Audio Language Model.pdf +3 -0
  3. Step-Audio-EditX Technical Report.pdf +3 -0
  4. Step-Audio. Unified Understanding and Generation in Intelligent Speech Interaction.pdf +3 -0
  5. code/ComfyUI_StepAudioTTS.zip +3 -0
  6. code/Step-Audio [intervitens].zip +3 -0
  7. code/Step-Audio-EditX.zip +3 -0
  8. code/Step-Audio-tts.zip +3 -0
  9. code/Step-Audio.zip +3 -0
  10. code/Step-Audio2.zip +3 -0
  11. code/StepAudioInfer.zip +3 -0
  12. code/astrbot_plugin_tts_Step_Audio.zip +3 -0
  13. dataset/StepEval-Audio-360/.gitattributes +59 -0
  14. dataset/StepEval-Audio-360/README.md +79 -0
  15. dataset/StepEval-Audio-360/audios.tar.gz +3 -0
  16. dataset/StepEval-Audio-360/data/test-00000-of-00001.parquet +3 -0
  17. dataset/StepEval-Audio-360/source.txt +1 -0
  18. demo/Step-Audio-EditX/.gitattributes +4 -0
  19. demo/Step-Audio-EditX/.gitignore +2 -0
  20. demo/Step-Audio-EditX/LICENSE +201 -0
  21. demo/Step-Audio-EditX/README.md +13 -0
  22. demo/Step-Audio-EditX/__init__.py +0 -0
  23. demo/Step-Audio-EditX/app.py +505 -0
  24. demo/Step-Audio-EditX/config/__init__.py +12 -0
  25. demo/Step-Audio-EditX/config/edit_config.py +32 -0
  26. demo/Step-Audio-EditX/config/prompts.py +23 -0
  27. demo/Step-Audio-EditX/funasr_detach/__init__.py +38 -0
  28. demo/Step-Audio-EditX/funasr_detach/auto/__init__.py +0 -0
  29. demo/Step-Audio-EditX/funasr_detach/auto/auto_frontend.py +90 -0
  30. demo/Step-Audio-EditX/funasr_detach/auto/auto_model.py +575 -0
  31. demo/Step-Audio-EditX/funasr_detach/auto/auto_tokenizer.py +7 -0
  32. demo/Step-Audio-EditX/funasr_detach/bin/__init__.py +0 -0
  33. demo/Step-Audio-EditX/funasr_detach/bin/compute_audio_cmvn.py +152 -0
  34. demo/Step-Audio-EditX/funasr_detach/bin/inference.py +33 -0
  35. demo/Step-Audio-EditX/funasr_detach/bin/tokenize_text.py +281 -0
  36. demo/Step-Audio-EditX/funasr_detach/bin/train.py +227 -0
  37. demo/Step-Audio-EditX/funasr_detach/datasets/__init__.py +0 -0
  38. demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/__init__.py +0 -0
  39. demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/datasets.py +112 -0
  40. demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/index_ds.py +150 -0
  41. demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/preprocessor.py +55 -0
  42. demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/samplers.py +306 -0
  43. demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/scp2jsonl.py +116 -0
  44. demo/Step-Audio-EditX/funasr_detach/download/__init__.py +0 -0
  45. demo/Step-Audio-EditX/funasr_detach/download/download_dataset_from_hub.py +19 -0
  46. demo/Step-Audio-EditX/funasr_detach/download/download_from_hub.py +231 -0
  47. demo/Step-Audio-EditX/funasr_detach/download/file.py +335 -0
  48. demo/Step-Audio-EditX/funasr_detach/download/name_maps_from_hub.py +13 -0
  49. demo/Step-Audio-EditX/funasr_detach/download/runtime_sdk_download_tool.py +60 -0
  50. demo/Step-Audio-EditX/funasr_detach/frontends/__init__.py +0 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Step-Audio-AQAA.[[:space:]]A[[:space:]]Fully[[:space:]]End-to-End[[:space:]]Expressive[[:space:]]Large[[:space:]]Audio[[:space:]]Language[[:space:]]Model.pdf filter=lfs diff=lfs merge=lfs -text
37
+ Step-Audio-EditX[[:space:]]Technical[[:space:]]Report.pdf filter=lfs diff=lfs merge=lfs -text
38
+ Step-Audio.[[:space:]]Unified[[:space:]]Understanding[[:space:]]and[[:space:]]Generation[[:space:]]in[[:space:]]Intelligent[[:space:]]Speech[[:space:]]Interaction.pdf filter=lfs diff=lfs merge=lfs -text
Step-Audio-AQAA. A Fully End-to-End Expressive Large Audio Language Model.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4290ba946aaf9ebc8a1df00a905cbafb19f18ca3bcf9a38389716602ee5f7d7e
3
+ size 1203894
Step-Audio-EditX Technical Report.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1ff7493dcedd3e506b8de85860b0c608d06f0392245fb5385b7fa8231234e50
3
+ size 786245
Step-Audio. Unified Understanding and Generation in Intelligent Speech Interaction.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce5f6d5b9575f4552c970f118d3191ff49f3e509a847f0fff58c23aa7b510b3f
3
+ size 6952309
code/ComfyUI_StepAudioTTS.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17e36cf8812529f50c8b80b72d9111e20c476a2694365ad4f9049f019106b38b
3
+ size 14201121
code/Step-Audio [intervitens].zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88549705b04c2bbde00c6e5c67b8966c0aac195c1605a756abd893c53d690e00
3
+ size 37467537
code/Step-Audio-EditX.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6376c6fe2c68201749f7dee3717eab495e5630d3611511e3afaa9b1fe265afcd
3
+ size 5979796
code/Step-Audio-tts.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bba78bbc9039e5b3d7df4f77917119bbacc6d8aa9baf4d8114edfceda83fc624
3
+ size 3827854
code/Step-Audio.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4aaa4b50011e9c82ac51020de7177b43a07e5f47cbd7e8bb55e80929cac5d7a
3
+ size 55625681
code/Step-Audio2.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30f63a2dc8c598cc9c968c1a6cca2bc8150beeeec45b8b081843ac2580388dc9
3
+ size 26895459
code/StepAudioInfer.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bf64260061b9cfb68fc673770b78f561a783a889dc01a3346d5bb12c1f8bf25
3
+ size 39121775
code/astrbot_plugin_tts_Step_Audio.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6d563e44b30e27b1c51a05316e262ab6483db0cf6d42a43f5a5407ba9206380
3
+ size 6151313
dataset/StepEval-Audio-360/.gitattributes ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.lz4 filter=lfs diff=lfs merge=lfs -text
12
+ *.mds filter=lfs diff=lfs merge=lfs -text
13
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
14
+ *.model filter=lfs diff=lfs merge=lfs -text
15
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
16
+ *.npy filter=lfs diff=lfs merge=lfs -text
17
+ *.npz filter=lfs diff=lfs merge=lfs -text
18
+ *.onnx filter=lfs diff=lfs merge=lfs -text
19
+ *.ot filter=lfs diff=lfs merge=lfs -text
20
+ *.parquet filter=lfs diff=lfs merge=lfs -text
21
+ *.pb filter=lfs diff=lfs merge=lfs -text
22
+ *.pickle filter=lfs diff=lfs merge=lfs -text
23
+ *.pkl filter=lfs diff=lfs merge=lfs -text
24
+ *.pt filter=lfs diff=lfs merge=lfs -text
25
+ *.pth filter=lfs diff=lfs merge=lfs -text
26
+ *.rar filter=lfs diff=lfs merge=lfs -text
27
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
28
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
30
+ *.tar filter=lfs diff=lfs merge=lfs -text
31
+ *.tflite filter=lfs diff=lfs merge=lfs -text
32
+ *.tgz filter=lfs diff=lfs merge=lfs -text
33
+ *.wasm filter=lfs diff=lfs merge=lfs -text
34
+ *.xz filter=lfs diff=lfs merge=lfs -text
35
+ *.zip filter=lfs diff=lfs merge=lfs -text
36
+ *.zst filter=lfs diff=lfs merge=lfs -text
37
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
38
+ # Audio files - uncompressed
39
+ *.pcm filter=lfs diff=lfs merge=lfs -text
40
+ *.sam filter=lfs diff=lfs merge=lfs -text
41
+ *.raw filter=lfs diff=lfs merge=lfs -text
42
+ # Audio files - compressed
43
+ *.aac filter=lfs diff=lfs merge=lfs -text
44
+ *.flac filter=lfs diff=lfs merge=lfs -text
45
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
46
+ *.ogg filter=lfs diff=lfs merge=lfs -text
47
+ *.wav filter=lfs diff=lfs merge=lfs -text
48
+ # Image files - uncompressed
49
+ *.bmp filter=lfs diff=lfs merge=lfs -text
50
+ *.gif filter=lfs diff=lfs merge=lfs -text
51
+ *.png filter=lfs diff=lfs merge=lfs -text
52
+ *.tiff filter=lfs diff=lfs merge=lfs -text
53
+ # Image files - compressed
54
+ *.jpg filter=lfs diff=lfs merge=lfs -text
55
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
56
+ *.webp filter=lfs diff=lfs merge=lfs -text
57
+ # Video files - compressed
58
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
59
+ *.webm filter=lfs diff=lfs merge=lfs -text
dataset/StepEval-Audio-360/README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # StepEval-Audio-360
5
+ ## Dataset Description
6
+ StepEval Audio 360 is a comprehensive dataset that evaluates the ability of multi-modal large language models (MLLMs) in human-AI audio interaction. This audio benchmark dataset, sourced from professional human annotators, covers a full spectrum of capabilities: singing, creativity, role-playing, logical reasoning, voice understanding, voice instruction following, gaming, speech emotion control, and language ability.
7
+
8
+ ## Languages
9
+ StepEval Audio 360 comprises about human voice recorded in different languages and dialects, including Chinese(Szechuan dialect and cantonese), English, and Japanese. It contains both audio and transcription data.
10
+
11
+ ## Links
12
+ - Homepage: [Step-Audio](https://github.com/stepfun-ai/Step-Audio)
13
+ - Paper: [Step-Audio: Unified Understanding and Generation in Intelligent Speech Interaction
14
+ ](https://arxiv.org/abs/2502.11946)
15
+ - ModelScope: https://modelscope.cn/datasets/stepfun-ai/StepEval-Audio-360
16
+ - Step-Audio Model Suite:
17
+ - Step-Audio-Tokenizer:
18
+ - Hugging Face:https://huggingface.co/stepfun-ai/Step-Audio-Tokenizer
19
+ - ModelScope:https://modelscope.cn/models/stepfun-ai/Step-Audio-Tokenizer
20
+ - Step-Audio-Chat :
21
+ - HuggingFace: https://huggingface.co/stepfun-ai/Step-Audio-Chat
22
+ - ModelScope: https://modelscope.cn/models/stepfun-ai/Step-Audio-Chat
23
+ - Step-Audio-TTS-3B:
24
+ - Hugging Face: https://huggingface.co/stepfun-ai/Step-Audio-TTS-3B
25
+ - ModelScope: https://modelscope.cn/models/stepfun-ai/Step-Audio-TTS-3B
26
+
27
+ ## User Manual
28
+ * Download the dataset
29
+ ```
30
+ # Make sure you have git-lfs installed (https://git-lfs.com)
31
+ git lfs install
32
+ git clone https://huggingface.co/datasets/stepfun-ai/StepEval-Audio-360
33
+ cd StepEval-Audio-360
34
+ git lfs pull
35
+ ```
36
+
37
+ * Decompress audio data
38
+ ```
39
+ mkdir audios
40
+ tar -xvf audios.tar.gz -C audios
41
+ ```
42
+
43
+ * How to use
44
+ ```
45
+ from datasets import load_dataset
46
+
47
+ dataset = load_dataset("stepfun-ai/StepEval-Audio-360")
48
+ dataset = dataset["test"]
49
+ for item in dataset:
50
+ conversation_id = item["conversation_id"]
51
+ category = item["category"]
52
+ conversation = item["conversation"]
53
+
54
+ # parse multi-turn dialogue data
55
+ for turn in conversation:
56
+ role = turn["role"]
57
+ text = turn["text"]
58
+ audio_filename = turn["audio_filename"] # refer to decompressed audio file
59
+ if audio_filename is not None:
60
+ print(role, text, audio_filename)
61
+ else:
62
+ print(role, text)
63
+ ```
64
+
65
+ ## Licensing
66
+ This dataset project is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0).
67
+
68
+ ## Citation
69
+ If you utilize this dataset, please cite it using the BibTeX provided.
70
+ ```
71
+ @misc {stepfun_2025,
72
+ author = { {StepFun} },
73
+ title = { StepEval-Audio-360 (Revision 72a072e) },
74
+ year = 2025,
75
+ url = { https://huggingface.co/datasets/stepfun-ai/StepEval-Audio-360 },
76
+ doi = { 10.57967/hf/4528 },
77
+ publisher = { Hugging Face }
78
+ }
79
+ ```
dataset/StepEval-Audio-360/audios.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7e9f043765500c6f6940ae58a55cf226ddfdde533099f4765bc40d2710d82d3
3
+ size 166398432
dataset/StepEval-Audio-360/data/test-00000-of-00001.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2990e77433b866431bbad8adc27b3aebee77046ceca5d265113994fedf2eaff
3
+ size 69065
dataset/StepEval-Audio-360/source.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ https://huggingface.co/datasets/stepfun-ai/StepEval-Audio-360
demo/Step-Audio-EditX/.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ examples filter=lfs diff=lfs merge=lfs -text
2
+ speakers/nezha_prompt.wav filter=lfs diff=lfs merge=lfs -text
3
+ speakers/nezhaRAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
4
+ speakers/nezha哼唱_prompt.wav filter=lfs diff=lfs merge=lfs -text
demo/Step-Audio-EditX/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ output/
demo/Step-Audio-EditX/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
demo/Step-Audio-EditX/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Step-Audio-EditX
3
+ emoji: 🚀
4
+ colorFrom: red
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: true
10
+ short_description: Try out Step-Audio-EditX
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
demo/Step-Audio-EditX/__init__.py ADDED
File without changes
demo/Step-Audio-EditX/app.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import argparse
4
+ import torch
5
+ import logging
6
+ import threading
7
+ from datetime import datetime
8
+ import torchaudio
9
+ import librosa
10
+ import soundfile as sf
11
+
12
+ # ZeroGPU support
13
+ try:
14
+ import spaces
15
+ ZEROGPU_AVAILABLE = True
16
+ except ImportError:
17
+ ZEROGPU_AVAILABLE = False
18
+ # Create a dummy decorator for non-ZeroGPU environments
19
+ class spaces:
20
+ @staticmethod
21
+ def GPU(duration=10):
22
+ def decorator(func):
23
+ return func
24
+ return decorator
25
+
26
+ # Project imports
27
+ from tokenizer import StepAudioTokenizer
28
+ from tts import StepAudioTTS
29
+ from model_loader import ModelSource
30
+ from config.edit_config import get_supported_edit_types
31
+
32
+ # Configure logging
33
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # Global variables for ZeroGPU-optimized loading
37
+ encoder = None
38
+ common_tts_engine = None
39
+ args_global = None
40
+ _model_lock = threading.Lock() # Thread lock for model initialization
41
+
42
+ def initialize_models():
43
+ """Initialize models on first GPU call (ZeroGPU optimization: load inside GPU context)"""
44
+ global encoder, common_tts_engine, args_global
45
+
46
+ # Fast path: check if already initialized (without lock)
47
+ if common_tts_engine is not None:
48
+ return # Already initialized
49
+
50
+ # Slow path: acquire lock and double-check
51
+ with _model_lock:
52
+ # Double-check pattern: another thread might have initialized while waiting for lock
53
+ if common_tts_engine is not None:
54
+ return # Already initialized by another thread
55
+
56
+ if args_global is None:
57
+ raise RuntimeError("Global args not set. Cannot initialize models.")
58
+
59
+ try:
60
+ logger.info("🚀 Initializing models inside GPU context (first call)...")
61
+
62
+ # Determine model source
63
+ source_mapping = {
64
+ "auto": ModelSource.AUTO,
65
+ "local": ModelSource.LOCAL,
66
+ "modelscope": ModelSource.MODELSCOPE,
67
+ "huggingface": ModelSource.HUGGINGFACE
68
+ }
69
+ model_source = source_mapping[args_global.model_source]
70
+
71
+ # Load StepAudioTokenizer (avoid CUDA initialization in main process)
72
+ encoder = StepAudioTokenizer(
73
+ os.path.join(args_global.model_path, "Step-Audio-Tokenizer"),
74
+ model_source=model_source,
75
+ funasr_model_id=args_global.tokenizer_model_id
76
+ )
77
+ logger.info("✓ StepAudioTokenizer loaded")
78
+
79
+ # Initialize common TTS engine (avoid CUDA initialization in main process)
80
+ common_tts_engine = StepAudioTTS(
81
+ os.path.join(args_global.model_path, "Step-Audio-EditX"),
82
+ encoder,
83
+ model_source=model_source,
84
+ tts_model_id=args_global.tts_model_id
85
+ )
86
+ logger.info("✓ StepCommonAudioTTS loaded")
87
+ print("Models initialized inside GPU context.")
88
+
89
+ if ZEROGPU_AVAILABLE:
90
+ logger.info("💡 Models loaded inside GPU context - ready for inference")
91
+ else:
92
+ logger.info("💡 Models loaded - ready for inference")
93
+
94
+ except Exception as e:
95
+ logger.error(f"❌ Error loading models: {e}")
96
+ raise
97
+
98
+ def get_model_config():
99
+ """Get model configuration without initializing GPU models"""
100
+ if args_global is None:
101
+ raise RuntimeError("Global args not set. Cannot get model config.")
102
+
103
+ return {
104
+ "encoder_path": os.path.join(args_global.model_path, "Step-Audio-Tokenizer"),
105
+ "tts_path": os.path.join(args_global.model_path, "Step-Audio-EditX"),
106
+ "model_source": args_global.model_source,
107
+ "tokenizer_model_id": args_global.tokenizer_model_id,
108
+ "tts_model_id": args_global.tts_model_id
109
+ }
110
+
111
+ def get_gpu_duration(audio_input, text_input, target_text, task_type, task_info):
112
+ """Dynamic GPU duration based on whether models need initialization"""
113
+ global common_tts_engine
114
+
115
+ if common_tts_engine is None:
116
+ # First call - need time for model loading (up to 5 minutes)
117
+ return 300 # Maximum allowed duration for model initialization
118
+ else:
119
+ # Subsequent calls - only inference time needed
120
+ return 120 # Standard inference duration
121
+
122
+ @spaces.GPU(duration=get_gpu_duration) # Dynamic duration based on model state
123
+ def process_audio_with_gpu(audio_input, text_input, target_text, task_type, task_info):
124
+ """Process audio using GPU (models are loaded inside GPU context to avoid main process errors)"""
125
+ global common_tts_engine
126
+
127
+ # Initialize models if not already loaded (inside GPU context to avoid main process errors)
128
+ if common_tts_engine is None:
129
+ print("Initializing common_tts_engine inside GPU context...")
130
+ logger.info("🎯 GPU allocated for 300s (first call with model loading)...")
131
+ initialize_models()
132
+ logger.info("✅ Models loaded successfully inside GPU context")
133
+ else:
134
+ print("common_tts_engine already initialized.")
135
+ logger.info("🎯 GPU allocated for 120s (inference with loaded models)...")
136
+
137
+ try:
138
+ # Use loaded models (first call may include loading time, subsequent calls are fast)
139
+ if task_type == "clone":
140
+ output_audio, sr = common_tts_engine.clone(audio_input, text_input, target_text)
141
+ else:
142
+ output_audio, sr = common_tts_engine.edit(audio_input, text_input, task_type, task_info, target_text)
143
+
144
+ logger.info("✅ Audio processing completed")
145
+ return output_audio, sr
146
+
147
+ except Exception as e:
148
+ logger.error(f"❌ Audio processing failed: {e}")
149
+ raise
150
+ # GPU automatically deallocated when function exits
151
+
152
+ # Save audio to temporary directory
153
+ def save_audio(audio_type, audio_data, sr, tmp_dir):
154
+ """Save audio data to a temporary file with timestamp"""
155
+ current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
156
+ save_path = os.path.join(tmp_dir, audio_type, f"{current_time}.wav")
157
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
158
+
159
+ try:
160
+ if isinstance(audio_data, torch.Tensor):
161
+ torchaudio.save(save_path, audio_data, sr)
162
+ else:
163
+ sf.write(save_path, audio_data, sr)
164
+ logger.debug(f"Audio saved to: {save_path}")
165
+ return save_path
166
+ except Exception as e:
167
+ logger.error(f"Failed to save audio: {e}")
168
+ raise
169
+
170
+
171
+ class EditxTab:
172
+ """Audio editing and voice cloning interface tab"""
173
+
174
+ def __init__(self, args):
175
+ self.args = args
176
+ self.edit_type_list = list(get_supported_edit_types().keys())
177
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
178
+
179
+ def history_messages_to_show(self, messages):
180
+ """Convert message history to gradio chatbot format"""
181
+ show_msgs = []
182
+ for message in messages:
183
+ edit_type = message['edit_type']
184
+ edit_info = message['edit_info']
185
+ source_text = message['source_text']
186
+ target_text = message['target_text']
187
+ raw_audio_part = message['raw_wave']
188
+ edit_audio_part = message['edit_wave']
189
+ type_str = f"{edit_type}-{edit_info}" if edit_info is not None else f"{edit_type}"
190
+ show_msgs.extend([
191
+ {"role": "user", "content": f"任务类型:{type_str}\n文本:{source_text}"},
192
+ {"role": "user", "content": gr.Audio(value=raw_audio_part, interactive=False)},
193
+ {"role": "assistant", "content": f"输出音频:\n文本:{target_text}"},
194
+ {"role": "assistant", "content": gr.Audio(value=edit_audio_part, interactive=False)}
195
+ ])
196
+ return show_msgs
197
+
198
+ def generate_clone(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state):
199
+ """Generate cloned audio (models are loaded on first GPU call)"""
200
+ self.logger.info("Starting voice cloning process")
201
+ state['history_audio'] = []
202
+ state['history_messages'] = []
203
+
204
+ # Input validation
205
+ if not prompt_text_input or prompt_text_input.strip() == "":
206
+ error_msg = "[Error] Uploaded text cannot be empty."
207
+ self.logger.error(error_msg)
208
+ return [{"role": "user", "content": error_msg}], state
209
+ if not prompt_audio_input:
210
+ error_msg = "[Error] Uploaded audio cannot be empty."
211
+ self.logger.error(error_msg)
212
+ return [{"role": "user", "content": error_msg}], state
213
+ if not generated_text or generated_text.strip() == "":
214
+ error_msg = "[Error] Clone content cannot be empty."
215
+ self.logger.error(error_msg)
216
+ return [{"role": "user", "content": error_msg}], state
217
+ if edit_type != "clone":
218
+ error_msg = "[Error] CLONE button must use clone task."
219
+ self.logger.error(error_msg)
220
+ return [{"role": "user", "content": error_msg}], state
221
+
222
+ try:
223
+ # Use GPU inference with models loaded inside GPU context
224
+ output_audio, output_sr = process_audio_with_gpu(
225
+ prompt_audio_input, prompt_text_input, generated_text, "clone", edit_info
226
+ )
227
+
228
+ if output_audio is not None and output_sr is not None:
229
+ # Convert tensor to numpy if needed
230
+ if isinstance(output_audio, torch.Tensor):
231
+ audio_numpy = output_audio.cpu().numpy().squeeze()
232
+ else:
233
+ audio_numpy = output_audio
234
+
235
+ # Load original audio for comparison
236
+ input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input)
237
+
238
+ # Create message for history
239
+ cur_assistant_msg = {
240
+ "edit_type": edit_type,
241
+ "edit_info": edit_info,
242
+ "source_text": prompt_text_input,
243
+ "target_text": generated_text,
244
+ "raw_wave": (input_sample_rate, input_audio_data_numpy),
245
+ "edit_wave": (output_sr, audio_numpy),
246
+ }
247
+ state["history_audio"].append((output_sr, audio_numpy, generated_text))
248
+ state["history_messages"].append(cur_assistant_msg)
249
+
250
+ show_msgs = self.history_messages_to_show(state["history_messages"])
251
+ self.logger.info("Voice cloning completed successfully")
252
+ return show_msgs, state
253
+ else:
254
+ error_msg = "[Error] Clone failed"
255
+ self.logger.error(error_msg)
256
+ return [{"role": "user", "content": error_msg}], state
257
+
258
+ except Exception as e:
259
+ error_msg = f"[Error] Clone failed: {str(e)}"
260
+ self.logger.error(error_msg)
261
+ return [{"role": "user", "content": error_msg}], state
262
+
263
+ def generate_edit(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state):
264
+ """Generate edited audio (models are loaded on first GPU call)"""
265
+ self.logger.info("Starting audio editing process")
266
+
267
+ # Input validation
268
+ if not prompt_audio_input:
269
+ error_msg = "[Error] Uploaded audio cannot be empty."
270
+ self.logger.error(error_msg)
271
+ return [{"role": "user", "content": error_msg}], state
272
+
273
+ try:
274
+ # Determine which audio to use
275
+ if len(state["history_audio"]) == 0:
276
+ # First edit - use uploaded audio
277
+ audio_to_edit = prompt_audio_input
278
+ text_to_use = prompt_text_input
279
+ self.logger.debug("Using prompt audio, no history found")
280
+ else:
281
+ # Use previous edited audio - save it to temp file first
282
+ sample_rate, audio_numpy, previous_text = state["history_audio"][-1]
283
+ temp_path = save_audio("temp", audio_numpy, sample_rate, self.args.tmp_dir)
284
+ audio_to_edit = temp_path
285
+ text_to_use = previous_text
286
+ self.logger.debug(f"Using previous audio from history, count: {len(state['history_audio'])}")
287
+
288
+ # For para-linguistic, use generated_text; otherwise use source text
289
+ if edit_type not in {"paralinguistic"}:
290
+ generated_text = text_to_use
291
+
292
+ # Use GPU inference with models loaded inside GPU context
293
+ output_audio, output_sr = process_audio_with_gpu(
294
+ audio_to_edit, text_to_use, generated_text, edit_type, edit_info
295
+ )
296
+
297
+ if output_audio is not None and output_sr is not None:
298
+ # Convert tensor to numpy if needed
299
+ if isinstance(output_audio, torch.Tensor):
300
+ audio_numpy = output_audio.cpu().numpy().squeeze()
301
+ else:
302
+ audio_numpy = output_audio
303
+
304
+ # Load original audio for comparison
305
+ if len(state["history_audio"]) == 0:
306
+ input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input)
307
+ else:
308
+ input_sample_rate, input_audio_data_numpy, _ = state["history_audio"][-1]
309
+
310
+ # Create message for history
311
+ cur_assistant_msg = {
312
+ "edit_type": edit_type,
313
+ "edit_info": edit_info,
314
+ "source_text": text_to_use,
315
+ "target_text": generated_text,
316
+ "raw_wave": (input_sample_rate, input_audio_data_numpy),
317
+ "edit_wave": (output_sr, audio_numpy),
318
+ }
319
+ state["history_audio"].append((output_sr, audio_numpy, generated_text))
320
+ state["history_messages"].append(cur_assistant_msg)
321
+
322
+ show_msgs = self.history_messages_to_show(state["history_messages"])
323
+ self.logger.info("Audio editing completed successfully")
324
+ return show_msgs, state
325
+ else:
326
+ error_msg = "[Error] Edit failed"
327
+ self.logger.error(error_msg)
328
+ return [{"role": "user", "content": error_msg}], state
329
+
330
+ except Exception as e:
331
+ error_msg = f"[Error] Edit failed: {str(e)}"
332
+ self.logger.error(error_msg)
333
+ return [{"role": "user", "content": error_msg}], state
334
+
335
+ def clear_history(self, state):
336
+ """Clear conversation history"""
337
+ state["history_messages"] = []
338
+ state["history_audio"] = []
339
+ return [], state
340
+
341
+ def init_state(self):
342
+ """Initialize conversation state"""
343
+ return {
344
+ "history_messages": [],
345
+ "history_audio": []
346
+ }
347
+
348
+ def register_components(self):
349
+ """Register gradio components - maintaining exact layout from original"""
350
+ with gr.Tab("Editx"):
351
+ with gr.Row():
352
+ with gr.Column():
353
+ self.model_input = gr.Textbox(label="Model Name", value="Step-Audio-EditX", scale=1)
354
+ self.prompt_text_input = gr.Textbox(label="Prompt Text", value="", scale=1)
355
+ self.prompt_audio_input = gr.Audio(
356
+ sources=["upload", "microphone"],
357
+ format="wav",
358
+ type="filepath",
359
+ label="Input Audio",
360
+ )
361
+ self.generated_text = gr.Textbox(label="Target Text", lines=1, max_lines=200, max_length=1000)
362
+ with gr.Column():
363
+ with gr.Row():
364
+ self.edit_type = gr.Dropdown(label="Task", choices=self.edit_type_list, value="clone")
365
+ self.edit_info = gr.Dropdown(label="Sub-task", choices=[], value=None)
366
+ self.chat_box = gr.Chatbot(label="History", type="messages", height=480*1)
367
+ with gr.Row():
368
+ with gr.Column():
369
+ with gr.Row():
370
+ self.button_tts = gr.Button("CLONE", variant="primary")
371
+ self.button_edit = gr.Button("EDIT", variant="primary")
372
+ with gr.Column():
373
+ self.clean_history_submit = gr.Button("Clear History", variant="primary")
374
+
375
+ gr.Markdown("---")
376
+ gr.Markdown("""
377
+ **Button Description:**
378
+ - CLONE: Synthesizes audio based on uploaded audio and text, only used for clone mode, will clear history information when used.
379
+ - EDIT: Edits based on uploaded audio, or continues to stack edit effects based on the previous round of generated audio.
380
+ """)
381
+ gr.Markdown("""
382
+ **Operation Workflow:**
383
+ - Upload the audio to be edited on the left side and fill in the corresponding text content of the audio;
384
+ - If the task requires modifying text content (such as clone, para-linguistic), fill in the text to be synthesized in the "clone text" field. For all other tasks, keep the uploaded audio text content unchanged;
385
+ - Select tasks and subtasks on the right side (some tasks have no subtasks, such as vad, etc.);
386
+ - Click the "CLONE" or "EDIT" button on the left side, and audio will be generated in the dialog box on the right side.
387
+ """)
388
+ gr.Markdown("""
389
+ **Para-linguistic Description:**
390
+ - Supported tags include: [Breathing] [Laughter] [Surprise-oh] [Confirmation-en] [Uhm] [Surprise-ah] [Surprise-wa] [Sigh] [Question-ei] [Dissatisfaction-hnn]
391
+ - Example:
392
+ - Fill in "clone text" field: "Great, the weather is so nice today." Click the "CLONE" button to get audio.
393
+ - Change "clone text" field to: "Great[Laughter], the weather is so nice today[Surprise-ah]." Click the "EDIT" button to get para-linguistic audio.
394
+ """)
395
+
396
+ def register_events(self):
397
+ """Register event handlers"""
398
+ # Create independent state for each session
399
+ state = gr.State(self.init_state())
400
+
401
+ self.button_tts.click(self.generate_clone,
402
+ inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state],
403
+ outputs=[self.chat_box, state])
404
+ self.button_edit.click(self.generate_edit,
405
+ inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state],
406
+ outputs=[self.chat_box, state])
407
+
408
+ self.clean_history_submit.click(self.clear_history, inputs=[state], outputs=[self.chat_box, state])
409
+ self.edit_type.change(
410
+ fn=self.update_edit_info,
411
+ inputs=self.edit_type,
412
+ outputs=self.edit_info,
413
+ )
414
+
415
+ def update_edit_info(self, category):
416
+ """Update sub-task dropdown based on main task selection"""
417
+ category_items = get_supported_edit_types()
418
+ choices = category_items.get(category, [])
419
+ value = None if len(choices) == 0 else choices[0]
420
+ return gr.Dropdown(label="Sub-task", choices=choices, value=value)
421
+
422
+
423
+ def launch_demo(args, editx_tab):
424
+ """Launch the gradio demo"""
425
+ with gr.Blocks(
426
+ theme=gr.themes.Soft(),
427
+ title="🎙️ Step-Audio-EditX",
428
+ css="""
429
+ :root {
430
+ --font: "Helvetica Neue", Helvetica, Arial, sans-serif;
431
+ --font-mono: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
432
+ }
433
+ """) as demo:
434
+ gr.Markdown("## 🎙️ Step-Audio-EditX")
435
+ gr.Markdown("Audio Editing and Zero-Shot Cloning using Step-Audio-EditX")
436
+
437
+ # Register components
438
+ editx_tab.register_components()
439
+
440
+ # Register events
441
+ editx_tab.register_events()
442
+
443
+ # Launch demo
444
+ demo.queue().launch(
445
+ server_name=args.server_name,
446
+ server_port=args.server_port,
447
+ share=args.share if hasattr(args, 'share') else False
448
+ )
449
+
450
+
451
+ if __name__ == "__main__":
452
+ # Parse command line arguments
453
+ parser = argparse.ArgumentParser(description="Step-Audio Edit Demo")
454
+ parser.add_argument("--model-path", type=str, default="stepfun-ai", help="Model path.")
455
+ parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
456
+ parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
457
+ parser.add_argument("--tmp-dir", type=str, default="/tmp/gradio", help="Save path.")
458
+ parser.add_argument("--share", action="store_true", help="Share gradio app.")
459
+
460
+ # Multi-source loading support parameters
461
+ parser.add_argument(
462
+ "--model-source",
463
+ type=str,
464
+ default="huggingface",
465
+ choices=["auto", "local", "modelscope", "huggingface"],
466
+ help="Model source: auto (detect automatically), local, modelscope, or huggingface"
467
+ )
468
+ parser.add_argument(
469
+ "--tokenizer-model-id",
470
+ type=str,
471
+ default="dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online",
472
+ help="Tokenizer model ID for online loading"
473
+ )
474
+ parser.add_argument(
475
+ "--tts-model-id",
476
+ type=str,
477
+ default=None,
478
+ help="TTS model ID for online loading (if different from model-path)"
479
+ )
480
+
481
+ args = parser.parse_args()
482
+
483
+ # Store args globally for model configuration
484
+ args_global = args
485
+
486
+ logger.info(f"Configuration loaded:")
487
+ logger.info(f"Model source: {args.model_source}")
488
+ logger.info(f"Model path: {args.model_path}")
489
+ logger.info(f"Tokenizer model ID: {args.tokenizer_model_id}")
490
+ if args.tts_model_id:
491
+ logger.info(f"TTS model ID: {args.tts_model_id}")
492
+
493
+ # Models will be initialized on first GPU call to avoid ZeroGPU main process errors
494
+
495
+ if ZEROGPU_AVAILABLE:
496
+ logger.info("🎉 ZeroGPU detected - using dynamic GPU duration management!")
497
+ logger.info("💡 First call: 300s (model loading), subsequent calls: 120s (inference only)")
498
+ else:
499
+ logger.info("💻 Running in local mode - models will be loaded on first call")
500
+
501
+ # Create EditxTab instance
502
+ editx_tab = EditxTab(args)
503
+
504
+ # Launch demo
505
+ launch_demo(args, editx_tab)
demo/Step-Audio-EditX/config/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration module for Step-Audio
3
+ """
4
+
5
+ from .prompts import AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL, AUDIO_EDIT_SYSTEM_PROMPT
6
+ from .edit_config import get_supported_edit_types
7
+
8
+ __all__ = [
9
+ 'AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL',
10
+ 'AUDIO_EDIT_SYSTEM_PROMPT',
11
+ 'get_supported_edit_types'
12
+ ]
demo/Step-Audio-EditX/config/edit_config.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 音频编辑配置模块
3
+ 包含支持的编辑类型和相关配置
4
+ """
5
+
6
+ def get_supported_edit_types():
7
+ """
8
+ 获取支持的编辑类型和选项
9
+
10
+ Returns:
11
+ Dict[str, list]: Dictionary of edit types and their options
12
+ """
13
+ return {
14
+ "clone": [],
15
+ "emotion": [
16
+ 'happy', 'angry', 'sad', 'humour', 'confusion', 'disgusted',
17
+ 'empathy', 'embarrass', 'fear', 'surprised', 'excited',
18
+ 'depressed', 'coldness', 'admiration', 'remove'
19
+ ],
20
+ "style": [
21
+ 'serious', 'arrogant', 'child', 'older', 'girl', 'pure',
22
+ 'sister', 'sweet', 'ethereal', 'whisper', 'gentle', 'recite',
23
+ 'generous', 'act_coy', 'warm', 'shy', 'comfort', 'authority',
24
+ 'chat', 'radio', 'soulful', 'story', 'vivid', 'program',
25
+ 'news', 'advertising', 'roar', 'murmur', 'shout', 'deeply', 'loudly',
26
+ 'remove', 'exaggerated'
27
+ ],
28
+ "vad": [],
29
+ "denoise": [],
30
+ "paralinguistic": [],
31
+ "speed": ["faster", "slower", "more faster", "more slower"],
32
+ }
demo/Step-Audio-EditX/config/prompts.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 系统提示配置模块
3
+ 包含所有TTS和编辑相关的系统提示
4
+ """
5
+
6
+ AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL = """Generate audio with the following timbre, prosody and speaking style
7
+
8
+ [speaker_start]
9
+ speaker name: {speaker}
10
+ speaker prompt text:
11
+ {prompt_text}
12
+ speaker audio tokens:
13
+ {prompt_wav_tokens}
14
+ [speaker_end]
15
+ """
16
+
17
+ AUDIO_EDIT_SYSTEM_PROMPT = """As a highly skilled audio editing and tuning specialist, you excel in interpreting user instructions and applying precise adjustments to meet their needs. Your expertise spans a wide range of enhancement capabilities, including but not limited to:
18
+ # Emotional Enhancement
19
+ # Speaking Style Transfer
20
+ # Non-linguistic Adjustments
21
+ # Audio Tuning & Editing
22
+ Note: You will receive instructions in natural language and are expected to accurately interpret and execute the most suitable audio edits and enhancements.
23
+ """
demo/Step-Audio-EditX/funasr_detach/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Initialize funasr package."""
2
+
3
+ import os
4
+ import pkgutil
5
+ import importlib
6
+
7
+ dirname = os.path.dirname(__file__)
8
+ version_file = os.path.join(dirname, "version.txt")
9
+ with open(version_file, "r") as f:
10
+ __version__ = f.read().strip()
11
+
12
+
13
+ import importlib
14
+ import pkgutil
15
+
16
+
17
+ def import_submodules(package, recursive=True):
18
+ if isinstance(package, str):
19
+ package = importlib.import_module(package)
20
+ results = {}
21
+ for loader, name, is_pkg in pkgutil.walk_packages(
22
+ package.__path__, package.__name__ + "."
23
+ ):
24
+ try:
25
+ results[name] = importlib.import_module(name)
26
+ except Exception as e:
27
+ # 如果想要看到导入错误的具体信息,可以取消注释下面的行
28
+ # print(f"Failed to import {name}: {e}")
29
+ pass
30
+ if recursive and is_pkg:
31
+ results.update(import_submodules(name))
32
+ return results
33
+
34
+
35
+ import_submodules(__name__)
36
+
37
+ from funasr_detach.auto.auto_model import AutoModel
38
+ from funasr_detach.auto.auto_frontend import AutoFrontend
demo/Step-Audio-EditX/funasr_detach/auto/__init__.py ADDED
File without changes
demo/Step-Audio-EditX/funasr_detach/auto/auto_frontend.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
+ from tqdm import tqdm
4
+
5
+ from funasr_detach.register import tables
6
+ from funasr_detach.download.download_from_hub import download_model
7
+ from funasr_detach.utils.load_utils import load_audio_text_image_video, extract_fbank
8
+ from funasr_detach.auto.auto_model import prepare_data_iterator
9
+ from funasr_detach.auto.auto_model import prepare_data_iterator
10
+
11
+
12
+ class AutoFrontend:
13
+ def __init__(self, **kwargs):
14
+ assert "model" in kwargs
15
+ if "model_conf" not in kwargs:
16
+ logging.info(
17
+ "download models from model hub: {}".format(
18
+ kwargs.get("model_hub", "ms")
19
+ )
20
+ )
21
+ kwargs = download_model(**kwargs)
22
+
23
+ # build frontend
24
+ frontend = kwargs.get("frontend", None)
25
+ if frontend is not None:
26
+ frontend_class = tables.frontend_classes.get(frontend)
27
+ frontend = frontend_class(**kwargs["frontend_conf"])
28
+
29
+ self.frontend = frontend
30
+ if "frontend" in kwargs:
31
+ del kwargs["frontend"]
32
+ self.kwargs = kwargs
33
+
34
+ def __call__(self, input, input_len=None, kwargs=None, **cfg):
35
+
36
+ kwargs = self.kwargs if kwargs is None else kwargs
37
+ kwargs.update(cfg)
38
+
39
+ key_list, data_list = prepare_data_iterator(input, input_len=input_len)
40
+ batch_size = kwargs.get("batch_size", 1)
41
+ device = kwargs.get("device", "cpu")
42
+ if device == "cpu":
43
+ batch_size = 1
44
+
45
+ meta_data = {}
46
+
47
+ result_list = []
48
+ num_samples = len(data_list)
49
+ pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
50
+
51
+ time0 = time.perf_counter()
52
+ for beg_idx in range(0, num_samples, batch_size):
53
+ end_idx = min(num_samples, beg_idx + batch_size)
54
+ data_batch = data_list[beg_idx:end_idx]
55
+ key_batch = key_list[beg_idx:end_idx]
56
+
57
+ # extract fbank feats
58
+ time1 = time.perf_counter()
59
+ audio_sample_list = load_audio_text_image_video(
60
+ data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)
61
+ )
62
+ time2 = time.perf_counter()
63
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
64
+ speech, speech_lengths = extract_fbank(
65
+ audio_sample_list,
66
+ data_type=kwargs.get("data_type", "sound"),
67
+ frontend=self.frontend,
68
+ **kwargs,
69
+ )
70
+ time3 = time.perf_counter()
71
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
72
+ meta_data["batch_data_time"] = (
73
+ speech_lengths.sum().item()
74
+ * self.frontend.frame_shift
75
+ * self.frontend.lfr_n
76
+ / 1000
77
+ )
78
+
79
+ speech.to(device=device), speech_lengths.to(device=device)
80
+ batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
81
+ result_list.append(batch)
82
+
83
+ pbar.update(1)
84
+ description = f"{meta_data}, "
85
+ pbar.set_description(description)
86
+
87
+ time_end = time.perf_counter()
88
+ pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
89
+
90
+ return result_list
demo/Step-Audio-EditX/funasr_detach/auto/auto_model.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import copy
4
+ import torch
5
+ import random
6
+ import string
7
+ import logging
8
+ import os.path
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from funasr_detach.register import tables
13
+ from funasr_detach.utils.load_utils import load_bytes
14
+ from funasr_detach.download.file import download_from_url
15
+ from funasr_detach.download.download_from_hub import download_model
16
+ from funasr_detach.utils.vad_utils import slice_padding_audio_samples
17
+ from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
18
+ from funasr_detach.train_utils.load_pretrained_model import load_pretrained_model
19
+ from funasr_detach.utils.load_utils import load_audio_text_image_video
20
+ from funasr_detach.utils.timestamp_tools import timestamp_sentence
21
+ from funasr_detach.models.campplus.utils import sv_chunk, postprocess, distribute_spk
22
+
23
+ try:
24
+ from funasr_detach.models.campplus.cluster_backend import ClusterBackend
25
+ except:
26
+ print("If you want to use the speaker diarization, please `pip install hdbscan`")
27
+
28
+
29
+ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
30
+ """
31
+
32
+ :param input:
33
+ :param input_len:
34
+ :param data_type:
35
+ :param frontend:
36
+ :return:
37
+ """
38
+ data_list = []
39
+ key_list = []
40
+ filelist = [".scp", ".txt", ".json", ".jsonl"]
41
+
42
+ chars = string.ascii_letters + string.digits
43
+ if isinstance(data_in, str) and data_in.startswith("http"): # url
44
+ data_in = download_from_url(data_in)
45
+ if isinstance(data_in, str) and os.path.exists(
46
+ data_in
47
+ ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
48
+ _, file_extension = os.path.splitext(data_in)
49
+ file_extension = file_extension.lower()
50
+ if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
51
+ with open(data_in, encoding="utf-8") as fin:
52
+ for line in fin:
53
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
54
+ if data_in.endswith(
55
+ ".jsonl"
56
+ ): # file.jsonl: json.dumps({"source": data})
57
+ lines = json.loads(line.strip())
58
+ data = lines["source"]
59
+ key = data["key"] if "key" in data else key
60
+ else: # filelist, wav.scp, text.txt: id \t data or data
61
+ lines = line.strip().split(maxsplit=1)
62
+ data = lines[1] if len(lines) > 1 else lines[0]
63
+ key = lines[0] if len(lines) > 1 else key
64
+
65
+ data_list.append(data)
66
+ key_list.append(key)
67
+ else:
68
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
69
+ data_list = [data_in]
70
+ key_list = [key]
71
+ elif isinstance(data_in, (list, tuple)):
72
+ if data_type is not None and isinstance(
73
+ data_type, (list, tuple)
74
+ ): # mutiple inputs
75
+ data_list_tmp = []
76
+ for data_in_i, data_type_i in zip(data_in, data_type):
77
+ key_list, data_list_i = prepare_data_iterator(
78
+ data_in=data_in_i, data_type=data_type_i
79
+ )
80
+ data_list_tmp.append(data_list_i)
81
+ data_list = []
82
+ for item in zip(*data_list_tmp):
83
+ data_list.append(item)
84
+ else:
85
+ # [audio sample point, fbank, text]
86
+ data_list = data_in
87
+ key_list = [
88
+ "rand_key_" + "".join(random.choice(chars) for _ in range(13))
89
+ for _ in range(len(data_in))
90
+ ]
91
+ else: # raw text; audio sample point, fbank; bytes
92
+ if isinstance(data_in, bytes): # audio bytes
93
+ data_in = load_bytes(data_in)
94
+ if key is None:
95
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
96
+ data_list = [data_in]
97
+ key_list = [key]
98
+
99
+ return key_list, data_list
100
+
101
+
102
+ class AutoModel:
103
+
104
+ def __init__(self, **kwargs):
105
+ if not kwargs.get("disable_log", False):
106
+ tables.print()
107
+
108
+ model, kwargs = self.build_model(**kwargs)
109
+
110
+ # if vad_model is not None, build vad model else None
111
+ vad_model = kwargs.get("vad_model", None)
112
+ vad_kwargs = kwargs.get("vad_model_revision", None)
113
+ if vad_model is not None:
114
+ logging.info("Building VAD model.")
115
+ vad_kwargs = {
116
+ "model": vad_model,
117
+ "model_revision": vad_kwargs,
118
+ "device": kwargs["device"],
119
+ }
120
+ vad_model, vad_kwargs = self.build_model(**vad_kwargs)
121
+
122
+ # if punc_model is not None, build punc model else None
123
+ punc_model = kwargs.get("punc_model", None)
124
+ punc_kwargs = kwargs.get("punc_model_revision", None)
125
+ if punc_model is not None:
126
+ logging.info("Building punc model.")
127
+ punc_kwargs = {
128
+ "model": punc_model,
129
+ "model_revision": punc_kwargs,
130
+ "device": kwargs["device"],
131
+ }
132
+ punc_model, punc_kwargs = self.build_model(**punc_kwargs)
133
+
134
+ # if spk_model is not None, build spk model else None
135
+ spk_model = kwargs.get("spk_model", None)
136
+ spk_kwargs = kwargs.get("spk_model_revision", None)
137
+ if spk_model is not None:
138
+ logging.info("Building SPK model.")
139
+ spk_kwargs = {
140
+ "model": spk_model,
141
+ "model_revision": spk_kwargs,
142
+ "device": kwargs["device"],
143
+ }
144
+ spk_model, spk_kwargs = self.build_model(**spk_kwargs)
145
+ self.cb_model = ClusterBackend().to(kwargs["device"])
146
+ spk_mode = kwargs.get("spk_mode", "punc_segment")
147
+ if spk_mode not in ["default", "vad_segment", "punc_segment"]:
148
+ logging.error(
149
+ "spk_mode should be one of default, vad_segment and punc_segment."
150
+ )
151
+ self.spk_mode = spk_mode
152
+
153
+ self.kwargs = kwargs
154
+ self.model = model
155
+ self.vad_model = vad_model
156
+ self.vad_kwargs = vad_kwargs
157
+ self.punc_model = punc_model
158
+ self.punc_kwargs = punc_kwargs
159
+ self.spk_model = spk_model
160
+ self.spk_kwargs = spk_kwargs
161
+ self.model_path = kwargs.get("model_path")
162
+ self.repo_path = kwargs.get("repo_path")
163
+
164
+
165
+ def build_model(self, **kwargs):
166
+ assert "model" in kwargs
167
+ if "model_conf" not in kwargs:
168
+ logging.info(
169
+ "download models from model hub: {}".format(
170
+ kwargs.get("model_hub", "ms")
171
+ )
172
+ )
173
+ kwargs = download_model(**kwargs)
174
+
175
+ set_all_random_seed(kwargs.get("seed", 0))
176
+
177
+ device = kwargs.get("device", "cuda")
178
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
179
+ device = "cpu"
180
+ kwargs["batch_size"] = 1
181
+ kwargs["device"] = device
182
+
183
+ if kwargs.get("ncpu", None):
184
+ torch.set_num_threads(kwargs.get("ncpu"))
185
+
186
+ # build tokenizer
187
+ tokenizer = kwargs.get("tokenizer", None)
188
+ if tokenizer is not None:
189
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
190
+ tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
191
+ kwargs["tokenizer"] = tokenizer
192
+ kwargs["token_list"] = tokenizer.token_list
193
+ vocab_size = len(tokenizer.token_list)
194
+ else:
195
+ vocab_size = -1
196
+
197
+ # build frontend
198
+ frontend = kwargs.get("frontend", None)
199
+ if frontend is not None:
200
+ frontend_class = tables.frontend_classes.get(frontend)
201
+ frontend = frontend_class(**kwargs["frontend_conf"])
202
+ kwargs["frontend"] = frontend
203
+ kwargs["input_size"] = frontend.output_size()
204
+
205
+ # build model
206
+ model_class = tables.model_classes.get(kwargs["model"])
207
+ model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
208
+
209
+ model.to(device)
210
+
211
+ # init_param
212
+ init_param = kwargs.get("init_param", None)
213
+ if init_param is not None:
214
+ logging.info(f"Loading pretrained params from {init_param}")
215
+ load_pretrained_model(
216
+ model=model,
217
+ path=init_param,
218
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
219
+ oss_bucket=kwargs.get("oss_bucket", None),
220
+ scope_map=kwargs.get("scope_map", None),
221
+ excludes=kwargs.get("excludes", None),
222
+ )
223
+
224
+ return model, kwargs
225
+
226
+ def __call__(self, *args, **cfg):
227
+ kwargs = self.kwargs
228
+ kwargs.update(cfg)
229
+ res = self.model(*args, kwargs)
230
+ return res
231
+
232
+ def generate(self, input, input_len=None, **cfg):
233
+ if self.vad_model is None:
234
+ return self.inference(input, input_len=input_len, **cfg)
235
+
236
+ else:
237
+ return self.inference_with_vad(input, input_len=input_len, **cfg)
238
+
239
+ def inference(
240
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
241
+ ):
242
+ kwargs = self.kwargs if kwargs is None else kwargs
243
+ kwargs.update(cfg)
244
+ model = self.model if model is None else model
245
+ model = model.cuda()
246
+ model.eval()
247
+
248
+ batch_size = kwargs.get("batch_size", 1)
249
+ # if kwargs.get("device", "cpu") == "cpu":
250
+ # batch_size = 1
251
+
252
+ key_list, data_list = prepare_data_iterator(
253
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
254
+ )
255
+
256
+ speed_stats = {}
257
+ asr_result_list = []
258
+ num_samples = len(data_list)
259
+ disable_pbar = kwargs.get("disable_pbar", False)
260
+ pbar = (
261
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
262
+ if not disable_pbar
263
+ else None
264
+ )
265
+ time_speech_total = 0.0
266
+ time_escape_total = 0.0
267
+ for beg_idx in range(0, num_samples, batch_size):
268
+ end_idx = min(num_samples, beg_idx + batch_size)
269
+ data_batch = data_list[beg_idx:end_idx]
270
+ key_batch = key_list[beg_idx:end_idx]
271
+ batch = {"data_in": data_batch, "key": key_batch}
272
+ if (end_idx - beg_idx) == 1 and kwargs.get(
273
+ "data_type", None
274
+ ) == "fbank": # fbank
275
+ batch["data_in"] = data_batch[0]
276
+ batch["data_lengths"] = input_len
277
+
278
+ time1 = time.perf_counter()
279
+ with torch.no_grad():
280
+ results, meta_data = model.inference(**batch, **kwargs)
281
+ time2 = time.perf_counter()
282
+
283
+ asr_result_list.extend(results)
284
+
285
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
286
+ batch_data_time = meta_data.get("batch_data_time", -1)
287
+ time_escape = time2 - time1
288
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
289
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
290
+ speed_stats["forward"] = f"{time_escape:0.3f}"
291
+ speed_stats["batch_size"] = f"{len(results)}"
292
+ speed_stats["time_cost"] = f"{(time_escape)}"
293
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
294
+ description = f"{speed_stats}, "
295
+ if pbar:
296
+ pbar.update(1)
297
+ pbar.set_description(description)
298
+ time_speech_total += batch_data_time
299
+ time_escape_total += time_escape
300
+
301
+ if pbar:
302
+ # pbar.update(1)
303
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
304
+ torch.cuda.empty_cache()
305
+ return asr_result_list
306
+
307
+ def inference_with_vad(self, input, input_len=None, **cfg):
308
+
309
+ # step.1: compute the vad model
310
+ self.vad_kwargs.update(cfg)
311
+ beg_vad = time.time()
312
+ res = self.inference(
313
+ input,
314
+ input_len=input_len,
315
+ model=self.vad_model,
316
+ kwargs=self.vad_kwargs,
317
+ **cfg,
318
+ )
319
+ end_vad = time.time()
320
+ print(f"time cost vad: {end_vad - beg_vad:0.3f}")
321
+
322
+ # step.2 compute asr model
323
+ model = self.model
324
+ kwargs = self.kwargs
325
+ kwargs.update(cfg)
326
+ batch_size = int(kwargs.get("batch_size_s", 300)) * 1000
327
+ batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
328
+ kwargs["batch_size"] = batch_size
329
+
330
+ key_list, data_list = prepare_data_iterator(
331
+ input, input_len=input_len, data_type=kwargs.get("data_type", None)
332
+ )
333
+ results_ret_list = []
334
+ time_speech_total_all_samples = 1e-6
335
+
336
+ beg_total = time.time()
337
+ pbar_total = tqdm(colour="red", total=len(res), dynamic_ncols=True)
338
+ for i in range(len(res)):
339
+ key = res[i]["key"]
340
+ vadsegments = res[i]["value"]
341
+ input_i = data_list[i]
342
+ speech = load_audio_text_image_video(
343
+ input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000)
344
+ )
345
+ speech_lengths = len(speech)
346
+ n = len(vadsegments)
347
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
348
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
349
+ results_sorted = []
350
+
351
+ if not len(sorted_data):
352
+ logging.info("decoding, utt: {}, empty speech".format(key))
353
+ continue
354
+
355
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
356
+ batch_size = max(
357
+ batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
358
+ )
359
+
360
+ batch_size_ms_cum = 0
361
+ beg_idx = 0
362
+ beg_asr_total = time.time()
363
+ time_speech_total_per_sample = speech_lengths / 16000
364
+ time_speech_total_all_samples += time_speech_total_per_sample
365
+
366
+ all_segments = []
367
+ for j, _ in enumerate(range(0, n)):
368
+ # pbar_sample.update(1)
369
+ batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
370
+ if (
371
+ j < n - 1
372
+ and (
373
+ batch_size_ms_cum
374
+ + sorted_data[j + 1][0][1]
375
+ - sorted_data[j + 1][0][0]
376
+ )
377
+ < batch_size
378
+ and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0])
379
+ < batch_size_threshold_ms
380
+ ):
381
+ continue
382
+ batch_size_ms_cum = 0
383
+ end_idx = j + 1
384
+ speech_j, speech_lengths_j = slice_padding_audio_samples(
385
+ speech, speech_lengths, sorted_data[beg_idx:end_idx]
386
+ )
387
+ results = self.inference(
388
+ speech_j,
389
+ input_len=None,
390
+ model=model,
391
+ kwargs=kwargs,
392
+ disable_pbar=True,
393
+ **cfg,
394
+ )
395
+ if self.spk_model is not None:
396
+ # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
397
+ for _b in range(len(speech_j)):
398
+ vad_segments = [
399
+ [
400
+ sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
401
+ sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
402
+ np.array(speech_j[_b]),
403
+ ]
404
+ ]
405
+ segments = sv_chunk(vad_segments)
406
+ all_segments.extend(segments)
407
+ speech_b = [i[2] for i in segments]
408
+ spk_res = self.inference(
409
+ speech_b,
410
+ input_len=None,
411
+ model=self.spk_model,
412
+ kwargs=kwargs,
413
+ disable_pbar=True,
414
+ **cfg,
415
+ )
416
+ results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
417
+ beg_idx = end_idx
418
+ if len(results) < 1:
419
+ continue
420
+ results_sorted.extend(results)
421
+
422
+ restored_data = [0] * n
423
+ for j in range(n):
424
+ index = sorted_data[j][1]
425
+ restored_data[index] = results_sorted[j]
426
+ result = {}
427
+
428
+ # results combine for texts, timestamps, speaker embeddings and others
429
+ # TODO: rewrite for clean code
430
+ for j in range(n):
431
+ for k, v in restored_data[j].items():
432
+ if k.startswith("timestamp"):
433
+ if k not in result:
434
+ result[k] = []
435
+ for t in restored_data[j][k]:
436
+ t[0] += vadsegments[j][0]
437
+ t[1] += vadsegments[j][0]
438
+ result[k].extend(restored_data[j][k])
439
+ elif k == "spk_embedding":
440
+ if k not in result:
441
+ result[k] = restored_data[j][k]
442
+ else:
443
+ result[k] = torch.cat(
444
+ [result[k], restored_data[j][k]], dim=0
445
+ )
446
+ elif "text" in k:
447
+ if k not in result:
448
+ result[k] = restored_data[j][k]
449
+ else:
450
+ result[k] += " " + restored_data[j][k]
451
+ else:
452
+ if k not in result:
453
+ result[k] = restored_data[j][k]
454
+ else:
455
+ result[k] += restored_data[j][k]
456
+
457
+ return_raw_text = kwargs.get("return_raw_text", False)
458
+ # step.3 compute punc model
459
+ if self.punc_model is not None:
460
+ self.punc_kwargs.update(cfg)
461
+ punc_res = self.inference(
462
+ result["text"],
463
+ model=self.punc_model,
464
+ kwargs=self.punc_kwargs,
465
+ disable_pbar=True,
466
+ **cfg,
467
+ )
468
+ raw_text = copy.copy(result["text"])
469
+ if return_raw_text:
470
+ result["raw_text"] = raw_text
471
+ result["text"] = punc_res[0]["text"]
472
+ else:
473
+ raw_text = None
474
+
475
+ # speaker embedding cluster after resorted
476
+ if self.spk_model is not None and kwargs.get("return_spk_res", True):
477
+ if raw_text is None:
478
+ logging.error("Missing punc_model, which is required by spk_model.")
479
+ all_segments = sorted(all_segments, key=lambda x: x[0])
480
+ spk_embedding = result["spk_embedding"]
481
+ labels = self.cb_model(
482
+ spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None)
483
+ )
484
+ # del result['spk_embedding']
485
+ sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
486
+ if self.spk_mode == "vad_segment": # recover sentence_list
487
+ sentence_list = []
488
+ for res, vadsegment in zip(restored_data, vadsegments):
489
+ if "timestamp" not in res:
490
+ logging.error(
491
+ "Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
492
+ and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
493
+ can predict timestamp, and speaker diarization relies on timestamps."
494
+ )
495
+ sentence_list.append(
496
+ {
497
+ "start": vadsegment[0],
498
+ "end": vadsegment[1],
499
+ "sentence": res["text"],
500
+ "timestamp": res["timestamp"],
501
+ }
502
+ )
503
+ elif self.spk_mode == "punc_segment":
504
+ if "timestamp" not in result:
505
+ logging.error(
506
+ "Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
507
+ and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
508
+ can predict timestamp, and speaker diarization relies on timestamps."
509
+ )
510
+ sentence_list = timestamp_sentence(
511
+ punc_res[0]["punc_array"],
512
+ result["timestamp"],
513
+ raw_text,
514
+ return_raw_text=return_raw_text,
515
+ )
516
+ distribute_spk(sentence_list, sv_output)
517
+ result["sentence_info"] = sentence_list
518
+ elif kwargs.get("sentence_timestamp", False):
519
+ sentence_list = timestamp_sentence(
520
+ punc_res[0]["punc_array"],
521
+ result["timestamp"],
522
+ raw_text,
523
+ return_raw_text=return_raw_text,
524
+ )
525
+ result["sentence_info"] = sentence_list
526
+ if "spk_embedding" in result:
527
+ del result["spk_embedding"]
528
+
529
+ result["key"] = key
530
+ results_ret_list.append(result)
531
+ end_asr_total = time.time()
532
+ time_escape_total_per_sample = end_asr_total - beg_asr_total
533
+ pbar_total.update(1)
534
+ pbar_total.set_description(
535
+ f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
536
+ f"time_speech: {time_speech_total_per_sample: 0.3f}, "
537
+ f"time_escape: {time_escape_total_per_sample:0.3f}"
538
+ )
539
+
540
+ return results_ret_list
541
+
542
+ def infer_encoder(
543
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
544
+ ):
545
+ kwargs = self.kwargs if kwargs is None else kwargs
546
+ kwargs.update(cfg)
547
+ model = self.model if model is None else model
548
+ model = model.cuda()
549
+ model.eval()
550
+
551
+ batch_size = kwargs.get("batch_size", 1)
552
+
553
+ key_list, data_list = prepare_data_iterator(
554
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
555
+ )
556
+
557
+ asr_result_list = []
558
+ num_samples = len(data_list)
559
+ for beg_idx in range(0, num_samples, batch_size):
560
+ end_idx = min(num_samples, beg_idx + batch_size)
561
+ data_batch = data_list[beg_idx:end_idx]
562
+ key_batch = key_list[beg_idx:end_idx]
563
+ batch = {"data_in": data_batch, "key": key_batch}
564
+ if (end_idx - beg_idx) == 1 and kwargs.get(
565
+ "data_type", None
566
+ ) == "fbank": # fbank
567
+ batch["data_in"] = data_batch[0]
568
+ batch["data_lengths"] = input_len
569
+
570
+ with torch.no_grad():
571
+ results, meta_data, cache = model.infer_encoder(**batch, **kwargs)
572
+ asr_result_list.extend(results)
573
+
574
+ torch.cuda.empty_cache()
575
+ return asr_result_list, cache
demo/Step-Audio-EditX/funasr_detach/auto/auto_tokenizer.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ class AutoTokenizer:
2
+ """
3
+ Undo
4
+ """
5
+
6
+ def __init__(self):
7
+ pass
demo/Step-Audio-EditX/funasr_detach/bin/__init__.py ADDED
File without changes
demo/Step-Audio-EditX/funasr_detach/bin/compute_audio_cmvn.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import torch
5
+ import hydra
6
+ import logging
7
+ from omegaconf import DictConfig, OmegaConf
8
+
9
+ from funasr_detach.register import tables
10
+ from funasr_detach.download.download_from_hub import download_model
11
+ from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
12
+
13
+
14
+ @hydra.main(config_name=None, version_base=None)
15
+ def main_hydra(kwargs: DictConfig):
16
+ if kwargs.get("debug", False):
17
+ import pdb
18
+
19
+ pdb.set_trace()
20
+
21
+ assert "model" in kwargs
22
+ if "model_conf" not in kwargs:
23
+ logging.info(
24
+ "download models from model hub: {}".format(kwargs.get("model_hub", "ms"))
25
+ )
26
+ kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
27
+
28
+ main(**kwargs)
29
+
30
+
31
+ def main(**kwargs):
32
+ print(kwargs)
33
+ # set random seed
34
+ tables.print()
35
+ set_all_random_seed(kwargs.get("seed", 0))
36
+ torch.backends.cudnn.enabled = kwargs.get(
37
+ "cudnn_enabled", torch.backends.cudnn.enabled
38
+ )
39
+ torch.backends.cudnn.benchmark = kwargs.get(
40
+ "cudnn_benchmark", torch.backends.cudnn.benchmark
41
+ )
42
+ torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
43
+
44
+ tokenizer = kwargs.get("tokenizer", None)
45
+
46
+ # build frontend if frontend is none None
47
+ frontend = kwargs.get("frontend", None)
48
+ if frontend is not None:
49
+ frontend_class = tables.frontend_classes.get(frontend)
50
+ frontend = frontend_class(**kwargs["frontend_conf"])
51
+ kwargs["frontend"] = frontend
52
+ kwargs["input_size"] = frontend.output_size()
53
+
54
+ # dataset
55
+ dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
56
+ dataset_train = dataset_class(
57
+ kwargs.get("train_data_set_list"),
58
+ frontend=frontend,
59
+ tokenizer=None,
60
+ is_training=False,
61
+ **kwargs.get("dataset_conf")
62
+ )
63
+
64
+ # dataloader
65
+ batch_sampler = kwargs["dataset_conf"].get(
66
+ "batch_sampler", "DynamicBatchLocalShuffleSampler"
67
+ )
68
+ batch_sampler_train = None
69
+ if batch_sampler is not None:
70
+ batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
71
+ dataset_conf = kwargs.get("dataset_conf")
72
+ dataset_conf["batch_type"] = "example"
73
+ dataset_conf["batch_size"] = 1
74
+ batch_sampler_train = batch_sampler_class(
75
+ dataset_train, is_training=False, **dataset_conf
76
+ )
77
+
78
+ dataloader_train = torch.utils.data.DataLoader(
79
+ dataset_train,
80
+ collate_fn=dataset_train.collator,
81
+ batch_sampler=batch_sampler_train,
82
+ num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)),
83
+ pin_memory=True,
84
+ )
85
+
86
+ iter_stop = int(kwargs.get("scale", 1.0) * len(dataloader_train))
87
+
88
+ total_frames = 0
89
+ for batch_idx, batch in enumerate(dataloader_train):
90
+ if batch_idx >= iter_stop:
91
+ break
92
+
93
+ fbank = batch["speech"].numpy()[0, :, :]
94
+ if total_frames == 0:
95
+ mean_stats = np.sum(fbank, axis=0)
96
+ var_stats = np.sum(np.square(fbank), axis=0)
97
+ else:
98
+ mean_stats += np.sum(fbank, axis=0)
99
+ var_stats += np.sum(np.square(fbank), axis=0)
100
+ total_frames += fbank.shape[0]
101
+
102
+ cmvn_info = {
103
+ "mean_stats": list(mean_stats.tolist()),
104
+ "var_stats": list(var_stats.tolist()),
105
+ "total_frames": total_frames,
106
+ }
107
+ cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
108
+ # import pdb;pdb.set_trace()
109
+ with open(cmvn_file, "w") as fout:
110
+ fout.write(json.dumps(cmvn_info))
111
+
112
+ mean = -1.0 * mean_stats / total_frames
113
+ var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean)
114
+ dims = mean.shape[0]
115
+ am_mvn = os.path.dirname(cmvn_file) + "/am.mvn"
116
+ with open(am_mvn, "w") as fout:
117
+ fout.write(
118
+ "<Nnet>"
119
+ + "\n"
120
+ + "<Splice> "
121
+ + str(dims)
122
+ + " "
123
+ + str(dims)
124
+ + "\n"
125
+ + "[ 0 ]"
126
+ + "\n"
127
+ + "<AddShift> "
128
+ + str(dims)
129
+ + " "
130
+ + str(dims)
131
+ + "\n"
132
+ )
133
+ mean_str = (
134
+ str(list(mean)).replace(",", "").replace("[", "[ ").replace("]", " ]")
135
+ )
136
+ fout.write("<LearnRateCoef> 0 " + mean_str + "\n")
137
+ fout.write("<Rescale> " + str(dims) + " " + str(dims) + "\n")
138
+ var_str = str(list(var)).replace(",", "").replace("[", "[ ").replace("]", " ]")
139
+ fout.write("<LearnRateCoef> 0 " + var_str + "\n")
140
+ fout.write("</Nnet>" + "\n")
141
+
142
+
143
+ """
144
+ python funasr/bin/compute_audio_cmvn.py \
145
+ --config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
146
+ --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
147
+ ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
148
+ ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
149
+ ++dataset_conf.num_workers=0
150
+ """
151
+ if __name__ == "__main__":
152
+ main_hydra()
demo/Step-Audio-EditX/funasr_detach/bin/inference.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import logging
3
+ from omegaconf import DictConfig, OmegaConf, ListConfig
4
+
5
+ from funasr_detach.auto.auto_model import AutoModel
6
+
7
+
8
+ @hydra.main(config_name=None, version_base=None)
9
+ def main_hydra(cfg: DictConfig):
10
+ def to_plain_list(cfg_item):
11
+ if isinstance(cfg_item, ListConfig):
12
+ return OmegaConf.to_container(cfg_item, resolve=True)
13
+ elif isinstance(cfg_item, DictConfig):
14
+ return {k: to_plain_list(v) for k, v in cfg_item.items()}
15
+ else:
16
+ return cfg_item
17
+
18
+ kwargs = to_plain_list(cfg)
19
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
20
+
21
+ logging.basicConfig(level=log_level)
22
+
23
+ if kwargs.get("debug", False):
24
+ import pdb
25
+
26
+ pdb.set_trace()
27
+ model = AutoModel(**kwargs)
28
+ res = model.generate(input=kwargs["input"])
29
+ print(res)
30
+
31
+
32
+ if __name__ == "__main__":
33
+ main_hydra()
demo/Step-Audio-EditX/funasr_detach/bin/tokenize_text.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ from collections import Counter
4
+ import logging
5
+ from pathlib import Path
6
+ import sys
7
+ from typing import List
8
+ from typing import Optional
9
+
10
+
11
+ from funasr_detach.utils.cli_utils import get_commandline_args
12
+ from funasr_detach.tokenizer.build_tokenizer import build_tokenizer
13
+ from funasr_detach.tokenizer.cleaner import TextCleaner
14
+ from funasr_detach.tokenizer.phoneme_tokenizer import g2p_classes
15
+ from funasr_detach.utils.types import str2bool
16
+ from funasr_detach.utils.types import str_or_none
17
+
18
+
19
+ def field2slice(field: Optional[str]) -> slice:
20
+ """Convert field string to slice
21
+
22
+ Note that field string accepts 1-based integer.
23
+
24
+ Examples:
25
+ >>> field2slice("1-")
26
+ slice(0, None, None)
27
+ >>> field2slice("1-3")
28
+ slice(0, 3, None)
29
+ >>> field2slice("-3")
30
+ slice(None, 3, None)
31
+ """
32
+ field = field.strip()
33
+ try:
34
+ if "-" in field:
35
+ # e.g. "2-" or "2-5" or "-7"
36
+ s1, s2 = field.split("-", maxsplit=1)
37
+ if s1.strip() == "":
38
+ s1 = None
39
+ else:
40
+ s1 = int(s1)
41
+ if s1 == 0:
42
+ raise ValueError("1-based string")
43
+ if s2.strip() == "":
44
+ s2 = None
45
+ else:
46
+ s2 = int(s2)
47
+ else:
48
+ # e.g. "2"
49
+ s1 = int(field)
50
+ s2 = s1 + 1
51
+ if s1 == 0:
52
+ raise ValueError("must be 1 or more value")
53
+ except ValueError:
54
+ raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
55
+
56
+ if s1 is None:
57
+ slic = slice(None, s2)
58
+ else:
59
+ # -1 because of 1-based integer following "cut" command
60
+ # e.g "1-3" -> slice(0, 3)
61
+ slic = slice(s1 - 1, s2)
62
+ return slic
63
+
64
+
65
+ def tokenize(
66
+ input: str,
67
+ output: str,
68
+ field: Optional[str],
69
+ delimiter: Optional[str],
70
+ token_type: str,
71
+ space_symbol: str,
72
+ non_linguistic_symbols: Optional[str],
73
+ bpemodel: Optional[str],
74
+ log_level: str,
75
+ write_vocabulary: bool,
76
+ vocabulary_size: int,
77
+ remove_non_linguistic_symbols: bool,
78
+ cutoff: int,
79
+ add_symbol: List[str],
80
+ cleaner: Optional[str],
81
+ g2p: Optional[str],
82
+ ):
83
+
84
+ logging.basicConfig(
85
+ level=log_level,
86
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
87
+ )
88
+ if input == "-":
89
+ fin = sys.stdin
90
+ else:
91
+ fin = Path(input).open("r", encoding="utf-8")
92
+ if output == "-":
93
+ fout = sys.stdout
94
+ else:
95
+ p = Path(output)
96
+ p.parent.mkdir(parents=True, exist_ok=True)
97
+ fout = p.open("w", encoding="utf-8")
98
+
99
+ cleaner = TextCleaner(cleaner)
100
+ tokenizer = build_tokenizer(
101
+ token_type=token_type,
102
+ bpemodel=bpemodel,
103
+ delimiter=delimiter,
104
+ space_symbol=space_symbol,
105
+ non_linguistic_symbols=non_linguistic_symbols,
106
+ remove_non_linguistic_symbols=remove_non_linguistic_symbols,
107
+ g2p_type=g2p,
108
+ )
109
+
110
+ counter = Counter()
111
+ if field is not None:
112
+ field = field2slice(field)
113
+
114
+ for line in fin:
115
+ line = line.rstrip()
116
+ if field is not None:
117
+ # e.g. field="2-"
118
+ # uttidA hello world!! -> hello world!!
119
+ tokens = line.split(delimiter)
120
+ tokens = tokens[field]
121
+ if delimiter is None:
122
+ line = " ".join(tokens)
123
+ else:
124
+ line = delimiter.join(tokens)
125
+
126
+ line = cleaner(line)
127
+ tokens = tokenizer.text2tokens(line)
128
+ if not write_vocabulary:
129
+ fout.write(" ".join(tokens) + "\n")
130
+ else:
131
+ for t in tokens:
132
+ counter[t] += 1
133
+
134
+ if not write_vocabulary:
135
+ return
136
+
137
+ ## FIXME
138
+ ## del duplicate add_symbols in counter
139
+ for symbol_and_id in add_symbol:
140
+ # e.g symbol="<blank>:0"
141
+ try:
142
+ symbol, idx = symbol_and_id.split(":")
143
+ except ValueError:
144
+ raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
145
+ symbol = symbol.strip()
146
+ if symbol in counter:
147
+ del counter[symbol]
148
+
149
+ # ======= write_vocabulary mode from here =======
150
+ # Sort by the number of occurrences in descending order
151
+ # and filter lower frequency words than cutoff value
152
+ words_and_counts = list(
153
+ filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
154
+ )
155
+ # Restrict the vocabulary size
156
+ if vocabulary_size > 0:
157
+ if vocabulary_size < len(add_symbol):
158
+ raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
159
+ words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
160
+
161
+ # Parse the values of --add_symbol
162
+ for symbol_and_id in add_symbol:
163
+ # e.g symbol="<blank>:0"
164
+ try:
165
+ symbol, idx = symbol_and_id.split(":")
166
+ idx = int(idx)
167
+ except ValueError:
168
+ raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
169
+ symbol = symbol.strip()
170
+
171
+ # e.g. idx=0 -> append as the first symbol
172
+ # e.g. idx=-1 -> append as the last symbol
173
+ if idx < 0:
174
+ idx = len(words_and_counts) + 1 + idx
175
+ words_and_counts.insert(idx, (symbol, None))
176
+
177
+ # Write words
178
+ for w, c in words_and_counts:
179
+ fout.write(w + "\n")
180
+
181
+ # Logging
182
+ total_count = sum(counter.values())
183
+ invocab_count = sum(c for w, c in words_and_counts if c is not None)
184
+ logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
185
+
186
+
187
+ def get_parser() -> argparse.ArgumentParser:
188
+ parser = argparse.ArgumentParser(
189
+ description="Tokenize texts",
190
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
191
+ )
192
+ parser.add_argument(
193
+ "--log_level",
194
+ type=lambda x: x.upper(),
195
+ default="INFO",
196
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
197
+ help="The verbose level of logging",
198
+ )
199
+
200
+ parser.add_argument(
201
+ "--input", "-i", required=True, help="Input text. - indicates sys.stdin"
202
+ )
203
+ parser.add_argument(
204
+ "--output", "-o", required=True, help="Output text. - indicates sys.stdout"
205
+ )
206
+ parser.add_argument(
207
+ "--field",
208
+ "-f",
209
+ help="The target columns of the input text as 1-based integer. e.g 2-",
210
+ )
211
+ parser.add_argument(
212
+ "--token_type",
213
+ "-t",
214
+ default="char",
215
+ choices=["char", "bpe", "word", "phn"],
216
+ help="Token type",
217
+ )
218
+ parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
219
+ parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
220
+ parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
221
+ parser.add_argument(
222
+ "--non_linguistic_symbols",
223
+ type=str_or_none,
224
+ help="non_linguistic_symbols file path",
225
+ )
226
+ parser.add_argument(
227
+ "--remove_non_linguistic_symbols",
228
+ type=str2bool,
229
+ default=False,
230
+ help="Remove non-language-symbols from tokens",
231
+ )
232
+ parser.add_argument(
233
+ "--cleaner",
234
+ type=str_or_none,
235
+ choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
236
+ default=None,
237
+ help="Apply text cleaning",
238
+ )
239
+ parser.add_argument(
240
+ "--g2p",
241
+ type=str_or_none,
242
+ choices=g2p_classes,
243
+ default=None,
244
+ help="Specify g2p method if --token_type=phn",
245
+ )
246
+
247
+ group = parser.add_argument_group("write_vocabulary mode related")
248
+ group.add_argument(
249
+ "--write_vocabulary",
250
+ type=str2bool,
251
+ default=False,
252
+ help="Write tokens list instead of tokenized text per line",
253
+ )
254
+ group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
255
+ group.add_argument(
256
+ "--cutoff",
257
+ default=0,
258
+ type=int,
259
+ help="cut-off frequency used for write-vocabulary mode",
260
+ )
261
+ group.add_argument(
262
+ "--add_symbol",
263
+ type=str,
264
+ default=[],
265
+ action="append",
266
+ help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
267
+ )
268
+
269
+ return parser
270
+
271
+
272
+ def main(cmd=None):
273
+ print(get_commandline_args(), file=sys.stderr)
274
+ parser = get_parser()
275
+ args = parser.parse_args(cmd)
276
+ kwargs = vars(args)
277
+ tokenize(**kwargs)
278
+
279
+
280
+ if __name__ == "__main__":
281
+ main()
demo/Step-Audio-EditX/funasr_detach/bin/train.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import os
5
+ import sys
6
+ import torch
7
+ import hydra
8
+ import logging
9
+ import argparse
10
+ from io import BytesIO
11
+ import torch.distributed as dist
12
+ from collections.abc import Sequence
13
+ from omegaconf import DictConfig, OmegaConf
14
+ from torch.nn.parallel import DistributedDataParallel as DDP
15
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16
+
17
+ from funasr_detach.register import tables
18
+ from funasr_detach.optimizers import optim_classes
19
+ from funasr_detach.train_utils.trainer import Trainer
20
+ from funasr_detach.schedulers import scheduler_classes
21
+ from funasr_detach.train_utils.initialize import initialize
22
+ from funasr_detach.download.download_from_hub import download_model
23
+ from funasr_detach.models.lora.utils import mark_only_lora_as_trainable
24
+ from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
25
+ from funasr_detach.train_utils.load_pretrained_model import load_pretrained_model
26
+
27
+ # from funasr_detach.tokenizer.build_tokenizer import build_tokenizer
28
+ # from funasr_detach.tokenizer.token_id_converter import TokenIDConverter
29
+ # from funasr_detach.tokenizer.funtoken import build_tokenizer
30
+
31
+
32
+ @hydra.main(config_name=None, version_base=None)
33
+ def main_hydra(kwargs: DictConfig):
34
+ if kwargs.get("debug", False):
35
+ import pdb
36
+
37
+ pdb.set_trace()
38
+
39
+ assert "model" in kwargs
40
+ if "model_conf" not in kwargs:
41
+ logging.info(
42
+ "download models from model hub: {}".format(kwargs.get("model_hub", "ms"))
43
+ )
44
+ kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
45
+
46
+ main(**kwargs)
47
+
48
+
49
+ def main(**kwargs):
50
+ print(kwargs)
51
+
52
+ # set random seed
53
+ set_all_random_seed(kwargs.get("seed", 0))
54
+ torch.backends.cudnn.enabled = kwargs.get(
55
+ "cudnn_enabled", torch.backends.cudnn.enabled
56
+ )
57
+ torch.backends.cudnn.benchmark = kwargs.get(
58
+ "cudnn_benchmark", torch.backends.cudnn.benchmark
59
+ )
60
+ torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
61
+
62
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
63
+ if local_rank == 0:
64
+ tables.print()
65
+ # Check if we are using DDP or FSDP
66
+ use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
67
+ use_fsdp = kwargs.get("use_fsdp", None)
68
+ if use_ddp or use_fsdp:
69
+ dist.init_process_group(
70
+ backend=kwargs.get("backend", "nccl"), init_method="env://"
71
+ )
72
+ torch.cuda.set_device(local_rank)
73
+
74
+ # save config.yaml
75
+ if (
76
+ (use_ddp or use_fsdp)
77
+ and dist.get_rank() == 0
78
+ or not (use_ddp or use_fsdp)
79
+ and local_rank == 0
80
+ ):
81
+ os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
82
+ yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
83
+ OmegaConf.save(config=kwargs, f=yaml_file)
84
+ logging.info("config.yaml is saved to: %s", yaml_file)
85
+
86
+ tokenizer = kwargs.get("tokenizer", None)
87
+ if tokenizer is not None:
88
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
89
+ tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
90
+ kwargs["tokenizer"] = tokenizer
91
+
92
+ # build frontend if frontend is none None
93
+ frontend = kwargs.get("frontend", None)
94
+ if frontend is not None:
95
+ frontend_class = tables.frontend_classes.get(frontend)
96
+ frontend = frontend_class(**kwargs["frontend_conf"])
97
+ kwargs["frontend"] = frontend
98
+ kwargs["input_size"] = frontend.output_size()
99
+
100
+ # build model
101
+ model_class = tables.model_classes.get(kwargs["model"])
102
+ model = model_class(
103
+ **kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)
104
+ )
105
+
106
+ # init_param
107
+ init_param = kwargs.get("init_param", None)
108
+ if init_param is not None:
109
+ if not isinstance(init_param, (list, tuple)):
110
+ init_param = (init_param,)
111
+ logging.info("init_param is not None: %s", init_param)
112
+ for p in init_param:
113
+ logging.info(f"Loading pretrained params from {p}")
114
+ load_pretrained_model(
115
+ model=model,
116
+ path=p,
117
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
118
+ oss_bucket=kwargs.get("oss_bucket", None),
119
+ scope_map=kwargs.get("scope_map", None),
120
+ excludes=kwargs.get("excludes", None),
121
+ )
122
+ else:
123
+ initialize(model, kwargs.get("init", "kaiming_normal"))
124
+
125
+ # freeze_param
126
+ freeze_param = kwargs.get("freeze_param", None)
127
+ if freeze_param is not None:
128
+ freeze_param = eval(freeze_param)
129
+ if isinstance(freeze_param, Sequence):
130
+ freeze_param = (freeze_param,)
131
+ logging.info("freeze_param is not None: %s", freeze_param)
132
+ for t in freeze_param:
133
+ for k, p in model.named_parameters():
134
+ if k.startswith(t + ".") or k == t:
135
+ logging.info(f"Setting {k}.requires_grad = False")
136
+ p.requires_grad = False
137
+
138
+ if use_ddp:
139
+ model = model.cuda(local_rank)
140
+ model = DDP(
141
+ model,
142
+ device_ids=[local_rank],
143
+ find_unused_parameters=kwargs.get("train_conf", {}).get(
144
+ "find_unused_parameters", False
145
+ ),
146
+ )
147
+ elif use_fsdp:
148
+ model = FSDP(model).cuda(local_rank)
149
+ else:
150
+ model = model.to(device=kwargs.get("device", "cuda"))
151
+
152
+ # optim
153
+ optim = kwargs.get("optim", "adam")
154
+ assert optim in optim_classes
155
+ optim_class = optim_classes.get(optim)
156
+ optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
157
+
158
+ # scheduler
159
+ scheduler = kwargs.get("scheduler", "warmuplr")
160
+ assert scheduler in scheduler_classes
161
+ scheduler_class = scheduler_classes.get(scheduler)
162
+ scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
163
+
164
+ # dataset
165
+ dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
166
+ dataset_tr = dataset_class(
167
+ kwargs.get("train_data_set_list"),
168
+ frontend=frontend,
169
+ tokenizer=tokenizer,
170
+ is_training=True,
171
+ **kwargs.get("dataset_conf"),
172
+ )
173
+ dataset_val = dataset_class(
174
+ kwargs.get("valid_data_set_list"),
175
+ frontend=frontend,
176
+ tokenizer=tokenizer,
177
+ is_training=False,
178
+ **kwargs.get("dataset_conf"),
179
+ )
180
+
181
+ # dataloader
182
+ batch_sampler = kwargs["dataset_conf"].get(
183
+ "batch_sampler", "DynamicBatchLocalShuffleSampler"
184
+ )
185
+ batch_sampler_val = None
186
+ if batch_sampler is not None:
187
+ batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
188
+ batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
189
+ batch_sampler_val = batch_sampler_class(
190
+ dataset_val, is_training=False, **kwargs.get("dataset_conf")
191
+ )
192
+ dataloader_tr = torch.utils.data.DataLoader(
193
+ dataset_tr,
194
+ collate_fn=dataset_tr.collator,
195
+ batch_sampler=batch_sampler,
196
+ num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
197
+ pin_memory=True,
198
+ )
199
+
200
+ dataloader_val = torch.utils.data.DataLoader(
201
+ dataset_val,
202
+ collate_fn=dataset_val.collator,
203
+ batch_sampler=batch_sampler_val,
204
+ num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
205
+ pin_memory=True,
206
+ )
207
+ trainer = Trainer(
208
+ model=model,
209
+ optim=optim,
210
+ scheduler=scheduler,
211
+ dataloader_train=dataloader_tr,
212
+ dataloader_val=dataloader_val,
213
+ local_rank=local_rank,
214
+ use_ddp=use_ddp,
215
+ use_fsdp=use_fsdp,
216
+ output_dir=kwargs.get("output_dir", "./exp"),
217
+ resume=kwargs.get("resume", True),
218
+ **kwargs.get("train_conf"),
219
+ )
220
+ trainer.run()
221
+
222
+ if use_ddp or use_fsdp:
223
+ torch.distributed.destroy_process_group()
224
+
225
+
226
+ if __name__ == "__main__":
227
+ main_hydra()
demo/Step-Audio-EditX/funasr_detach/datasets/__init__.py ADDED
File without changes
demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/__init__.py ADDED
File without changes
demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/datasets.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from funasr_detach.register import tables
4
+ from funasr_detach.utils.load_utils import extract_fbank, load_audio_text_image_video
5
+
6
+
7
+ @tables.register("dataset_classes", "AudioDataset")
8
+ class AudioDataset(torch.utils.data.Dataset):
9
+ """
10
+ AudioDataset
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ path,
16
+ index_ds: str = None,
17
+ frontend=None,
18
+ tokenizer=None,
19
+ int_pad_value: int = -1,
20
+ float_pad_value: float = 0.0,
21
+ **kwargs
22
+ ):
23
+ super().__init__()
24
+ index_ds_class = tables.index_ds_classes.get(index_ds)
25
+ self.index_ds = index_ds_class(path, **kwargs)
26
+ preprocessor_speech = kwargs.get("preprocessor_speech", None)
27
+ if preprocessor_speech:
28
+ preprocessor_speech_class = tables.preprocessor_classes.get(
29
+ preprocessor_speech
30
+ )
31
+ preprocessor_speech = preprocessor_speech_class(
32
+ **kwargs.get("preprocessor_speech_conf")
33
+ )
34
+ self.preprocessor_speech = preprocessor_speech
35
+ preprocessor_text = kwargs.get("preprocessor_text", None)
36
+ if preprocessor_text:
37
+ preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
38
+ preprocessor_text = preprocessor_text_class(
39
+ **kwargs.get("preprocessor_text_conf")
40
+ )
41
+ self.preprocessor_text = preprocessor_text
42
+
43
+ self.frontend = frontend
44
+ self.fs = 16000 if frontend is None else frontend.fs
45
+ self.data_type = "sound"
46
+ self.tokenizer = tokenizer
47
+
48
+ self.int_pad_value = int_pad_value
49
+ self.float_pad_value = float_pad_value
50
+
51
+ def get_source_len(self, index):
52
+ item = self.index_ds[index]
53
+ return self.index_ds.get_source_len(item)
54
+
55
+ def get_target_len(self, index):
56
+ item = self.index_ds[index]
57
+ return self.index_ds.get_target_len(item)
58
+
59
+ def __len__(self):
60
+ return len(self.index_ds)
61
+
62
+ def __getitem__(self, index):
63
+ item = self.index_ds[index]
64
+ # import pdb;
65
+ # pdb.set_trace()
66
+ source = item["source"]
67
+ data_src = load_audio_text_image_video(source, fs=self.fs)
68
+ if self.preprocessor_speech:
69
+ data_src = self.preprocessor_speech(data_src, fs=self.fs)
70
+ speech, speech_lengths = extract_fbank(
71
+ data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
72
+ ) # speech: [b, T, d]
73
+
74
+ target = item["target"]
75
+ if self.preprocessor_text:
76
+ target = self.preprocessor_text(target)
77
+ if self.tokenizer:
78
+ ids = self.tokenizer.encode(target)
79
+ text = torch.tensor(ids, dtype=torch.int64)
80
+ else:
81
+ ids = target
82
+ text = ids
83
+ ids_lengths = len(ids)
84
+ text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
85
+
86
+ return {
87
+ "speech": speech[0, :, :],
88
+ "speech_lengths": speech_lengths,
89
+ "text": text,
90
+ "text_lengths": text_lengths,
91
+ }
92
+
93
+ def collator(self, samples: list = None):
94
+ outputs = {}
95
+ for sample in samples:
96
+ for key in sample.keys():
97
+ if key not in outputs:
98
+ outputs[key] = []
99
+ outputs[key].append(sample[key])
100
+
101
+ for key, data_list in outputs.items():
102
+ if isinstance(data_list[0], torch.Tensor):
103
+ if data_list[0].dtype == torch.int64:
104
+
105
+ pad_value = self.int_pad_value
106
+ else:
107
+ pad_value = self.float_pad_value
108
+
109
+ outputs[key] = torch.nn.utils.rnn.pad_sequence(
110
+ data_list, batch_first=True, padding_value=pad_value
111
+ )
112
+ return outputs
demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/index_ds.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import logging
5
+ import concurrent.futures
6
+ import librosa
7
+ import torch.distributed as dist
8
+
9
+ from funasr_detach.register import tables
10
+
11
+
12
+ @tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
13
+ class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
14
+
15
+ def __init__(self, path):
16
+ super().__init__()
17
+
18
+ contents = []
19
+ with open(path, encoding="utf-8") as fin:
20
+ for line in fin:
21
+ data = json.loads(line.strip())
22
+ if "text" in data: # for sft
23
+ self.contents.append(data["text"])
24
+ if "source" in data: # for speech lab pretrain
25
+ prompt = data["prompt"]
26
+ source = data["source"]
27
+ target = data["target"]
28
+ source_len = data["source_len"]
29
+ target_len = data["target_len"]
30
+
31
+ contents.append(
32
+ {
33
+ "source": source,
34
+ "prompt": prompt,
35
+ "target": target,
36
+ "source_len": source_len,
37
+ "target_len": target_len,
38
+ }
39
+ )
40
+
41
+ self.contents = []
42
+ total_num = len(contents)
43
+ try:
44
+ rank = dist.get_rank()
45
+ world_size = dist.get_world_size()
46
+ except:
47
+ rank = 0
48
+ world_size = 1
49
+ logging.warning("distributed is not initialized, only single shard")
50
+ num_per_rank = total_num // world_size
51
+
52
+ # rank = 0
53
+ # import ipdb; ipdb.set_trace()
54
+ self.contents = contents[rank * num_per_rank : (rank + 1) * num_per_rank]
55
+
56
+ logging.info(
57
+ "in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(
58
+ rank, len(self.contents), len(contents)
59
+ )
60
+ )
61
+
62
+ def __len__(self):
63
+ return len(self.contents)
64
+
65
+ def __getitem__(self, index):
66
+ try:
67
+ data = self.contents[index]
68
+ except:
69
+ print(index)
70
+ return data
71
+
72
+ def get_source_len(self, data_dict):
73
+ return data_dict["source_len"]
74
+
75
+ def get_target_len(self, data_dict):
76
+
77
+ return data_dict["target_len"] if "target_len" in data_dict else 0
78
+
79
+
80
+ @tables.register("index_ds_classes", "IndexDSJsonl")
81
+ @tables.register("index_ds_classes", "IndexDSJsonlRankFull")
82
+ class IndexDSJsonlRankFull(torch.utils.data.Dataset):
83
+
84
+ def __init__(self, path: str, **kwargs):
85
+ super().__init__()
86
+
87
+ if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans
88
+ from funasr_detach.datasets.audio_datasets.scp2jsonl import (
89
+ gen_jsonl_from_wav_text_list,
90
+ )
91
+
92
+ jsonl_outdir = os.path.dirname(path[0])
93
+ jsonl_name = (
94
+ "datalist_train.jsonl"
95
+ if kwargs.get("is_training", True)
96
+ else "datalist_val.jsonl"
97
+ )
98
+ jsonl_file_out = os.path.join(jsonl_outdir, jsonl_name)
99
+ if not os.path.exists(jsonl_file_out):
100
+ print(f"datalist is: {path}, generate jsonl from it")
101
+ gen_jsonl_from_wav_text_list(
102
+ path, jsonl_file_out=jsonl_file_out, **kwargs
103
+ )
104
+ path = jsonl_file_out
105
+
106
+ contents = []
107
+ with open(path, encoding="utf-8") as fin:
108
+ for line in fin:
109
+ data = json.loads(line.strip())
110
+ if "text" in data: # for sft
111
+ self.contents.append(data["text"])
112
+ if "source" in data: # for speech lab pretrain
113
+ prompt = data.get("prompt", "<ASR>")
114
+ source = data["source"]
115
+ target = data["target"]
116
+ source_len = data.get("source_len", 1)
117
+ target_len = data.get("target_len", 0)
118
+
119
+ contents.append(
120
+ {
121
+ "source": source,
122
+ "prompt": prompt,
123
+ "target": target,
124
+ "source_len": source_len,
125
+ "target_len": target_len,
126
+ }
127
+ )
128
+
129
+ self.contents = contents
130
+
131
+ logging.info(
132
+ "total_num of samplers across ranks: {}".format(len(self.contents))
133
+ )
134
+
135
+ def __len__(self):
136
+ return len(self.contents)
137
+
138
+ def __getitem__(self, index):
139
+ try:
140
+ data = self.contents[index]
141
+ except:
142
+ print(index)
143
+ return data
144
+
145
+ def get_source_len(self, data_dict):
146
+ return data_dict.get("source_len", 1)
147
+
148
+ def get_target_len(self, data_dict):
149
+
150
+ return data_dict.get("target_len", 0)
demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/preprocessor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import logging
5
+ import concurrent.futures
6
+ import librosa
7
+ import torch.distributed as dist
8
+ from typing import Collection
9
+ import torch
10
+ import torchaudio
11
+ from torch import nn
12
+ import random
13
+ import re
14
+ from funasr_detach.tokenizer.cleaner import TextCleaner
15
+ from funasr_detach.register import tables
16
+
17
+
18
+ @tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
19
+ class SpeechPreprocessSpeedPerturb(nn.Module):
20
+ def __init__(self, speed_perturb: list = None, **kwargs):
21
+ super().__init__()
22
+ self.speed_perturb = speed_perturb
23
+
24
+ def forward(self, waveform, fs, **kwargs):
25
+ if self.speed_perturb is None:
26
+ return waveform
27
+ speed = random.choice(self.speed_perturb)
28
+ if speed != 1.0:
29
+ if not isinstance(waveform, torch.Tensor):
30
+ waveform = torch.tensor(waveform)
31
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
32
+ waveform.view(1, -1), fs, [["speed", str(speed)], ["rate", str(fs)]]
33
+ )
34
+ waveform = waveform.view(-1)
35
+
36
+ return waveform
37
+
38
+
39
+ @tables.register("preprocessor_classes", "TextPreprocessSegDict")
40
+ class TextPreprocessSegDict(nn.Module):
41
+ def __init__(
42
+ self,
43
+ seg_dict: str = None,
44
+ text_cleaner: Collection[str] = None,
45
+ split_with_space: bool = False,
46
+ **kwargs
47
+ ):
48
+ super().__init__()
49
+
50
+ self.text_cleaner = TextCleaner(text_cleaner)
51
+
52
+ def forward(self, text, **kwargs):
53
+ text = self.text_cleaner(text)
54
+
55
+ return text
demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/samplers.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import logging
4
+ import torch.distributed as dist
5
+
6
+ from funasr_detach.register import tables
7
+
8
+
9
+ @tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
10
+ class BatchSampler(torch.utils.data.BatchSampler):
11
+
12
+ def __init__(
13
+ self,
14
+ dataset,
15
+ batch_type: str = "example",
16
+ batch_size: int = 100,
17
+ buffer_size: int = 30,
18
+ drop_last: bool = False,
19
+ shuffle: bool = True,
20
+ is_training: bool = True,
21
+ **kwargs
22
+ ):
23
+
24
+ self.drop_last = drop_last
25
+ self.pre_idx = -1
26
+ self.dataset = dataset
27
+ self.total_samples = len(dataset)
28
+ self.batch_type = batch_type
29
+ self.batch_size = int(batch_size)
30
+ self.buffer_size = buffer_size
31
+ self.max_token_length = kwargs.get("max_token_length", 5000)
32
+ self.shuffle_idx = np.arange(self.total_samples)
33
+ self.shuffle = shuffle and is_training
34
+ self.length_scale_source = kwargs.get("length_scale_source", 1.0)
35
+
36
+ def __len__(self):
37
+ return (self.total_samples - 1) // self.batch_size + 1
38
+
39
+ def set_epoch(self, epoch):
40
+ np.random.seed(epoch)
41
+
42
+ def __iter__(self):
43
+
44
+ if self.shuffle:
45
+ np.random.shuffle(self.shuffle_idx)
46
+
47
+ batch = []
48
+ max_token = 0
49
+ num_sample = 0
50
+
51
+ iter_num = (self.total_samples - 1) // self.buffer_size + 1
52
+ # print("iter_num: ", iter_num)
53
+ for iter in range(self.pre_idx + 1, iter_num):
54
+ datalen_with_index = []
55
+ for i in range(self.buffer_size):
56
+ idx = iter * self.buffer_size + i
57
+ if idx >= self.total_samples:
58
+ continue
59
+
60
+ idx_map = self.shuffle_idx[idx]
61
+ # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
62
+ target_len = (
63
+ self.dataset.get_target_len(idx_map)
64
+ if self.batch_type == "length"
65
+ else 0.0
66
+ )
67
+ source_len = (
68
+ self.dataset.get_source_len(idx_map) / self.length_scale_source
69
+ )
70
+ sample_len_cur = source_len + target_len
71
+
72
+ datalen_with_index.append([idx, sample_len_cur])
73
+
74
+ datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
75
+ for item in datalen_with_index_sort:
76
+ idx, sample_len_cur_raw = item
77
+ if sample_len_cur_raw > self.max_token_length:
78
+ continue
79
+
80
+ max_token_cur = max(max_token, sample_len_cur_raw)
81
+ max_token_padding = 1 + num_sample
82
+ if self.batch_type != "example":
83
+ max_token_padding *= max_token_cur
84
+ if max_token_padding <= self.batch_size:
85
+ batch.append(idx)
86
+ max_token = max_token_cur
87
+ num_sample += 1
88
+ else:
89
+ yield batch
90
+ batch = [idx]
91
+ max_token = sample_len_cur_raw
92
+ num_sample = 1
93
+
94
+
95
+ @tables.register("batch_sampler_classes", "BatchSampler")
96
+ @tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
97
+ class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
98
+
99
+ def __init__(
100
+ self,
101
+ dataset,
102
+ batch_type: str = "example",
103
+ batch_size: int = 100,
104
+ buffer_size: int = 30,
105
+ drop_last: bool = True,
106
+ shuffle: bool = True,
107
+ is_training: bool = True,
108
+ **kwargs
109
+ ):
110
+
111
+ self.drop_last = drop_last
112
+ self.pre_idx = -1
113
+ self.dataset = dataset
114
+ self.total_samples = len(dataset)
115
+ self.batch_type = batch_type
116
+ self.batch_size = int(batch_size)
117
+ self.buffer_size = buffer_size
118
+ self.max_token_length = kwargs.get("max_token_length", 1500)
119
+ self.shuffle_idx = np.arange(self.total_samples)
120
+ self.shuffle = shuffle and is_training
121
+ self.length_scale_source = kwargs.get("length_scale_source", 1.0)
122
+
123
+ try:
124
+ rank = dist.get_rank()
125
+ world_size = dist.get_world_size()
126
+ except:
127
+ rank = 0
128
+ world_size = 1
129
+ self.rank = rank
130
+ self.world_size = world_size
131
+
132
+ def __len__(self):
133
+ return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
134
+
135
+ def set_epoch(self, epoch):
136
+ np.random.seed(epoch)
137
+
138
+ def __iter__(self):
139
+
140
+ batch_size_total = self.batch_size * self.world_size
141
+
142
+ if self.shuffle:
143
+ np.random.shuffle(self.shuffle_idx)
144
+
145
+ batch = []
146
+ max_token = 0
147
+ num_sample = 0
148
+
149
+ iter_num = (self.total_samples - 1) // self.buffer_size + 1
150
+ # print("iter_num: ", iter_num)
151
+ for iter in range(self.pre_idx + 1, iter_num):
152
+ # if iter == iter_num -1 and self.drop_last:
153
+ # continue
154
+ datalen_with_index = []
155
+ for i in range(self.buffer_size):
156
+ idx = iter * self.buffer_size + i
157
+ if idx >= self.total_samples:
158
+ continue
159
+
160
+ idx_map = self.shuffle_idx[idx]
161
+ # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
162
+
163
+ source_len = (
164
+ self.dataset.get_source_len(idx_map) / self.length_scale_source
165
+ )
166
+ target_len = (
167
+ self.dataset.get_target_len(idx_map)
168
+ if self.batch_type == "length"
169
+ else 0.0
170
+ )
171
+ sample_len_cur = source_len + target_len
172
+
173
+ datalen_with_index.append([idx, sample_len_cur])
174
+
175
+ datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
176
+ for item in datalen_with_index_sort:
177
+ idx, sample_len_cur_raw = item
178
+ if sample_len_cur_raw > self.max_token_length:
179
+ continue
180
+
181
+ max_token_cur = max(max_token, sample_len_cur_raw)
182
+ max_token_padding = 1 + num_sample
183
+ # if self.batch_type != 'example':
184
+ # max_token_padding *= max_token_cur
185
+ if max_token_padding <= batch_size_total:
186
+ batch.append(idx)
187
+ max_token = max_token_cur
188
+ num_sample += 1
189
+ else:
190
+ batch_rank = batch[
191
+ self.rank * self.batch_size : (self.rank + 1) * self.batch_size
192
+ ]
193
+ yield batch_rank
194
+ batch = [idx]
195
+ max_token = sample_len_cur_raw
196
+ num_sample = 1
197
+
198
+
199
+ @tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
200
+ class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
201
+
202
+ def __init__(
203
+ self,
204
+ dataset,
205
+ batch_type: str = "example",
206
+ batch_size: int = 100,
207
+ buffer_size: int = 30,
208
+ drop_last: bool = True,
209
+ shuffle: bool = True,
210
+ is_training: bool = True,
211
+ **kwargs
212
+ ):
213
+
214
+ self.drop_last = drop_last
215
+ self.pre_idx = -1
216
+ self.dataset = dataset
217
+ self.total_samples = len(dataset)
218
+ self.batch_type = batch_type
219
+ self.batch_size = int(batch_size)
220
+ self.buffer_size = buffer_size
221
+ self.max_token_length = kwargs.get("max_token_length", 1500)
222
+ self.shuffle_idx = np.arange(self.total_samples)
223
+ self.shuffle = shuffle and is_training
224
+ self.length_scale_source = kwargs.get("length_scale_source", 1.0)
225
+
226
+ try:
227
+ rank = dist.get_rank()
228
+ world_size = dist.get_world_size()
229
+ except:
230
+ rank = 0
231
+ world_size = 1
232
+ self.rank = rank
233
+ self.world_size = world_size
234
+
235
+ def __len__(self):
236
+ return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
237
+
238
+ def set_epoch(self, epoch):
239
+ np.random.seed(epoch)
240
+
241
+ def __iter__(self):
242
+
243
+ batch_size_total = self.batch_size * self.world_size
244
+ if self.shuffle:
245
+ np.random.shuffle(self.shuffle_idx)
246
+
247
+ batch_list_all_rank = []
248
+ batch_list_cur = []
249
+ max_token = 0
250
+ num_sample = 0
251
+
252
+ iter_num = (self.total_samples - 1) // self.buffer_size + 1
253
+ # print("iter_num: ", iter_num)
254
+ for iter in range(self.pre_idx + 1, iter_num):
255
+ # if iter == iter_num - 1 and self.drop_last:
256
+ # continue
257
+ datalen_with_index = []
258
+ for i in range(self.buffer_size):
259
+ idx = iter * self.buffer_size + i
260
+ if idx >= self.total_samples:
261
+ continue
262
+
263
+ idx_map = self.shuffle_idx[idx]
264
+ # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
265
+
266
+ source_len = (
267
+ self.dataset.get_source_len(idx_map) / self.length_scale_source
268
+ )
269
+ target_len = (
270
+ self.dataset.get_target_len(idx_map)
271
+ if self.batch_type == "length"
272
+ else 0.0
273
+ )
274
+ sample_len_cur = source_len + target_len
275
+
276
+ datalen_with_index.append([idx, sample_len_cur])
277
+
278
+ datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
279
+ for ii, item in enumerate(datalen_with_index_sort):
280
+ is_last_batch = iter == iter_num - 1 and ii == len(
281
+ datalen_with_index_sort
282
+ )
283
+ idx, sample_len_cur_raw = item
284
+ if sample_len_cur_raw > self.max_token_length:
285
+ continue
286
+
287
+ max_token_cur = max(max_token, sample_len_cur_raw)
288
+ max_token_padding = 1 + num_sample
289
+
290
+ if self.batch_type != "example":
291
+ max_token_padding *= max_token_cur
292
+ if len(batch_list_all_rank) < self.world_size:
293
+
294
+ if max_token_padding <= self.batch_size:
295
+ batch_list_cur.append(idx)
296
+ max_token = max_token_cur
297
+ num_sample += 1
298
+ else:
299
+ batch_list_all_rank.append(batch_list_cur)
300
+ batch_list_cur = []
301
+ else:
302
+ batch_rank = batch_list_all_rank[self.rank]
303
+ yield batch_rank
304
+ batch_list_all_rank = [idx]
305
+ max_token = sample_len_cur_raw
306
+ num_sample = 1
demo/Step-Audio-EditX/funasr_detach/datasets/audio_datasets/scp2jsonl.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import logging
5
+ import hydra
6
+ from omegaconf import DictConfig, OmegaConf
7
+ import concurrent.futures
8
+ import librosa
9
+ import torch.distributed as dist
10
+
11
+
12
+ def gen_jsonl_from_wav_text_list(
13
+ path, data_type_list=("source", "target"), jsonl_file_out: str = None, **kwargs
14
+ ):
15
+ try:
16
+ rank = dist.get_rank()
17
+ world_size = dist.get_world_size()
18
+ except:
19
+ rank = 0
20
+ world_size = 1
21
+
22
+ cpu_cores = os.cpu_count() or 1
23
+ print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
24
+ if rank == 0:
25
+ json_dict = {}
26
+ for data_type, data_file in zip(data_type_list, path):
27
+ json_dict[data_type] = {}
28
+ with open(data_file, "r") as f:
29
+
30
+ data_file_lists = f.readlines()
31
+ lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
32
+ task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
33
+ with concurrent.futures.ThreadPoolExecutor(
34
+ max_workers=cpu_cores
35
+ ) as executor:
36
+
37
+ futures = [
38
+ executor.submit(
39
+ parse_context_length,
40
+ data_file_lists[
41
+ i * lines_for_each_th : (i + 1) * lines_for_each_th
42
+ ],
43
+ data_type,
44
+ )
45
+ for i in range(task_num)
46
+ ]
47
+
48
+ for future in concurrent.futures.as_completed(futures):
49
+
50
+ json_dict[data_type].update(future.result())
51
+ # print(json_dict)
52
+
53
+ with open(jsonl_file_out, "w") as f:
54
+ for key in json_dict[data_type_list[0]].keys():
55
+ jsonl_line = {"key": key}
56
+ for data_file in data_type_list:
57
+ jsonl_line.update(json_dict[data_file][key])
58
+ jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
59
+ f.write(jsonl_line + "\n")
60
+ f.flush()
61
+
62
+ else:
63
+ pass
64
+
65
+ if world_size > 1:
66
+ dist.barrier()
67
+
68
+
69
+ def parse_context_length(data_list: list, data_type: str):
70
+
71
+ res = {}
72
+ for i, line in enumerate(data_list):
73
+ key, line = line.strip().split(maxsplit=1)
74
+ line = line.strip()
75
+ if os.path.exists(line):
76
+ waveform, _ = librosa.load(line, sr=16000)
77
+ sample_num = len(waveform)
78
+ context_len = int(sample_num // 16000 * 1000 / 10)
79
+ else:
80
+ context_len = len(line.split()) if " " in line else len(line)
81
+ res[key] = {data_type: line, f"{data_type}_len": context_len}
82
+ return res
83
+
84
+
85
+ @hydra.main(config_name=None, version_base=None)
86
+ def main_hydra(cfg: DictConfig):
87
+
88
+ kwargs = OmegaConf.to_container(cfg, resolve=True)
89
+
90
+ scp_file_list = kwargs.get(
91
+ "scp_file_list",
92
+ (
93
+ "/Users/zhifu/funasr1.0/test_local/wav.scp",
94
+ "/Users/zhifu/funasr1.0/test_local/text.txt",
95
+ ),
96
+ )
97
+ if isinstance(scp_file_list, str):
98
+ scp_file_list = eval(scp_file_list)
99
+ data_type_list = kwargs.get("data_type_list", ("source", "target"))
100
+ jsonl_file_out = kwargs.get(
101
+ "jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
102
+ )
103
+ gen_jsonl_from_wav_text_list(
104
+ scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out
105
+ )
106
+
107
+
108
+ """
109
+ python -m funasr_detach.datasets.audio_datasets.scp2jsonl \
110
+ ++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
111
+ ++data_type_list='["source", "target"]' \
112
+ ++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
113
+ """
114
+
115
+ if __name__ == "__main__":
116
+ main_hydra()
demo/Step-Audio-EditX/funasr_detach/download/__init__.py ADDED
File without changes
demo/Step-Audio-EditX/funasr_detach/download/download_dataset_from_hub.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def download_dataset():
2
+ pass
3
+
4
+
5
+ def download_dataset_from_ms(**kwargs):
6
+ from modelscope.msdatasets import MsDataset
7
+
8
+ dataset_name = kwargs.get(
9
+ "dataset_name", "speech_asr/speech_asr_aishell1_trainsets"
10
+ )
11
+ subset_name = kwargs.get("subset_name", "default")
12
+ split = kwargs.get("split", "train")
13
+ data_dump_dir = kwargs.get("data_dump_dir", None)
14
+ ds = MsDataset.load(
15
+ dataset_name=dataset_name,
16
+ subset_name=subset_name,
17
+ split=split,
18
+ cache_dir=data_dump_dir,
19
+ )
demo/Step-Audio-EditX/funasr_detach/download/download_from_hub.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import threading
4
+ from omegaconf import OmegaConf
5
+
6
+ from funasr_detach.download.name_maps_from_hub import name_maps_ms, name_maps_hf
7
+
8
+ # Global cache for downloaded models to avoid repeated downloads
9
+ # Key: (repo_id, model_revision, model_hub)
10
+ # Value: repo_cache_dir
11
+ _model_cache = {}
12
+ _cache_lock = threading.Lock()
13
+
14
+
15
+ def download_model(**kwargs):
16
+ model_hub = kwargs.get("model_hub", "ms")
17
+ model_or_path = kwargs.get("model")
18
+ repo_path = kwargs.get("repo_path", "")
19
+
20
+ # Handle name mapping based on model_hub
21
+ if model_hub == "ms" and model_or_path in name_maps_ms:
22
+ model_or_path = name_maps_ms[model_or_path]
23
+ elif model_hub == "hf" and model_or_path in name_maps_hf:
24
+ model_or_path = name_maps_hf[model_or_path]
25
+
26
+ model_revision = kwargs.get("model_revision")
27
+
28
+ # Download model if it doesn't exist locally
29
+ if not os.path.exists(model_or_path):
30
+ if model_hub == "local":
31
+ # For local models, the path should already exist
32
+ raise FileNotFoundError(f"Local model path does not exist: {model_or_path}")
33
+ elif model_hub in ["ms", "hf"]:
34
+ repo_path, model_or_path = get_or_download_model_dir(
35
+ model_or_path,
36
+ model_revision,
37
+ is_training=kwargs.get("is_training"),
38
+ check_latest=kwargs.get("kwargs", True),
39
+ model_hub=model_hub,
40
+ )
41
+ else:
42
+ raise ValueError(f"Unsupported model_hub: {model_hub}")
43
+
44
+ print(f"Using model path: {model_or_path}")
45
+ kwargs["model_path"] = model_or_path
46
+ kwargs["repo_path"] = repo_path
47
+
48
+ # Common logic for processing configuration files (same for all model hubs)
49
+ if os.path.exists(os.path.join(model_or_path, "configuration.json")):
50
+ with open(
51
+ os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8"
52
+ ) as f:
53
+ conf_json = json.load(f)
54
+ cfg = {}
55
+ add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
56
+ cfg.update(kwargs)
57
+ config = OmegaConf.load(cfg["config"])
58
+ kwargs = OmegaConf.merge(config, cfg)
59
+ kwargs["model"] = config["model"]
60
+ elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
61
+ os.path.join(model_or_path, "model.pt")
62
+ ):
63
+ config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
64
+ kwargs = OmegaConf.merge(config, kwargs)
65
+ init_param = os.path.join(model_or_path, "model.pb")
66
+ kwargs["init_param"] = init_param
67
+ if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
68
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(
69
+ model_or_path, "tokens.txt"
70
+ )
71
+ if os.path.exists(os.path.join(model_or_path, "tokens.json")):
72
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(
73
+ model_or_path, "tokens.json"
74
+ )
75
+ if os.path.exists(os.path.join(model_or_path, "seg_dict")):
76
+ kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(
77
+ model_or_path, "seg_dict"
78
+ )
79
+ if os.path.exists(os.path.join(model_or_path, "bpe.model")):
80
+ kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(
81
+ model_or_path, "bpe.model"
82
+ )
83
+ kwargs["model"] = config["model"]
84
+ if os.path.exists(os.path.join(model_or_path, "am.mvn")):
85
+ kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
86
+ if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
87
+ kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
88
+
89
+ return OmegaConf.to_container(kwargs, resolve=True)
90
+
91
+
92
+ def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
93
+
94
+ if isinstance(file_path_metas, dict):
95
+ for k, v in file_path_metas.items():
96
+ if isinstance(v, str):
97
+ p = os.path.join(model_or_path, v)
98
+ if os.path.exists(p):
99
+ cfg[k] = p
100
+ elif isinstance(v, dict):
101
+ if k not in cfg:
102
+ cfg[k] = {}
103
+ add_file_root_path(model_or_path, v, cfg[k])
104
+
105
+ return cfg
106
+
107
+
108
+ def get_or_download_model_dir(
109
+ model,
110
+ model_revision=None,
111
+ is_training=False,
112
+ check_latest=True,
113
+ model_hub="ms",
114
+ ):
115
+ """Get local model directory or download model if necessary.
116
+
117
+ Args:
118
+ model (str): model id or path to local model directory.
119
+ For HF subfolders, use format: "repo_id/subfolder_path"
120
+ model_revision (str, optional): model version number.
121
+ is_training (bool): Whether this is for training
122
+ check_latest (bool): Whether to check for latest version
123
+ model_hub (str): Model hub type ("ms" for ModelScope, "hf" for HuggingFace)
124
+ """
125
+ # Extract repo_id for caching (handle subfolder case)
126
+ if "/" in model and len(model.split("/")) > 2:
127
+ parts = model.split("/")
128
+ repo_id = "/".join(parts[:2]) # e.g., "organization/repo" or "stepfun-ai/Step-Audio-EditX"
129
+ subfolder = "/".join(parts[2:]) # e.g., "subfolder/model"
130
+ else:
131
+ repo_id = model
132
+ subfolder = None
133
+
134
+ # Create cache key
135
+ cache_key = (repo_id, model_revision, model_hub)
136
+
137
+ # Check cache first
138
+ with _cache_lock:
139
+ if cache_key in _model_cache:
140
+ cached_repo_dir = _model_cache[cache_key]
141
+ print(f"Using cached model for {repo_id}: {cached_repo_dir}")
142
+
143
+ # For subfolder case, construct the model_cache_dir from cached repo
144
+ if subfolder:
145
+ model_cache_dir = os.path.join(cached_repo_dir, subfolder)
146
+ if not os.path.exists(model_cache_dir):
147
+ raise FileNotFoundError(f"Subfolder {subfolder} not found in cached repo {repo_id}")
148
+ else:
149
+ model_cache_dir = cached_repo_dir
150
+
151
+ return cached_repo_dir, model_cache_dir
152
+
153
+ # Cache miss, need to download
154
+ if model_hub == "ms":
155
+ # ModelScope download
156
+ from modelscope.hub.snapshot_download import snapshot_download
157
+ from modelscope.utils.constant import Invoke, ThirdParty
158
+
159
+ key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
160
+
161
+ # Download the repo (use repo_id, not the full model path with subfolder)
162
+ repo_cache_dir = snapshot_download(
163
+ repo_id,
164
+ revision=model_revision,
165
+ user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"},
166
+ )
167
+ repo_cache_dir = normalize_cache_path(repo_cache_dir)
168
+
169
+ # Construct model_cache_dir
170
+ if subfolder:
171
+ model_cache_dir = os.path.join(repo_cache_dir, subfolder)
172
+ if not os.path.exists(model_cache_dir):
173
+ raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}")
174
+ else:
175
+ model_cache_dir = normalize_cache_path(repo_cache_dir)
176
+
177
+ elif model_hub == "hf":
178
+ # HuggingFace download
179
+ try:
180
+ from huggingface_hub import snapshot_download
181
+ except ImportError:
182
+ raise ImportError(
183
+ "huggingface_hub is required for downloading from HuggingFace. "
184
+ "Please install it with: pip install huggingface_hub"
185
+ )
186
+
187
+ # Download the repo (use repo_id, not the full model path with subfolder)
188
+ repo_cache_dir = snapshot_download(
189
+ repo_id=repo_id,
190
+ revision=model_revision,
191
+ allow_patterns=None, # Download all files to ensure resource files are available
192
+ )
193
+ repo_cache_dir = normalize_cache_path(repo_cache_dir)
194
+
195
+ # Construct model_cache_dir
196
+ if subfolder:
197
+ model_cache_dir = os.path.join(repo_cache_dir, subfolder)
198
+ if not os.path.exists(model_cache_dir):
199
+ raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}")
200
+ else:
201
+ model_cache_dir = normalize_cache_path(repo_cache_dir)
202
+ else:
203
+ raise ValueError(f"Unsupported model_hub: {model_hub}")
204
+
205
+ # Cache the result before returning
206
+ with _cache_lock:
207
+ _model_cache[cache_key] = repo_cache_dir
208
+
209
+ print(f"Model downloaded to: {model_cache_dir}")
210
+ return repo_cache_dir, model_cache_dir
211
+
212
+ def normalize_cache_path(cache_path):
213
+ """Normalize cache path to ensure consistent format with snapshots/{commit_id}."""
214
+ # Check if the cache_path directory contains a snapshots folder
215
+ snapshots_dir = os.path.join(cache_path, "snapshots")
216
+ if os.path.exists(snapshots_dir) and os.path.isdir(snapshots_dir):
217
+ # Find the commit_id subdirectory in snapshots
218
+ try:
219
+ snapshot_items = os.listdir(snapshots_dir)
220
+ # Look for the first directory (should be the commit_id)
221
+ for item in snapshot_items:
222
+ item_path = os.path.join(snapshots_dir, item)
223
+ if os.path.isdir(item_path):
224
+ # Found commit_id directory, return the full path
225
+ return os.path.join(cache_path, "snapshots", item)
226
+ except OSError:
227
+ pass
228
+
229
+ # If no snapshots directory found or error occurred, return original path
230
+ return cache_path
231
+
demo/Step-Audio-EditX/funasr_detach/download/file.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import contextlib
4
+ import os
5
+ import tempfile
6
+ from abc import ABCMeta, abstractmethod
7
+ from pathlib import Path
8
+ from typing import Generator, Union
9
+
10
+ import requests
11
+ from urllib.parse import urlparse
12
+
13
+
14
+ def download_from_url(url):
15
+ result = urlparse(url)
16
+ file_path = None
17
+ if result.scheme is not None and len(result.scheme) > 0:
18
+ storage = HTTPStorage()
19
+ # bytes
20
+ data = storage.read(url)
21
+ work_dir = tempfile.TemporaryDirectory().name
22
+ if not os.path.exists(work_dir):
23
+ os.makedirs(work_dir)
24
+ file_path = os.path.join(work_dir, os.path.basename(url))
25
+ with open(file_path, "wb") as fb:
26
+ fb.write(data)
27
+ assert file_path is not None, f"failed to download: {url}"
28
+ return file_path
29
+
30
+
31
+ class Storage(metaclass=ABCMeta):
32
+ """Abstract class of storage.
33
+
34
+ All backends need to implement two apis: ``read()`` and ``read_text()``.
35
+ ``read()`` reads the file as a byte stream and ``read_text()`` reads
36
+ the file as texts.
37
+ """
38
+
39
+ @abstractmethod
40
+ def read(self, filepath: str):
41
+ pass
42
+
43
+ @abstractmethod
44
+ def read_text(self, filepath: str):
45
+ pass
46
+
47
+ @abstractmethod
48
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
49
+ pass
50
+
51
+ @abstractmethod
52
+ def write_text(
53
+ self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8"
54
+ ) -> None:
55
+ pass
56
+
57
+
58
+ class LocalStorage(Storage):
59
+ """Local hard disk storage"""
60
+
61
+ def read(self, filepath: Union[str, Path]) -> bytes:
62
+ """Read data from a given ``filepath`` with 'rb' mode.
63
+
64
+ Args:
65
+ filepath (str or Path): Path to read data.
66
+
67
+ Returns:
68
+ bytes: Expected bytes object.
69
+ """
70
+ with open(filepath, "rb") as f:
71
+ content = f.read()
72
+ return content
73
+
74
+ def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
75
+ """Read data from a given ``filepath`` with 'r' mode.
76
+
77
+ Args:
78
+ filepath (str or Path): Path to read data.
79
+ encoding (str): The encoding format used to open the ``filepath``.
80
+ Default: 'utf-8'.
81
+
82
+ Returns:
83
+ str: Expected text reading from ``filepath``.
84
+ """
85
+ with open(filepath, "r", encoding=encoding) as f:
86
+ value_buf = f.read()
87
+ return value_buf
88
+
89
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
90
+ """Write data to a given ``filepath`` with 'wb' mode.
91
+
92
+ Note:
93
+ ``write`` will create a directory if the directory of ``filepath``
94
+ does not exist.
95
+
96
+ Args:
97
+ obj (bytes): Data to be written.
98
+ filepath (str or Path): Path to write data.
99
+ """
100
+ dirname = os.path.dirname(filepath)
101
+ if dirname and not os.path.exists(dirname):
102
+ os.makedirs(dirname, exist_ok=True)
103
+
104
+ with open(filepath, "wb") as f:
105
+ f.write(obj)
106
+
107
+ def write_text(
108
+ self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8"
109
+ ) -> None:
110
+ """Write data to a given ``filepath`` with 'w' mode.
111
+
112
+ Note:
113
+ ``write_text`` will create a directory if the directory of
114
+ ``filepath`` does not exist.
115
+
116
+ Args:
117
+ obj (str): Data to be written.
118
+ filepath (str or Path): Path to write data.
119
+ encoding (str): The encoding format used to open the ``filepath``.
120
+ Default: 'utf-8'.
121
+ """
122
+ dirname = os.path.dirname(filepath)
123
+ if dirname and not os.path.exists(dirname):
124
+ os.makedirs(dirname, exist_ok=True)
125
+
126
+ with open(filepath, "w", encoding=encoding) as f:
127
+ f.write(obj)
128
+
129
+ @contextlib.contextmanager
130
+ def as_local_path(
131
+ self, filepath: Union[str, Path]
132
+ ) -> Generator[Union[str, Path], None, None]:
133
+ """Only for unified API and do nothing."""
134
+ yield filepath
135
+
136
+
137
+ class HTTPStorage(Storage):
138
+ """HTTP and HTTPS storage."""
139
+
140
+ def read(self, url):
141
+ # TODO @wenmeng.zwm add progress bar if file is too large
142
+ r = requests.get(url)
143
+ r.raise_for_status()
144
+ return r.content
145
+
146
+ def read_text(self, url):
147
+ r = requests.get(url)
148
+ r.raise_for_status()
149
+ return r.text
150
+
151
+ @contextlib.contextmanager
152
+ def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
153
+ """Download a file from ``filepath``.
154
+
155
+ ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
156
+ can be called with ``with`` statement, and when exists from the
157
+ ``with`` statement, the temporary path will be released.
158
+
159
+ Args:
160
+ filepath (str): Download a file from ``filepath``.
161
+
162
+ Examples:
163
+ >>> storage = HTTPStorage()
164
+ >>> # After existing from the ``with`` clause,
165
+ >>> # the path will be removed
166
+ >>> with storage.get_local_path('http://path/to/file') as path:
167
+ ... # do something here
168
+ """
169
+ try:
170
+ f = tempfile.NamedTemporaryFile(delete=False)
171
+ f.write(self.read(filepath))
172
+ f.close()
173
+ yield f.name
174
+ finally:
175
+ os.remove(f.name)
176
+
177
+ def write(self, obj: bytes, url: Union[str, Path]) -> None:
178
+ raise NotImplementedError("write is not supported by HTTP Storage")
179
+
180
+ def write_text(
181
+ self, obj: str, url: Union[str, Path], encoding: str = "utf-8"
182
+ ) -> None:
183
+ raise NotImplementedError("write_text is not supported by HTTP Storage")
184
+
185
+
186
+ class OSSStorage(Storage):
187
+ """OSS storage."""
188
+
189
+ def __init__(self, oss_config_file=None):
190
+ # read from config file or env var
191
+ raise NotImplementedError("OSSStorage.__init__ to be implemented in the future")
192
+
193
+ def read(self, filepath):
194
+ raise NotImplementedError("OSSStorage.read to be implemented in the future")
195
+
196
+ def read_text(self, filepath, encoding="utf-8"):
197
+ raise NotImplementedError(
198
+ "OSSStorage.read_text to be implemented in the future"
199
+ )
200
+
201
+ @contextlib.contextmanager
202
+ def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
203
+ """Download a file from ``filepath``.
204
+
205
+ ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
206
+ can be called with ``with`` statement, and when exists from the
207
+ ``with`` statement, the temporary path will be released.
208
+
209
+ Args:
210
+ filepath (str): Download a file from ``filepath``.
211
+
212
+ Examples:
213
+ >>> storage = OSSStorage()
214
+ >>> # After existing from the ``with`` clause,
215
+ >>> # the path will be removed
216
+ >>> with storage.get_local_path('http://path/to/file') as path:
217
+ ... # do something here
218
+ """
219
+ try:
220
+ f = tempfile.NamedTemporaryFile(delete=False)
221
+ f.write(self.read(filepath))
222
+ f.close()
223
+ yield f.name
224
+ finally:
225
+ os.remove(f.name)
226
+
227
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
228
+ raise NotImplementedError("OSSStorage.write to be implemented in the future")
229
+
230
+ def write_text(
231
+ self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8"
232
+ ) -> None:
233
+ raise NotImplementedError(
234
+ "OSSStorage.write_text to be implemented in the future"
235
+ )
236
+
237
+
238
+ G_STORAGES = {}
239
+
240
+
241
+ class File(object):
242
+ _prefix_to_storage: dict = {
243
+ "oss": OSSStorage,
244
+ "http": HTTPStorage,
245
+ "https": HTTPStorage,
246
+ "local": LocalStorage,
247
+ }
248
+
249
+ @staticmethod
250
+ def _get_storage(uri):
251
+ assert isinstance(uri, str), f"uri should be str type, but got {type(uri)}"
252
+
253
+ if "://" not in uri:
254
+ # local path
255
+ storage_type = "local"
256
+ else:
257
+ prefix, _ = uri.split("://")
258
+ storage_type = prefix
259
+
260
+ assert storage_type in File._prefix_to_storage, (
261
+ f"Unsupported uri {uri}, valid prefixs: "
262
+ f"{list(File._prefix_to_storage.keys())}"
263
+ )
264
+
265
+ if storage_type not in G_STORAGES:
266
+ G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
267
+
268
+ return G_STORAGES[storage_type]
269
+
270
+ @staticmethod
271
+ def read(uri: str) -> bytes:
272
+ """Read data from a given ``filepath`` with 'rb' mode.
273
+
274
+ Args:
275
+ filepath (str or Path): Path to read data.
276
+
277
+ Returns:
278
+ bytes: Expected bytes object.
279
+ """
280
+ storage = File._get_storage(uri)
281
+ return storage.read(uri)
282
+
283
+ @staticmethod
284
+ def read_text(uri: Union[str, Path], encoding: str = "utf-8") -> str:
285
+ """Read data from a given ``filepath`` with 'r' mode.
286
+
287
+ Args:
288
+ filepath (str or Path): Path to read data.
289
+ encoding (str): The encoding format used to open the ``filepath``.
290
+ Default: 'utf-8'.
291
+
292
+ Returns:
293
+ str: Expected text reading from ``filepath``.
294
+ """
295
+ storage = File._get_storage(uri)
296
+ return storage.read_text(uri)
297
+
298
+ @staticmethod
299
+ def write(obj: bytes, uri: Union[str, Path]) -> None:
300
+ """Write data to a given ``filepath`` with 'wb' mode.
301
+
302
+ Note:
303
+ ``write`` will create a directory if the directory of ``filepath``
304
+ does not exist.
305
+
306
+ Args:
307
+ obj (bytes): Data to be written.
308
+ filepath (str or Path): Path to write data.
309
+ """
310
+ storage = File._get_storage(uri)
311
+ return storage.write(obj, uri)
312
+
313
+ @staticmethod
314
+ def write_text(obj: str, uri: str, encoding: str = "utf-8") -> None:
315
+ """Write data to a given ``filepath`` with 'w' mode.
316
+
317
+ Note:
318
+ ``write_text`` will create a directory if the directory of
319
+ ``filepath`` does not exist.
320
+
321
+ Args:
322
+ obj (str): Data to be written.
323
+ filepath (str or Path): Path to write data.
324
+ encoding (str): The encoding format used to open the ``filepath``.
325
+ Default: 'utf-8'.
326
+ """
327
+ storage = File._get_storage(uri)
328
+ return storage.write_text(obj, uri)
329
+
330
+ @contextlib.contextmanager
331
+ def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
332
+ """Only for unified API and do nothing."""
333
+ storage = File._get_storage(uri)
334
+ with storage.as_local_path(uri) as local_path:
335
+ yield local_path
demo/Step-Audio-EditX/funasr_detach/download/name_maps_from_hub.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name_maps_ms = {
2
+ "paraformer-zh": "damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
3
+ "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
4
+ "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
5
+ "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
6
+ "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
7
+ "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
8
+ "ct-punc-c": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
9
+ "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
10
+ "cam++": "damo/speech_campplus_sv_zh-cn_16k-common",
11
+ }
12
+
13
+ name_maps_hf = {}
demo/Step-Audio-EditX/funasr_detach/download/runtime_sdk_download_tool.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from pathlib import Path
4
+
5
+ from funasr_detach.utils.types import str2bool
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("--model-name", type=str, required=True)
11
+ parser.add_argument("--export-dir", type=str, required=True)
12
+ parser.add_argument(
13
+ "--export", type=str2bool, default=True, help="whether to export model"
14
+ )
15
+ parser.add_argument("--type", type=str, default="onnx", help='["onnx", "torch"]')
16
+ parser.add_argument("--device", type=str, default="cpu", help='["cpu", "cuda"]')
17
+ parser.add_argument(
18
+ "--quantize", type=str2bool, default=False, help="export quantized model"
19
+ )
20
+ parser.add_argument(
21
+ "--fallback-num", type=int, default=0, help="amp fallback number"
22
+ )
23
+ parser.add_argument("--audio_in", type=str, default=None, help='["wav", "wav.scp"]')
24
+ parser.add_argument(
25
+ "--model_revision", type=str, default=None, help="model_revision"
26
+ )
27
+ parser.add_argument("--calib_num", type=int, default=200, help="calib max num")
28
+ args = parser.parse_args()
29
+
30
+ model_dir = args.model_name
31
+ if not Path(args.model_name).exists():
32
+ from modelscope.hub.snapshot_download import snapshot_download
33
+
34
+ try:
35
+ model_dir = snapshot_download(
36
+ args.model_name, cache_dir=args.export_dir, revision=args.model_revision
37
+ )
38
+ except:
39
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
40
+ model_dir
41
+ )
42
+ if args.export:
43
+ model_file = os.path.join(model_dir, "model.onnx")
44
+ if args.quantize:
45
+ model_file = os.path.join(model_dir, "model_quant.onnx")
46
+ if not os.path.exists(model_file):
47
+ print(".onnx is not exist, begin to export onnx")
48
+ from funasr_detach.bin.export_model import ModelExport
49
+
50
+ export_model = ModelExport(
51
+ cache_dir=args.export_dir,
52
+ onnx=True,
53
+ device="cpu",
54
+ quant=args.quantize,
55
+ )
56
+ export_model.export(model_dir)
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
demo/Step-Audio-EditX/funasr_detach/frontends/__init__.py ADDED
File without changes