qinxy commited on
Commit
21c58e8
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. CODE_OF_CONDUCT.md +76 -0
  3. FAQ.md +16 -0
  4. LICENSE +201 -0
  5. README.md +241 -0
  6. asset/dingding.png +0 -0
  7. cosyvoice/__init__.py +0 -0
  8. cosyvoice/bin/average_model.py +93 -0
  9. cosyvoice/bin/export_jit.py +103 -0
  10. cosyvoice/bin/export_onnx.py +120 -0
  11. cosyvoice/bin/inference.py +125 -0
  12. cosyvoice/bin/train.py +175 -0
  13. cosyvoice/bin/train_dpo.py +187 -0
  14. cosyvoice/cli/__init__.py +0 -0
  15. cosyvoice/cli/cosyvoice.py +190 -0
  16. cosyvoice/cli/frontend.py +215 -0
  17. cosyvoice/cli/model.py +395 -0
  18. cosyvoice/dataset/__init__.py +0 -0
  19. cosyvoice/dataset/dataset.py +164 -0
  20. cosyvoice/dataset/processor.py +441 -0
  21. cosyvoice/dataset/processor_dpo.py +443 -0
  22. cosyvoice/flow/decoder.py +494 -0
  23. cosyvoice/flow/flow.py +281 -0
  24. cosyvoice/flow/flow_matching.py +224 -0
  25. cosyvoice/flow/length_regulator.py +70 -0
  26. cosyvoice/hifigan/discriminator.py +230 -0
  27. cosyvoice/hifigan/f0_predictor.py +58 -0
  28. cosyvoice/hifigan/generator.py +582 -0
  29. cosyvoice/hifigan/hifigan.py +67 -0
  30. cosyvoice/llm/llm.py +520 -0
  31. cosyvoice/llm/llm_dpo.py +556 -0
  32. cosyvoice/llm/llm_vllm.py +212 -0
  33. cosyvoice/llm/vllm_use_cosyvoice2_model.py +263 -0
  34. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +0 -0
  35. cosyvoice/tokenizer/tokenizer.py +281 -0
  36. cosyvoice/transformer/__init__.py +0 -0
  37. cosyvoice/transformer/activation.py +84 -0
  38. cosyvoice/transformer/attention.py +330 -0
  39. cosyvoice/transformer/convolution.py +145 -0
  40. cosyvoice/transformer/decoder.py +396 -0
  41. cosyvoice/transformer/decoder_layer.py +132 -0
  42. cosyvoice/transformer/embedding.py +302 -0
  43. cosyvoice/transformer/encoder.py +474 -0
  44. cosyvoice/transformer/encoder_layer.py +236 -0
  45. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  46. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  47. cosyvoice/transformer/subsampling.py +383 -0
  48. cosyvoice/transformer/upsample_encoder.py +320 -0
  49. cosyvoice/utils/__init__.py +0 -0
  50. cosyvoice/utils/class_utils.py +83 -0
.gitattributes ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.bin filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
3
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
4
+ *.onnx filter=lfs diff=lfs merge=lfs -text
5
+ *.zip filter=lfs diff=lfs merge=lfs -text
6
+ asset/*.wav filter=lfs diff=lfs merge=lfs -text
7
+ pretrained_models/**/*.whl filter=lfs diff=lfs merge=lfs -text
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to making participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies both within project spaces and in public spaces
49
+ when an individual is representing the project or its community. Examples of
50
+ representing a project or community include using an official project e-mail
51
+ address, posting via an official social media account, or acting as an appointed
52
+ representative at an online or offline event. Representation of a project may be
53
+ further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at mikelei@mobvoi.com. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see
76
+ https://www.contributor-covenant.org/faq
FAQ.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## ModuleNotFoundError: No module named 'matcha'
2
+
3
+ Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`.
4
+
5
+ run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script.
6
+
7
+ ## cannot find resource.zip or cannot unzip resource.zip
8
+
9
+ Please make sure you have git-lfs installed. Execute
10
+
11
+ ```sh
12
+ git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
13
+ cd pretrained_models/CosyVoice-ttsfrd/
14
+ unzip resource.zip -d .
15
+ pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
16
+ ```
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)](https://github.com/Akshay090/svg-banners)
2
+
3
+ ## 👉🏻 CosyVoice 👈🏻
4
+ **CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B)
5
+
6
+ **CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M)
7
+
8
+ ## Highlight🔥
9
+
10
+ **CosyVoice 2.0** has been released! Compared to version 1.0, the new version offers more accurate, more stable, faster, and better speech generation capabilities.
11
+ ### Multilingual
12
+ - **Supported Language**: Chinese, English, Japanese, Korean, Chinese dialects (Cantonese, Sichuanese, Shanghainese, Tianjinese, Wuhanese, etc.)
13
+ - **Crosslingual & Mixlingual**:Support zero-shot voice cloning for cross-lingual and code-switching scenarios.
14
+ ### Ultra-Low Latency
15
+ - **Bidirectional Streaming Support**: CosyVoice 2.0 integrates offline and streaming modeling technologies.
16
+ - **Rapid First Packet Synthesis**: Achieves latency as low as 150ms while maintaining high-quality audio output.
17
+ ### High Accuracy
18
+ - **Improved Pronunciation**: Reduces pronunciation errors by 30% to 50% compared to CosyVoice 1.0.
19
+ - **Benchmark Achievements**: Attains the lowest character error rate on the hard test set of the Seed-TTS evaluation set.
20
+ ### Strong Stability
21
+ - **Consistency in Timbre**: Ensures reliable voice consistency for zero-shot and cross-language speech synthesis.
22
+ - **Cross-language Synthesis**: Marked improvements compared to version 1.0.
23
+ ### Natural Experience
24
+ - **Enhanced Prosody and Sound Quality**: Improved alignment of synthesized audio, raising MOS evaluation scores from 5.4 to 5.53.
25
+ - **Emotional and Dialectal Flexibility**: Now supports more granular emotional controls and accent adjustments.
26
+
27
+ ## Roadmap
28
+
29
+ - [x] 2024/12
30
+
31
+ - [x] 25hz cosyvoice 2.0 released
32
+
33
+ - [x] 2024/09
34
+
35
+ - [x] 25hz cosyvoice base model
36
+ - [x] 25hz cosyvoice voice conversion model
37
+
38
+ - [x] 2024/08
39
+
40
+ - [x] Repetition Aware Sampling(RAS) inference for llm stability
41
+ - [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization
42
+
43
+ - [x] 2024/07
44
+
45
+ - [x] Flow matching training support
46
+ - [x] WeTextProcessing support when ttsfrd is not available
47
+ - [x] Fastapi server and client
48
+
49
+
50
+ ## Install
51
+
52
+ ### Clone and install
53
+
54
+ - Clone the repo
55
+ ``` sh
56
+ git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
57
+ # If you failed to clone submodule due to network failures, please run following command until success
58
+ cd CosyVoice
59
+ git submodule update --init --recursive
60
+ ```
61
+
62
+ - Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
63
+ - Create Conda env:
64
+
65
+ ``` sh
66
+ conda create -n cosyvoice -y python=3.10
67
+ conda activate cosyvoice
68
+ # pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
69
+ conda install -y -c conda-forge pynini==2.1.5
70
+ pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
71
+
72
+ # If you encounter sox compatibility issues
73
+ # ubuntu
74
+ sudo apt-get install sox libsox-dev
75
+ # centos
76
+ sudo yum install sox sox-devel
77
+ ```
78
+
79
+ ### Model download
80
+
81
+ We strongly recommend that you download our pretrained `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
82
+
83
+ ``` python
84
+ # SDK模型下载
85
+ from modelscope import snapshot_download
86
+ snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
87
+ snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
88
+ snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
89
+ snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
90
+ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
91
+ ```
92
+
93
+ ``` sh
94
+ # git模型下载,请确保已安装git lfs
95
+ mkdir -p pretrained_models
96
+ git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
97
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
98
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
99
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
100
+ git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
101
+ ```
102
+
103
+ Optionally, you can unzip `ttsfrd` resouce and install `ttsfrd` package for better text normalization performance.
104
+
105
+ Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use WeTextProcessing by default.
106
+
107
+ ``` sh
108
+ cd pretrained_models/CosyVoice-ttsfrd/
109
+ unzip resource.zip -d .
110
+ pip install ttsfrd_dependency-0.1-py3-none-any.whl
111
+ pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
112
+ ```
113
+
114
+ ### Basic Usage
115
+
116
+ We strongly recommend using `CosyVoice2-0.5B` for better performance.
117
+ Follow code below for detailed usage of each model.
118
+
119
+ ``` python
120
+ import sys
121
+ sys.path.append('third_party/Matcha-TTS')
122
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
123
+ from cosyvoice.utils.file_utils import load_wav
124
+ import torchaudio
125
+ ```
126
+
127
+ #### CosyVoice2 Usage
128
+ ```python
129
+ cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
130
+
131
+ # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
132
+ # zero_shot usage
133
+ prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
134
+ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
135
+ torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
136
+
137
+ # save zero_shot spk for future usage
138
+ assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', prompt_speech_16k, 'my_zero_shot_spk') is True
139
+ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk', stream=False)):
140
+ torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
141
+ cosyvoice.save_spkinfo()
142
+
143
+ # fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
144
+ for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
145
+ torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
146
+
147
+ # instruct usage
148
+ for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
149
+ torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
150
+
151
+ # bistream usage, you can use generator as input, this is useful when using text llm model as input
152
+ # NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
153
+ def text_generator():
154
+ yield '收到好友从远方寄来的生日礼物,'
155
+ yield '那份意外的惊喜与深深的祝福'
156
+ yield '让我心中充满了甜蜜的快乐,'
157
+ yield '笑容如花儿般绽放。'
158
+ for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
159
+ torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
160
+ ```
161
+
162
+ #### CosyVoice Usage
163
+ ```python
164
+ cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)
165
+ # sft usage
166
+ print(cosyvoice.list_available_spks())
167
+ # change stream=True for chunk stream inference
168
+ for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
169
+ torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
170
+
171
+ cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
172
+ # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
173
+ prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
174
+ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
175
+ torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
176
+ # cross_lingual usage
177
+ prompt_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
178
+ for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
179
+ torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
180
+ # vc usage
181
+ prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
182
+ source_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
183
+ for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
184
+ torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
185
+
186
+ cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
187
+ # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
188
+ for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
189
+ torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
190
+ ```
191
+
192
+ #### Start web demo
193
+
194
+ You can use our web demo page to get familiar with CosyVoice quickly.
195
+
196
+ Please see the demo website for details.
197
+
198
+ ``` python
199
+ # change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference
200
+ python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice2-0.5B
201
+ ```
202
+
203
+ #### Advanced Usage
204
+
205
+ For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
206
+
207
+ #### Build for deployment
208
+
209
+ Optionally, if you want service deployment,
210
+ you can run following steps.
211
+
212
+ ``` sh
213
+ cd runtime/python
214
+ docker build -t cosyvoice:v1.0 .
215
+ # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
216
+ # for grpc usage
217
+ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
218
+ cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
219
+ # for fastapi usage
220
+ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
221
+ cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
222
+ ```
223
+
224
+ ## Discussion & Communication
225
+
226
+ You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
227
+
228
+ You can also scan the QR code to join our official Dingding chat group.
229
+
230
+ <img src="./asset/dingding.png" width="250px">
231
+
232
+ ## Acknowledge
233
+
234
+ 1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
235
+ 2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
236
+ 3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
237
+ 4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
238
+ 5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
239
+
240
+ ## Disclaimer
241
+ The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
asset/dingding.png ADDED
cosyvoice/__init__.py ADDED
File without changes
cosyvoice/bin/average_model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Di Wu)
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import argparse
18
+ import glob
19
+
20
+ import yaml
21
+ import torch
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser(description='average model')
26
+ parser.add_argument('--dst_model', required=True, help='averaged model')
27
+ parser.add_argument('--src_path',
28
+ required=True,
29
+ help='src model path for average')
30
+ parser.add_argument('--val_best',
31
+ action="store_true",
32
+ help='averaged model')
33
+ parser.add_argument('--num',
34
+ default=5,
35
+ type=int,
36
+ help='nums for averaged model')
37
+
38
+ args = parser.parse_args()
39
+ print(args)
40
+ return args
41
+
42
+
43
+ def main():
44
+ args = get_args()
45
+ val_scores = []
46
+ if args.val_best:
47
+ yamls = glob.glob('{}/*.yaml'.format(args.src_path))
48
+ yamls = [
49
+ f for f in yamls
50
+ if not (os.path.basename(f).startswith('train')
51
+ or os.path.basename(f).startswith('init'))
52
+ ]
53
+ for y in yamls:
54
+ with open(y, 'r') as f:
55
+ dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
56
+ loss = float(dic_yaml['loss_dict']['loss'])
57
+ epoch = int(dic_yaml['epoch'])
58
+ step = int(dic_yaml['step'])
59
+ tag = dic_yaml['tag']
60
+ val_scores += [[epoch, step, loss, tag]]
61
+ sorted_val_scores = sorted(val_scores,
62
+ key=lambda x: x[2],
63
+ reverse=False)
64
+ print("best val (epoch, step, loss, tag) = " +
65
+ str(sorted_val_scores[:args.num]))
66
+ path_list = [
67
+ args.src_path + '/epoch_{}_whole.pt'.format(score[0])
68
+ for score in sorted_val_scores[:args.num]
69
+ ]
70
+ print(path_list)
71
+ avg = {}
72
+ num = args.num
73
+ assert num == len(path_list)
74
+ for path in path_list:
75
+ print('Processing {}'.format(path))
76
+ states = torch.load(path, map_location=torch.device('cpu'))
77
+ for k in states.keys():
78
+ if k not in ['step', 'epoch']:
79
+ if k not in avg.keys():
80
+ avg[k] = states[k].clone()
81
+ else:
82
+ avg[k] += states[k]
83
+ # average
84
+ for k in avg.keys():
85
+ if avg[k] is not None:
86
+ # pytorch 1.6 use true_divide instead of /=
87
+ avg[k] = torch.true_divide(avg[k], num)
88
+ print('Saving to {}'.format(args.dst_model))
89
+ torch.save(avg, args.dst_model)
90
+
91
+
92
+ if __name__ == '__main__':
93
+ main()
cosyvoice/bin/export_jit.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import sys
22
+ import torch
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append('{}/../..'.format(ROOT_DIR))
25
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
27
+ from cosyvoice.utils.file_utils import logging
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser(description='export your model for deployment')
32
+ parser.add_argument('--model_dir',
33
+ type=str,
34
+ default='pretrained_models/CosyVoice-300M',
35
+ help='local path')
36
+ args = parser.parse_args()
37
+ print(args)
38
+ return args
39
+
40
+
41
+ def get_optimized_script(model, preserved_attrs=[]):
42
+ script = torch.jit.script(model)
43
+ if preserved_attrs != []:
44
+ script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
45
+ else:
46
+ script = torch.jit.freeze(script)
47
+ script = torch.jit.optimize_for_inference(script)
48
+ return script
49
+
50
+
51
+ def main():
52
+ args = get_args()
53
+ logging.basicConfig(level=logging.DEBUG,
54
+ format='%(asctime)s %(levelname)s %(message)s')
55
+
56
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
57
+ torch._C._jit_set_profiling_mode(False)
58
+ torch._C._jit_set_profiling_executor(False)
59
+
60
+ try:
61
+ model = CosyVoice(args.model_dir)
62
+ except Exception:
63
+ try:
64
+ model = CosyVoice2(args.model_dir)
65
+ except Exception:
66
+ raise TypeError('no valid model_type!')
67
+
68
+ if not isinstance(model, CosyVoice2):
69
+ # 1. export llm text_encoder
70
+ llm_text_encoder = model.model.llm.text_encoder
71
+ script = get_optimized_script(llm_text_encoder)
72
+ script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
73
+ script = get_optimized_script(llm_text_encoder.half())
74
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
75
+ logging.info('successfully export llm_text_encoder')
76
+
77
+ # 2. export llm llm
78
+ llm_llm = model.model.llm.llm
79
+ script = get_optimized_script(llm_llm, ['forward_chunk'])
80
+ script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
81
+ script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
82
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
83
+ logging.info('successfully export llm_llm')
84
+
85
+ # 3. export flow encoder
86
+ flow_encoder = model.model.flow.encoder
87
+ script = get_optimized_script(flow_encoder)
88
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
89
+ script = get_optimized_script(flow_encoder.half())
90
+ script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
91
+ logging.info('successfully export flow_encoder')
92
+ else:
93
+ # 3. export flow encoder
94
+ flow_encoder = model.model.flow.encoder
95
+ script = get_optimized_script(flow_encoder)
96
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
97
+ script = get_optimized_script(flow_encoder.half())
98
+ script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
99
+ logging.info('successfully export flow_encoder')
100
+
101
+
102
+ if __name__ == '__main__':
103
+ main()
cosyvoice/bin/export_onnx.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import print_function
17
+
18
+ import argparse
19
+ import logging
20
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
+ import os
22
+ import sys
23
+ import onnxruntime
24
+ import random
25
+ import torch
26
+ from tqdm import tqdm
27
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28
+ sys.path.append('{}/../..'.format(ROOT_DIR))
29
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
31
+ from cosyvoice.utils.file_utils import logging
32
+
33
+
34
+ def get_dummy_input(batch_size, seq_len, out_channels, device):
35
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
36
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
37
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
38
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
39
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
40
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
41
+ return x, mask, mu, t, spks, cond
42
+
43
+
44
+ def get_args():
45
+ parser = argparse.ArgumentParser(description='export your model for deployment')
46
+ parser.add_argument('--model_dir',
47
+ type=str,
48
+ default='pretrained_models/CosyVoice-300M',
49
+ help='local path')
50
+ args = parser.parse_args()
51
+ print(args)
52
+ return args
53
+
54
+
55
+ @torch.no_grad()
56
+ def main():
57
+ args = get_args()
58
+ logging.basicConfig(level=logging.DEBUG,
59
+ format='%(asctime)s %(levelname)s %(message)s')
60
+
61
+ try:
62
+ model = CosyVoice(args.model_dir)
63
+ except Exception:
64
+ try:
65
+ model = CosyVoice2(args.model_dir)
66
+ except Exception:
67
+ raise TypeError('no valid model_type!')
68
+
69
+ # 1. export flow decoder estimator
70
+ estimator = model.model.flow.decoder.estimator
71
+ estimator.eval()
72
+
73
+ device = model.model.device
74
+ batch_size, seq_len = 2, 256
75
+ out_channels = model.model.flow.decoder.estimator.out_channels
76
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
77
+ torch.onnx.export(
78
+ estimator,
79
+ (x, mask, mu, t, spks, cond),
80
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
81
+ export_params=True,
82
+ opset_version=18,
83
+ do_constant_folding=True,
84
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
85
+ output_names=['estimator_out'],
86
+ dynamic_axes={
87
+ 'x': {2: 'seq_len'},
88
+ 'mask': {2: 'seq_len'},
89
+ 'mu': {2: 'seq_len'},
90
+ 'cond': {2: 'seq_len'},
91
+ 'estimator_out': {2: 'seq_len'},
92
+ }
93
+ )
94
+
95
+ # 2. test computation consistency
96
+ option = onnxruntime.SessionOptions()
97
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
98
+ option.intra_op_num_threads = 1
99
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
100
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
101
+ sess_options=option, providers=providers)
102
+
103
+ for _ in tqdm(range(10)):
104
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
105
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
106
+ ort_inputs = {
107
+ 'x': x.cpu().numpy(),
108
+ 'mask': mask.cpu().numpy(),
109
+ 'mu': mu.cpu().numpy(),
110
+ 't': t.cpu().numpy(),
111
+ 'spks': spks.cpu().numpy(),
112
+ 'cond': cond.cpu().numpy()
113
+ }
114
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
115
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
116
+ logging.info('successfully export estimator')
117
+
118
+
119
+ if __name__ == "__main__":
120
+ main()
cosyvoice/bin/inference.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import torch
22
+ from torch.utils.data import DataLoader
23
+ import torchaudio
24
+ from hyperpyyaml import load_hyperpyyaml
25
+ from tqdm import tqdm
26
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
27
+ from cosyvoice.dataset.dataset import Dataset
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser(description='inference with your model')
32
+ parser.add_argument('--config', required=True, help='config file')
33
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
34
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
35
+ parser.add_argument('--tts_text', required=True, help='tts input file')
36
+ parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
37
+ parser.add_argument('--llm_model', required=True, help='llm model file')
38
+ parser.add_argument('--flow_model', required=True, help='flow model file')
39
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
40
+ parser.add_argument('--gpu',
41
+ type=int,
42
+ default=-1,
43
+ help='gpu id for this rank, -1 for cpu')
44
+ parser.add_argument('--mode',
45
+ default='sft',
46
+ choices=['sft', 'zero_shot'],
47
+ help='inference mode')
48
+ parser.add_argument('--result_dir', required=True, help='asr result file')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
59
+
60
+ # Init cosyvoice models from configs
61
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
62
+ device = torch.device('cuda' if use_cuda else 'cpu')
63
+ try:
64
+ with open(args.config, 'r') as f:
65
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
66
+ model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
67
+ except Exception:
68
+ try:
69
+ with open(args.config, 'r') as f:
70
+ configs = load_hyperpyyaml(f)
71
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
72
+ except Exception:
73
+ raise TypeError('no valid model_type!')
74
+
75
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
76
+
77
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
78
+ tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
79
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
80
+
81
+ sample_rate = configs['sample_rate']
82
+ del configs
83
+ os.makedirs(args.result_dir, exist_ok=True)
84
+ fn = os.path.join(args.result_dir, 'wav.scp')
85
+ f = open(fn, 'w')
86
+ with torch.no_grad():
87
+ for _, batch in tqdm(enumerate(test_data_loader)):
88
+ utts = batch["utts"]
89
+ assert len(utts) == 1, "inference mode only support batchsize 1"
90
+ text_token = batch["text_token"].to(device)
91
+ text_token_len = batch["text_token_len"].to(device)
92
+ tts_index = batch["tts_index"]
93
+ tts_text_token = batch["tts_text_token"].to(device)
94
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
95
+ speech_token = batch["speech_token"].to(device)
96
+ speech_token_len = batch["speech_token_len"].to(device)
97
+ speech_feat = batch["speech_feat"].to(device)
98
+ speech_feat_len = batch["speech_feat_len"].to(device)
99
+ utt_embedding = batch["utt_embedding"].to(device)
100
+ spk_embedding = batch["spk_embedding"].to(device)
101
+ if args.mode == 'sft':
102
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
103
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
104
+ else:
105
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
106
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
107
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
108
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
109
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
110
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
111
+ tts_speeches = []
112
+ for model_output in model.tts(**model_input):
113
+ tts_speeches.append(model_output['tts_speech'])
114
+ tts_speeches = torch.concat(tts_speeches, dim=1)
115
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
116
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
117
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
118
+ f.write('{} {}\n'.format(tts_key, tts_fn))
119
+ f.flush()
120
+ f.close()
121
+ logging.info('Result wav.scp saved in {}'.format(fn))
122
+
123
+
124
+ if __name__ == '__main__':
125
+ main()
cosyvoice/bin/train.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import os
22
+ import torch
23
+ import torch.distributed as dist
24
+ import deepspeed
25
+
26
+ from hyperpyyaml import load_hyperpyyaml
27
+
28
+ from torch.distributed.elastic.multiprocessing.errors import record
29
+
30
+ from cosyvoice.utils.executor import Executor
31
+ from cosyvoice.utils.train_utils import (
32
+ init_distributed,
33
+ init_dataset_and_dataloader,
34
+ init_optimizer_and_scheduler,
35
+ init_summarywriter, save_model,
36
+ wrap_cuda_model, check_modify_and_save_config)
37
+
38
+
39
+ def get_args():
40
+ parser = argparse.ArgumentParser(description='training your network')
41
+ parser.add_argument('--train_engine',
42
+ default='torch_ddp',
43
+ choices=['torch_ddp', 'deepspeed'],
44
+ help='Engine for paralleled training')
45
+ parser.add_argument('--model', required=True, help='model which will be trained')
46
+ parser.add_argument('--config', required=True, help='config file')
47
+ parser.add_argument('--train_data', required=True, help='train data file')
48
+ parser.add_argument('--cv_data', required=True, help='cv data file')
49
+ parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
50
+ parser.add_argument('--checkpoint', help='checkpoint model')
51
+ parser.add_argument('--model_dir', required=True, help='save model dir')
52
+ parser.add_argument('--tensorboard_dir',
53
+ default='tensorboard',
54
+ help='tensorboard log dir')
55
+ parser.add_argument('--ddp.dist_backend',
56
+ dest='dist_backend',
57
+ default='nccl',
58
+ choices=['nccl', 'gloo'],
59
+ help='distributed backend')
60
+ parser.add_argument('--num_workers',
61
+ default=0,
62
+ type=int,
63
+ help='num of subprocess workers for reading')
64
+ parser.add_argument('--prefetch',
65
+ default=100,
66
+ type=int,
67
+ help='prefetch number')
68
+ parser.add_argument('--pin_memory',
69
+ action='store_true',
70
+ default=False,
71
+ help='Use pinned memory buffers used for reading')
72
+ parser.add_argument('--use_amp',
73
+ action='store_true',
74
+ default=False,
75
+ help='Use automatic mixed precision training')
76
+ parser.add_argument('--deepspeed.save_states',
77
+ dest='save_states',
78
+ default='model_only',
79
+ choices=['model_only', 'model+optimizer'],
80
+ help='save model/optimizer states')
81
+ parser.add_argument('--timeout',
82
+ default=60,
83
+ type=int,
84
+ help='timeout (in seconds) of cosyvoice_join.')
85
+ parser = deepspeed.add_config_arguments(parser)
86
+ args = parser.parse_args()
87
+ return args
88
+
89
+
90
+ @record
91
+ def main():
92
+ args = get_args()
93
+ logging.basicConfig(level=logging.DEBUG,
94
+ format='%(asctime)s %(levelname)s %(message)s')
95
+ # gan train has some special initialization logic
96
+ gan = True if args.model == 'hifigan' else False
97
+
98
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
99
+ if gan is True:
100
+ override_dict.pop('hift')
101
+ try:
102
+ with open(args.config, 'r') as f:
103
+ configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
104
+ except Exception:
105
+ with open(args.config, 'r') as f:
106
+ configs = load_hyperpyyaml(f, overrides=override_dict)
107
+ if gan is True:
108
+ configs['train_conf'] = configs['train_conf_gan']
109
+ configs['train_conf'].update(vars(args))
110
+
111
+ # Init env for ddp
112
+ init_distributed(args)
113
+
114
+ # Get dataset & dataloader
115
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
116
+ init_dataset_and_dataloader(args, configs, gan)
117
+
118
+ # Do some sanity checks and save config to arsg.model_dir
119
+ configs = check_modify_and_save_config(args, configs)
120
+
121
+ # Tensorboard summary
122
+ writer = init_summarywriter(args)
123
+
124
+ # load checkpoint
125
+ model = configs[args.model]
126
+ start_step, start_epoch = 0, -1
127
+ if args.checkpoint is not None:
128
+ if os.path.exists(args.checkpoint):
129
+ state_dict = torch.load(args.checkpoint, map_location='cpu')
130
+ model.load_state_dict(state_dict, strict=False)
131
+ if 'step' in state_dict:
132
+ start_step = state_dict['step']
133
+ if 'epoch' in state_dict:
134
+ start_epoch = state_dict['epoch']
135
+ else:
136
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
137
+
138
+ # Dispatch model from cpu to gpu
139
+ model = wrap_cuda_model(args, model)
140
+
141
+ # Get optimizer & scheduler
142
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
143
+ scheduler.set_step(start_step)
144
+ if scheduler_d is not None:
145
+ scheduler_d.set_step(start_step)
146
+
147
+ # Save init checkpoints
148
+ info_dict = deepcopy(configs['train_conf'])
149
+ info_dict['step'] = start_step
150
+ info_dict['epoch'] = start_epoch
151
+ save_model(model, 'init', info_dict)
152
+
153
+ # Get executor
154
+ executor = Executor(gan=gan)
155
+ executor.step = start_step
156
+
157
+ # Init scaler, used for pytorch amp mixed precision training
158
+ scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
159
+ print('start step {} start epoch {}'.format(start_step, start_epoch))
160
+ # Start training loop
161
+ for epoch in range(start_epoch + 1, info_dict['max_epoch']):
162
+ executor.epoch = epoch
163
+ train_dataset.set_epoch(epoch)
164
+ dist.barrier()
165
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
166
+ if gan is True:
167
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
168
+ writer, info_dict, scaler, group_join)
169
+ else:
170
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
171
+ dist.destroy_process_group(group_join)
172
+
173
+
174
+ if __name__ == '__main__':
175
+ main()
cosyvoice/bin/train_dpo.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import os
22
+ import torch
23
+ import torch.distributed as dist
24
+ import deepspeed
25
+
26
+ from hyperpyyaml import load_hyperpyyaml
27
+
28
+ from torch.distributed.elastic.multiprocessing.errors import record
29
+
30
+ from cosyvoice.utils.executor_dpo import Executor
31
+ from cosyvoice.utils.train_utils_dpo import (
32
+ init_distributed,
33
+ init_dataset_and_dataloader,
34
+ init_optimizer_and_scheduler,
35
+ init_summarywriter, save_model,
36
+ wrap_cuda_model, check_modify_and_save_config)
37
+
38
+
39
+ def get_args():
40
+ parser = argparse.ArgumentParser(description='training your network')
41
+ parser.add_argument('--train_engine',
42
+ default='torch_ddp',
43
+ choices=['torch_ddp', 'deepspeed'],
44
+ help='Engine for paralleled training')
45
+ parser.add_argument('--model', required=True, help='model which will be trained')
46
+ parser.add_argument('--config', required=True, help='config file')
47
+ parser.add_argument('--train_data', required=True, help='train data file')
48
+ parser.add_argument('--cv_data', required=True, help='cv data file')
49
+ parser.add_argument('--checkpoint', help='checkpoint model')
50
+ parser.add_argument('--model_dir', required=True, help='save model dir')
51
+ parser.add_argument('--tensorboard_dir',
52
+ default='tensorboard',
53
+ help='tensorboard log dir')
54
+ parser.add_argument('--ddp.dist_backend',
55
+ dest='dist_backend',
56
+ default='nccl',
57
+ choices=['nccl', 'gloo'],
58
+ help='distributed backend')
59
+ parser.add_argument('--num_workers',
60
+ default=0,
61
+ type=int,
62
+ help='num of subprocess workers for reading')
63
+ parser.add_argument('--prefetch',
64
+ default=100,
65
+ type=int,
66
+ help='prefetch number')
67
+ parser.add_argument('--pin_memory',
68
+ action='store_true',
69
+ default=False,
70
+ help='Use pinned memory buffers used for reading')
71
+ parser.add_argument('--use_amp',
72
+ action='store_true',
73
+ default=False,
74
+ help='Use automatic mixed precision training')
75
+ parser.add_argument('--deepspeed.save_states',
76
+ dest='save_states',
77
+ default='model_only',
78
+ choices=['model_only', 'model+optimizer'],
79
+ help='save model/optimizer states')
80
+ parser.add_argument('--timeout',
81
+ default=60,
82
+ type=int,
83
+ help='timeout (in seconds) of cosyvoice_join.')
84
+ parser.add_argument('--dpo',
85
+ action='store_true',
86
+ default=False,
87
+ help='Use Direct Preference Optimization')
88
+ parser.add_argument('--beta',
89
+ default=0.01,
90
+ type=float,
91
+ help='beta of dpo training')
92
+ parser = deepspeed.add_config_arguments(parser)
93
+ args = parser.parse_args()
94
+ return args
95
+
96
+
97
+ @record
98
+ def main():
99
+ args = get_args()
100
+ logging.basicConfig(level=logging.DEBUG,
101
+ format='%(asctime)s %(levelname)s %(message)s')
102
+ # gan train has some special initialization logic
103
+ gan = True if args.model == 'hifigan' else False
104
+
105
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
106
+ if gan is True:
107
+ override_dict.pop('hift')
108
+ with open(args.config, 'r') as f:
109
+ configs = load_hyperpyyaml(f, overrides=override_dict)
110
+ if gan is True:
111
+ configs['train_conf'] = configs['train_conf_gan']
112
+ configs['train_conf'].update(vars(args))
113
+
114
+ # Init env for ddp
115
+ init_distributed(args)
116
+
117
+ # Get dataset & dataloader
118
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
119
+ init_dataset_and_dataloader(args, configs, gan)
120
+
121
+ # Do some sanity checks and save config to arsg.model_dir
122
+ configs = check_modify_and_save_config(args, configs)
123
+
124
+ # Tensorboard summary
125
+ writer = init_summarywriter(args)
126
+
127
+ # load checkpoint
128
+ model = configs[args.model]
129
+ ref_model = None
130
+ if args.dpo:
131
+ ref_model = deepcopy(model)
132
+ start_step, start_epoch = 0, -1
133
+ if args.checkpoint is not None:
134
+ if os.path.exists(args.checkpoint):
135
+ state_dict = torch.load(args.checkpoint, map_location='cpu')
136
+ model.load_state_dict(state_dict, strict=False)
137
+ if args.dpo:
138
+ ref_model.load_state_dict(state_dict, strict=False)
139
+ if 'step' in state_dict:
140
+ start_step = state_dict['step']
141
+ if 'epoch' in state_dict:
142
+ start_epoch = state_dict['epoch']
143
+ else:
144
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
145
+
146
+ # Dispatch model from cpu to gpu
147
+ model = wrap_cuda_model(args, model)
148
+ if args.dpo:
149
+ ref_model = wrap_cuda_model(args, ref_model)
150
+
151
+ # Get optimizer & scheduler
152
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
153
+ if args.dpo:
154
+ ref_model, _, _, _, _ = init_optimizer_and_scheduler(args, configs, ref_model, gan)
155
+ scheduler.set_step(start_step)
156
+ if scheduler_d is not None:
157
+ scheduler_d.set_step(start_step)
158
+
159
+ # Save init checkpoints
160
+ info_dict = deepcopy(configs['train_conf'])
161
+ info_dict['step'] = start_step
162
+ info_dict['epoch'] = start_epoch
163
+ save_model(model, 'init', info_dict)
164
+
165
+ # Get executor
166
+ executor = Executor(gan=gan, dpo=args.dpo, beta=args.beta)
167
+ executor.step = start_step
168
+
169
+ # Init scaler, used for pytorch amp mixed precision training
170
+ scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
171
+ print('start step {} start epoch {}'.format(start_step, start_epoch))
172
+ # Start training loop
173
+ for epoch in range(start_epoch + 1, info_dict['max_epoch']):
174
+ executor.epoch = epoch
175
+ train_dataset.set_epoch(epoch)
176
+ dist.barrier()
177
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
178
+ if gan is True:
179
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
180
+ writer, info_dict, scaler, group_join)
181
+ else:
182
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model)
183
+ dist.destroy_process_group(group_join)
184
+
185
+
186
+ if __name__ == '__main__':
187
+ main()
cosyvoice/cli/__init__.py ADDED
File without changes
cosyvoice/cli/cosyvoice.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import time
16
+ from typing import Generator
17
+ from tqdm import tqdm
18
+ from hyperpyyaml import load_hyperpyyaml
19
+ from modelscope import snapshot_download
20
+ import torch
21
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
22
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
23
+ from cosyvoice.utils.file_utils import logging
24
+ from cosyvoice.utils.class_utils import get_model_type
25
+
26
+
27
+ class CosyVoice:
28
+
29
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
30
+ self.instruct = True if '-Instruct' in model_dir else False
31
+ self.model_dir = model_dir
32
+ self.fp16 = fp16
33
+ if not os.path.exists(model_dir):
34
+ model_dir = snapshot_download(model_dir)
35
+ hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
36
+ if not os.path.exists(hyper_yaml_path):
37
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
38
+ with open(hyper_yaml_path, 'r') as f:
39
+ configs = load_hyperpyyaml(f)
40
+ assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
41
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
42
+ configs['feat_extractor'],
43
+ '{}/campplus.onnx'.format(model_dir),
44
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
45
+ '{}/spk2info.pt'.format(model_dir),
46
+ configs['allowed_special'])
47
+ self.sample_rate = configs['sample_rate']
48
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
49
+ load_jit, load_trt, fp16 = False, False, False
50
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
51
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent)
52
+ self.model.load('{}/llm.pt'.format(model_dir),
53
+ '{}/flow.pt'.format(model_dir),
54
+ '{}/hift.pt'.format(model_dir))
55
+ if load_jit:
56
+ self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
57
+ '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
58
+ '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
59
+ if load_trt:
60
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
61
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
62
+ self.fp16)
63
+ del configs
64
+
65
+ def list_available_spks(self):
66
+ spks = list(self.frontend.spk2info.keys())
67
+ return spks
68
+
69
+ def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id):
70
+ assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
71
+ model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_speech_16k, self.sample_rate, '')
72
+ del model_input['text']
73
+ del model_input['text_len']
74
+ self.frontend.spk2info[zero_shot_spk_id] = model_input
75
+ return True
76
+
77
+ def save_spkinfo(self):
78
+ torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
79
+
80
+ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
81
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
82
+ model_input = self.frontend.frontend_sft(i, spk_id)
83
+ start_time = time.time()
84
+ logging.info('synthesis text {}'.format(i))
85
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
86
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
87
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
88
+ yield model_output
89
+ start_time = time.time()
90
+
91
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
92
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
93
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
94
+ if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
95
+ logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
96
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
97
+ start_time = time.time()
98
+ logging.info('synthesis text {}'.format(i))
99
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
100
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
101
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
102
+ yield model_output
103
+ start_time = time.time()
104
+
105
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
106
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
107
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
108
+ start_time = time.time()
109
+ logging.info('synthesis text {}'.format(i))
110
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
111
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
112
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
113
+ yield model_output
114
+ start_time = time.time()
115
+
116
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
117
+ assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
118
+ if self.instruct is False:
119
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
120
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
121
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
122
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
123
+ start_time = time.time()
124
+ logging.info('synthesis text {}'.format(i))
125
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
126
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
127
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
128
+ yield model_output
129
+ start_time = time.time()
130
+
131
+ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
132
+ model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
133
+ start_time = time.time()
134
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
135
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
136
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
137
+ yield model_output
138
+ start_time = time.time()
139
+
140
+
141
+ class CosyVoice2(CosyVoice):
142
+
143
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
144
+ self.instruct = True if '-Instruct' in model_dir else False
145
+ self.model_dir = model_dir
146
+ self.fp16 = fp16
147
+ if not os.path.exists(model_dir):
148
+ model_dir = snapshot_download(model_dir)
149
+ hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
150
+ if not os.path.exists(hyper_yaml_path):
151
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
152
+ with open(hyper_yaml_path, 'r') as f:
153
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
154
+ assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
155
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
156
+ configs['feat_extractor'],
157
+ '{}/campplus.onnx'.format(model_dir),
158
+ '{}/speech_tokenizer_v2.onnx'.format(model_dir),
159
+ '{}/spk2info.pt'.format(model_dir),
160
+ configs['allowed_special'])
161
+ self.sample_rate = configs['sample_rate']
162
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
163
+ load_jit, load_trt, fp16 = False, False, False
164
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
165
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent)
166
+ self.model.load('{}/llm.pt'.format(model_dir),
167
+ '{}/flow.pt'.format(model_dir),
168
+ '{}/hift.pt'.format(model_dir))
169
+ if load_jit:
170
+ self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
171
+ if load_trt:
172
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
173
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
174
+ self.fp16)
175
+ del configs
176
+
177
+ def inference_instruct(self, *args, **kwargs):
178
+ raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
179
+
180
+ def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
181
+ assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
182
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
183
+ model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
184
+ start_time = time.time()
185
+ logging.info('synthesis text {}'.format(i))
186
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
187
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
188
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
189
+ yield model_output
190
+ start_time = time.time()
cosyvoice/cli/frontend.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ from typing import Generator
16
+ import json
17
+ import onnxruntime
18
+ import torch
19
+ import numpy as np
20
+ import whisper
21
+ from typing import Callable
22
+ import torchaudio.compliance.kaldi as kaldi
23
+ import torchaudio
24
+ import os
25
+ import re
26
+ import inflect
27
+ try:
28
+ import ttsfrd
29
+ use_ttsfrd = True
30
+ except ImportError:
31
+ print("failed to import ttsfrd, use WeTextProcessing instead")
32
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
33
+ from tn.english.normalizer import Normalizer as EnNormalizer
34
+ use_ttsfrd = False
35
+ from cosyvoice.utils.file_utils import logging
36
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
37
+
38
+
39
+ class CosyVoiceFrontEnd:
40
+
41
+ def __init__(self,
42
+ get_tokenizer: Callable,
43
+ feat_extractor: Callable,
44
+ campplus_model: str,
45
+ speech_tokenizer_model: str,
46
+ spk2info: str = '',
47
+ allowed_special: str = 'all'):
48
+ self.tokenizer = get_tokenizer()
49
+ self.feat_extractor = feat_extractor
50
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
+ option = onnxruntime.SessionOptions()
52
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
53
+ option.intra_op_num_threads = 1
54
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
55
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
56
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
57
+ "CPUExecutionProvider"])
58
+ if os.path.exists(spk2info):
59
+ self.spk2info = torch.load(spk2info, map_location=self.device)
60
+ else:
61
+ self.spk2info = {}
62
+ self.allowed_special = allowed_special
63
+ self.use_ttsfrd = use_ttsfrd
64
+ if self.use_ttsfrd:
65
+ self.frd = ttsfrd.TtsFrontendEngine()
66
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
67
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
68
+ 'failed to initialize ttsfrd resource'
69
+ self.frd.set_lang_type('pinyinvg')
70
+ else:
71
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
72
+ self.en_tn_model = EnNormalizer()
73
+ self.inflect_parser = inflect.engine()
74
+
75
+ def _extract_text_token(self, text):
76
+ if isinstance(text, Generator):
77
+ logging.info('get tts_text generator, will return _extract_text_token_generator!')
78
+ # NOTE add a dummy text_token_len for compatibility
79
+ return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
80
+ else:
81
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
82
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
83
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
84
+ return text_token, text_token_len
85
+
86
+ def _extract_text_token_generator(self, text_generator):
87
+ for text in text_generator:
88
+ text_token, _ = self._extract_text_token(text)
89
+ for i in range(text_token.shape[1]):
90
+ yield text_token[:, i: i + 1]
91
+
92
+ def _extract_speech_token(self, speech):
93
+ assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
94
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
95
+ speech_token = self.speech_tokenizer_session.run(None,
96
+ {self.speech_tokenizer_session.get_inputs()[0].name:
97
+ feat.detach().cpu().numpy(),
98
+ self.speech_tokenizer_session.get_inputs()[1].name:
99
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
100
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
101
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
102
+ return speech_token, speech_token_len
103
+
104
+ def _extract_spk_embedding(self, speech):
105
+ feat = kaldi.fbank(speech,
106
+ num_mel_bins=80,
107
+ dither=0,
108
+ sample_frequency=16000)
109
+ feat = feat - feat.mean(dim=0, keepdim=True)
110
+ embedding = self.campplus_session.run(None,
111
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
112
+ embedding = torch.tensor([embedding]).to(self.device)
113
+ return embedding
114
+
115
+ def _extract_speech_feat(self, speech):
116
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
117
+ speech_feat = speech_feat.unsqueeze(dim=0)
118
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
119
+ return speech_feat, speech_feat_len
120
+
121
+ def text_normalize(self, text, split=True, text_frontend=True):
122
+ if isinstance(text, Generator):
123
+ logging.info('get tts_text generator, will skip text_normalize!')
124
+ return [text]
125
+ if text_frontend is False or text == '':
126
+ return [text] if split is True else text
127
+ text = text.strip()
128
+ if self.use_ttsfrd:
129
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
130
+ text = ''.join(texts)
131
+ else:
132
+ if contains_chinese(text):
133
+ text = self.zh_tn_model.normalize(text)
134
+ text = text.replace("\n", "")
135
+ text = replace_blank(text)
136
+ text = replace_corner_mark(text)
137
+ text = text.replace(".", "。")
138
+ text = text.replace(" - ", ",")
139
+ text = remove_bracket(text)
140
+ text = re.sub(r'[,,、]+$', '。', text)
141
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
142
+ token_min_n=60, merge_len=20, comma_split=False))
143
+ else:
144
+ text = self.en_tn_model.normalize(text)
145
+ text = spell_out_number(text, self.inflect_parser)
146
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
147
+ token_min_n=60, merge_len=20, comma_split=False))
148
+ texts = [i for i in texts if not is_only_punctuation(i)]
149
+ return texts if split is True else text
150
+
151
+ def frontend_sft(self, tts_text, spk_id):
152
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
153
+ embedding = self.spk2info[spk_id]['embedding']
154
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
155
+ return model_input
156
+
157
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
158
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
159
+ if zero_shot_spk_id == '':
160
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
161
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
162
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
163
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
164
+ if resample_rate == 24000:
165
+ # cosyvoice2, force speech_feat % speech_token = 2
166
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
167
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
168
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
169
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
170
+ model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
171
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
172
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
173
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
174
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
175
+ else:
176
+ model_input = self.spk2info[zero_shot_spk_id]
177
+ model_input['text'] = tts_text_token
178
+ model_input['text_len'] = tts_text_token_len
179
+ return model_input
180
+
181
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
182
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
183
+ # in cross lingual mode, we remove prompt in llm
184
+ del model_input['prompt_text']
185
+ del model_input['prompt_text_len']
186
+ del model_input['llm_prompt_speech_token']
187
+ del model_input['llm_prompt_speech_token_len']
188
+ return model_input
189
+
190
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
191
+ model_input = self.frontend_sft(tts_text, spk_id)
192
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
193
+ del model_input['llm_embedding']
194
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
195
+ model_input['prompt_text'] = instruct_text_token
196
+ model_input['prompt_text_len'] = instruct_text_token_len
197
+ return model_input
198
+
199
+ def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
200
+ model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
201
+ del model_input['llm_prompt_speech_token']
202
+ del model_input['llm_prompt_speech_token_len']
203
+ return model_input
204
+
205
+ def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
206
+ prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
207
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
208
+ prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
209
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
210
+ source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
211
+ model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
212
+ 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
213
+ 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
214
+ 'flow_embedding': embedding}
215
+ return model_input
cosyvoice/cli/model.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ from typing import Generator
17
+ import queue
18
+ import torch
19
+ import numpy as np
20
+ import threading
21
+ import time
22
+ from torch.nn import functional as F
23
+ from contextlib import nullcontext
24
+ import uuid
25
+ from cosyvoice.utils.common import fade_in_out
26
+ from cosyvoice.utils.file_utils import convert_onnx_to_trt
27
+ from cosyvoice.utils.common import TrtContextWrapper
28
+
29
+
30
+ class CosyVoiceModel:
31
+
32
+ def __init__(self,
33
+ llm: torch.nn.Module,
34
+ flow: torch.nn.Module,
35
+ hift: torch.nn.Module,
36
+ fp16: bool = False,
37
+ trt_concurrent: int = 1):
38
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ self.llm = llm
40
+ self.flow = flow
41
+ self.hift = hift
42
+ self.fp16 = fp16
43
+ self.trt_concurrent = trt_concurrent
44
+ if self.fp16 is True:
45
+ self.llm.half()
46
+ self.flow.half()
47
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
48
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
49
+ self.token_overlap_len = 20
50
+ # mel fade in out
51
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
52
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
53
+ # hift cache
54
+ self.mel_cache_len = 20
55
+ self.source_cache_len = int(self.mel_cache_len * 256)
56
+ # speech fade in out
57
+ self.speech_window = np.hamming(2 * self.source_cache_len)
58
+ # rtf and decoding related
59
+ self.stream_scale_factor = 1
60
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
61
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
62
+ self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
63
+ for _ in range(trt_concurrent):
64
+ self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
65
+ self.lock = threading.Lock()
66
+ # dict used to store session related variable
67
+ self.tts_speech_token_dict = {}
68
+ self.llm_end_dict = {}
69
+ self.mel_overlap_dict = {}
70
+ self.flow_cache_dict = {}
71
+ self.hift_cache_dict = {}
72
+ self.trt_context_dict = {}
73
+
74
+ def load(self, llm_model, flow_model, hift_model):
75
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
76
+ self.llm.to(self.device).eval()
77
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
78
+ self.flow.to(self.device).eval()
79
+ # in case hift_model is a hifigan model
80
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
81
+ self.hift.load_state_dict(hift_state_dict, strict=True)
82
+ self.hift.to(self.device).eval()
83
+
84
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
85
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
86
+ self.llm.text_encoder = llm_text_encoder
87
+ llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
88
+ self.llm.llm = llm_llm
89
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
90
+ self.flow.encoder = flow_encoder
91
+
92
+ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
93
+ assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
94
+ if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
95
+ convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
96
+ del self.flow.decoder.estimator
97
+ import tensorrt as trt
98
+ with open(flow_decoder_estimator_model, 'rb') as f:
99
+ estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
100
+ assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
101
+ self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
102
+
103
+ def get_trt_kwargs(self):
104
+ min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
105
+ opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
106
+ max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
107
+ input_names = ["x", "mask", "mu", "cond"]
108
+ return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
109
+
110
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
111
+ with self.llm_context, torch.cuda.amp.autocast(self.fp16):
112
+ if isinstance(text, Generator):
113
+ assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
114
+ for i in self.llm.inference_bistream(text=text,
115
+ prompt_text=prompt_text.to(self.device),
116
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
117
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
118
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
119
+ embedding=llm_embedding.to(self.device)):
120
+ self.tts_speech_token_dict[uuid].append(i)
121
+ else:
122
+ for i in self.llm.inference(text=text.to(self.device),
123
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
124
+ prompt_text=prompt_text.to(self.device),
125
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
126
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
127
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
128
+ embedding=llm_embedding.to(self.device)):
129
+ self.tts_speech_token_dict[uuid].append(i)
130
+ self.llm_end_dict[uuid] = True
131
+
132
+ def vc_job(self, source_speech_token, uuid):
133
+ self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
134
+ self.llm_end_dict[uuid] = True
135
+
136
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
137
+ with torch.cuda.amp.autocast(self.fp16):
138
+ tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
139
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
140
+ prompt_token=prompt_token.to(self.device),
141
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
142
+ prompt_feat=prompt_feat.to(self.device),
143
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
144
+ embedding=embedding.to(self.device),
145
+ flow_cache=self.flow_cache_dict[uuid])
146
+
147
+ # mel overlap fade in out
148
+ if self.mel_overlap_dict[uuid].shape[2] != 0:
149
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
150
+ # append hift cache
151
+ if self.hift_cache_dict[uuid] is not None:
152
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
153
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
154
+ else:
155
+ hift_cache_source = torch.zeros(1, 1, 0)
156
+ # keep overlap mel and hift cache
157
+ if finalize is False:
158
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
159
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
160
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
161
+ if self.hift_cache_dict[uuid] is not None:
162
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
163
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
164
+ 'source': tts_source[:, :, -self.source_cache_len:],
165
+ 'speech': tts_speech[:, -self.source_cache_len:]}
166
+ tts_speech = tts_speech[:, :-self.source_cache_len]
167
+ else:
168
+ if speed != 1.0:
169
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
170
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
171
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
172
+ if self.hift_cache_dict[uuid] is not None:
173
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
174
+ return tts_speech
175
+
176
+ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
177
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
178
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
179
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
180
+ prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
181
+ # this_uuid is used to track variables related to this inference thread
182
+ this_uuid = str(uuid.uuid1())
183
+ this_trt_context = self.trt_context_pool.get()
184
+ with self.lock:
185
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
186
+ self.hift_cache_dict[this_uuid] = None
187
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
188
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
189
+ self.trt_context_dict[this_uuid] = this_trt_context
190
+ if source_speech_token.shape[1] == 0:
191
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
192
+ else:
193
+ p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
194
+ p.start()
195
+ if stream is True:
196
+ token_hop_len = self.token_min_hop_len
197
+ while True:
198
+ time.sleep(0.1)
199
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
200
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
201
+ .unsqueeze(dim=0)
202
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
203
+ prompt_token=flow_prompt_speech_token,
204
+ prompt_feat=prompt_speech_feat,
205
+ embedding=flow_embedding,
206
+ uuid=this_uuid,
207
+ finalize=False)
208
+ yield {'tts_speech': this_tts_speech.cpu()}
209
+ with self.lock:
210
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
211
+ # increase token_hop_len for better speech quality
212
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
213
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
214
+ break
215
+ p.join()
216
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
217
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
218
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
219
+ prompt_token=flow_prompt_speech_token,
220
+ prompt_feat=prompt_speech_feat,
221
+ embedding=flow_embedding,
222
+ uuid=this_uuid,
223
+ finalize=True)
224
+ yield {'tts_speech': this_tts_speech.cpu()}
225
+ else:
226
+ # deal with all tokens
227
+ p.join()
228
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
229
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
230
+ prompt_token=flow_prompt_speech_token,
231
+ prompt_feat=prompt_speech_feat,
232
+ embedding=flow_embedding,
233
+ uuid=this_uuid,
234
+ finalize=True,
235
+ speed=speed)
236
+ yield {'tts_speech': this_tts_speech.cpu()}
237
+ with self.lock:
238
+ self.tts_speech_token_dict.pop(this_uuid)
239
+ self.llm_end_dict.pop(this_uuid)
240
+ self.mel_overlap_dict.pop(this_uuid)
241
+ self.hift_cache_dict.pop(this_uuid)
242
+ self.flow_cache_dict.pop(this_uuid)
243
+ self.trt_context_pool.put(self.trt_context_dict[this_uuid])
244
+ self.trt_context_dict.pop(this_uuid)
245
+ if torch.cuda.is_available():
246
+ torch.cuda.empty_cache()
247
+ torch.cuda.current_stream().synchronize()
248
+
249
+
250
+ class CosyVoice2Model(CosyVoiceModel):
251
+
252
+ def __init__(self,
253
+ llm: torch.nn.Module,
254
+ flow: torch.nn.Module,
255
+ hift: torch.nn.Module,
256
+ fp16: bool = False,
257
+ trt_concurrent: int = 1):
258
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
259
+ self.llm = llm
260
+ self.flow = flow
261
+ self.hift = hift
262
+ self.fp16 = fp16
263
+ self.trt_concurrent = trt_concurrent
264
+ if self.fp16 is True:
265
+ self.llm.half()
266
+ self.flow.half()
267
+ # NOTE must matching training static_chunk_size
268
+ self.token_hop_len = 25
269
+ # hift cache
270
+ self.mel_cache_len = 8
271
+ self.source_cache_len = int(self.mel_cache_len * 480)
272
+ # speech fade in out
273
+ self.speech_window = np.hamming(2 * self.source_cache_len)
274
+ # rtf and decoding related
275
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
276
+ self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
277
+ for _ in range(trt_concurrent):
278
+ self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
279
+ self.lock = threading.Lock()
280
+ # dict used to store session related variable
281
+ self.tts_speech_token_dict = {}
282
+ self.llm_end_dict = {}
283
+ self.hift_cache_dict = {}
284
+ self.trt_context_dict = {}
285
+
286
+ def load_jit(self, flow_encoder_model):
287
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
288
+ self.flow.encoder = flow_encoder
289
+
290
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
291
+ with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
292
+ tts_mel, _ = self.flow.inference(token=token.to(self.device),
293
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
294
+ prompt_token=prompt_token.to(self.device),
295
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
296
+ prompt_feat=prompt_feat.to(self.device),
297
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
298
+ embedding=embedding.to(self.device),
299
+ streaming=stream,
300
+ finalize=finalize)
301
+ tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
302
+ # append hift cache
303
+ if self.hift_cache_dict[uuid] is not None:
304
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
305
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
306
+ else:
307
+ hift_cache_source = torch.zeros(1, 1, 0)
308
+ # keep overlap mel and hift cache
309
+ if finalize is False:
310
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
311
+ if self.hift_cache_dict[uuid] is not None:
312
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
313
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
314
+ 'source': tts_source[:, :, -self.source_cache_len:],
315
+ 'speech': tts_speech[:, -self.source_cache_len:]}
316
+ tts_speech = tts_speech[:, :-self.source_cache_len]
317
+ else:
318
+ if speed != 1.0:
319
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
320
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
321
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
322
+ if self.hift_cache_dict[uuid] is not None:
323
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
324
+ return tts_speech
325
+
326
+ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
327
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
328
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
329
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
330
+ prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
331
+ # this_uuid is used to track variables related to this inference thread
332
+ this_uuid = str(uuid.uuid1())
333
+ this_trt_context = self.trt_context_pool.get()
334
+ with self.lock:
335
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
336
+ self.hift_cache_dict[this_uuid] = None
337
+ self.trt_context_dict[this_uuid] = this_trt_context
338
+ if source_speech_token.shape[1] == 0:
339
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
340
+ else:
341
+ p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
342
+ p.start()
343
+ if stream is True:
344
+ token_offset = 0
345
+ prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
346
+ while True:
347
+ time.sleep(0.1)
348
+ this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
349
+ if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
350
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
351
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
352
+ prompt_token=flow_prompt_speech_token,
353
+ prompt_feat=prompt_speech_feat,
354
+ embedding=flow_embedding,
355
+ token_offset=token_offset,
356
+ uuid=this_uuid,
357
+ stream=stream,
358
+ finalize=False)
359
+ token_offset += this_token_hop_len
360
+ yield {'tts_speech': this_tts_speech.cpu()}
361
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
362
+ break
363
+ p.join()
364
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
365
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
366
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
367
+ prompt_token=flow_prompt_speech_token,
368
+ prompt_feat=prompt_speech_feat,
369
+ embedding=flow_embedding,
370
+ token_offset=token_offset,
371
+ uuid=this_uuid,
372
+ finalize=True)
373
+ yield {'tts_speech': this_tts_speech.cpu()}
374
+ else:
375
+ # deal with all tokens
376
+ p.join()
377
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
378
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
379
+ prompt_token=flow_prompt_speech_token,
380
+ prompt_feat=prompt_speech_feat,
381
+ embedding=flow_embedding,
382
+ token_offset=0,
383
+ uuid=this_uuid,
384
+ finalize=True,
385
+ speed=speed)
386
+ yield {'tts_speech': this_tts_speech.cpu()}
387
+ with self.lock:
388
+ self.tts_speech_token_dict.pop(this_uuid)
389
+ self.llm_end_dict.pop(this_uuid)
390
+ self.hift_cache_dict.pop(this_uuid)
391
+ self.trt_context_pool.put(self.trt_context_dict[this_uuid])
392
+ self.trt_context_dict.pop(this_uuid)
393
+ if torch.cuda.is_available():
394
+ torch.cuda.empty_cache()
395
+ torch.cuda.current_stream().synchronize()
cosyvoice/dataset/__init__.py ADDED
File without changes
cosyvoice/dataset/dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import json
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.utils.data import IterableDataset
24
+ from cosyvoice.utils.file_utils import read_lists, read_json_lists
25
+
26
+
27
+ class Processor(IterableDataset):
28
+
29
+ def __init__(self, source, f, *args, **kw):
30
+ assert callable(f)
31
+ self.source = source
32
+ self.f = f
33
+ self.args = args
34
+ self.kw = kw
35
+
36
+ def set_epoch(self, epoch):
37
+ self.source.set_epoch(epoch)
38
+
39
+ def __iter__(self):
40
+ """ Return an iterator over the source dataset processed by the
41
+ given processor.
42
+ """
43
+ assert self.source is not None
44
+ assert callable(self.f)
45
+ return self.f(iter(self.source), *self.args, **self.kw)
46
+
47
+ def apply(self, f):
48
+ assert callable(f)
49
+ return Processor(self, f, *self.args, **self.kw)
50
+
51
+
52
+ class DistributedSampler:
53
+
54
+ def __init__(self, shuffle=True, partition=True):
55
+ self.epoch = -1
56
+ self.update()
57
+ self.shuffle = shuffle
58
+ self.partition = partition
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def sample(self, data):
84
+ """ Sample data according to rank/world_size/num_workers
85
+
86
+ Args:
87
+ data(List): input data list
88
+
89
+ Returns:
90
+ List: data list after sample
91
+ """
92
+ data = list(range(len(data)))
93
+ # force datalist even
94
+ if self.partition:
95
+ if self.shuffle:
96
+ random.Random(self.epoch).shuffle(data)
97
+ if len(data) < self.world_size:
98
+ data = data * math.ceil(self.world_size / len(data))
99
+ data = data[:self.world_size]
100
+ data = data[self.rank::self.world_size]
101
+ if len(data) < self.num_workers:
102
+ data = data * math.ceil(self.num_workers / len(data))
103
+ data = data[:self.num_workers]
104
+ data = data[self.worker_id::self.num_workers]
105
+ return data
106
+
107
+
108
+ class DataList(IterableDataset):
109
+
110
+ def __init__(self, lists, shuffle=True, partition=True):
111
+ self.lists = lists
112
+ self.sampler = DistributedSampler(shuffle, partition)
113
+
114
+ def set_epoch(self, epoch):
115
+ self.sampler.set_epoch(epoch)
116
+
117
+ def __iter__(self):
118
+ sampler_info = self.sampler.update()
119
+ indexes = self.sampler.sample(self.lists)
120
+ for index in indexes:
121
+ data = dict(src=self.lists[index])
122
+ data.update(sampler_info)
123
+ yield data
124
+
125
+
126
+ def Dataset(data_list_file,
127
+ data_pipeline,
128
+ mode='train',
129
+ gan=False,
130
+ shuffle=True,
131
+ partition=True,
132
+ tts_file='',
133
+ prompt_utt2data=''):
134
+ """ Construct dataset from arguments
135
+
136
+ We have two shuffle stage in the Dataset. The first is global
137
+ shuffle at shards tar/raw file level. The second is global shuffle
138
+ at training samples level.
139
+
140
+ Args:
141
+ data_type(str): raw/shard
142
+ tokenizer (BaseTokenizer): tokenizer to tokenize
143
+ partition(bool): whether to do data partition in terms of rank
144
+ """
145
+ assert mode in ['train', 'inference']
146
+ lists = read_lists(data_list_file)
147
+ if mode == 'inference':
148
+ with open(tts_file) as f:
149
+ tts_data = json.load(f)
150
+ utt2lists = read_json_lists(prompt_utt2data)
151
+ # filter unnecessary file in inference mode
152
+ lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
153
+ dataset = DataList(lists,
154
+ shuffle=shuffle,
155
+ partition=partition)
156
+ if mode == 'inference':
157
+ # map partial arg to parquet_opener func in inference mode
158
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
159
+ if gan is True:
160
+ # map partial arg to padding func in gan mode
161
+ data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
162
+ for func in data_pipeline:
163
+ dataset = Processor(dataset, func, mode=mode)
164
+ return dataset
cosyvoice/dataset/processor.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+
17
+ import pyarrow.parquet as pq
18
+ from io import BytesIO
19
+ import torch
20
+ import torchaudio
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ import torch.nn.functional as F
23
+ import pyworld as pw
24
+
25
+
26
+ AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
27
+
28
+
29
+ def parquet_opener(data, mode='train', tts_data={}):
30
+ """ Give url or local file, return file descriptor
31
+ Inplace operation.
32
+
33
+ Args:
34
+ data(Iterable[str]): url or local file list
35
+
36
+ Returns:
37
+ Iterable[{src, stream}]
38
+ """
39
+ for sample in data:
40
+ assert 'src' in sample
41
+ url = sample['src']
42
+ try:
43
+ for df in pq.ParquetFile(url).iter_batches(batch_size=64):
44
+ df = df.to_pandas()
45
+ for i in range(len(df)):
46
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
47
+ continue
48
+ sample.update(dict(df.loc[i]))
49
+ if mode == 'train':
50
+ # NOTE do not return sample directly, must initialize a new dict
51
+ yield {**sample}
52
+ else:
53
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
54
+ yield {**sample, 'tts_index': index, 'tts_text': text}
55
+ except Exception as ex:
56
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
57
+
58
+
59
+ def filter(data,
60
+ max_length=10240,
61
+ min_length=10,
62
+ token_max_length=200,
63
+ token_min_length=1,
64
+ min_output_input_ratio=0.0005,
65
+ max_output_input_ratio=1,
66
+ mode='train'):
67
+ """ Filter sample according to feature and label length
68
+ Inplace operation.
69
+
70
+ Args::
71
+ data: Iterable[{key, wav, label, sample_rate}]
72
+ max_length: drop utterance which is greater than max_length(10ms)
73
+ min_length: drop utterance which is less than min_length(10ms)
74
+ token_max_length: drop utterance which is greater than
75
+ token_max_length, especially when use char unit for
76
+ english modeling
77
+ token_min_length: drop utterance which is
78
+ less than token_max_length
79
+ min_output_input_ratio: minimal ration of
80
+ token_length / feats_length(10ms)
81
+ max_output_input_ratio: maximum ration of
82
+ token_length / feats_length(10ms)
83
+
84
+ Returns:
85
+ Iterable[{key, wav, label, sample_rate}]
86
+ """
87
+ for sample in data:
88
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
89
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
90
+ del sample['audio_data']
91
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
92
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
93
+ if num_frames < min_length:
94
+ continue
95
+ if num_frames > max_length:
96
+ continue
97
+ if len(sample['text_token']) < token_min_length:
98
+ continue
99
+ if len(sample['text_token']) > token_max_length:
100
+ continue
101
+ if len(sample['speech_token']) == 0:
102
+ continue
103
+ if num_frames != 0:
104
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
105
+ continue
106
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
107
+ continue
108
+ yield sample
109
+
110
+
111
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
112
+ """ Resample data.
113
+ Inplace operation.
114
+
115
+ Args:
116
+ data: Iterable[{key, wav, label, sample_rate}]
117
+ resample_rate: target resample rate
118
+
119
+ Returns:
120
+ Iterable[{key, wav, label, sample_rate}]
121
+ """
122
+ for sample in data:
123
+ assert 'sample_rate' in sample
124
+ assert 'speech' in sample
125
+ sample_rate = sample['sample_rate']
126
+ waveform = sample['speech']
127
+ if sample_rate != resample_rate:
128
+ if sample_rate < min_sample_rate:
129
+ continue
130
+ sample['sample_rate'] = resample_rate
131
+ sample['speech'] = torchaudio.transforms.Resample(
132
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
133
+ max_val = sample['speech'].abs().max()
134
+ if max_val > 1:
135
+ sample['speech'] /= max_val
136
+ yield sample
137
+
138
+
139
+ def truncate(data, truncate_length=24576, mode='train'):
140
+ """ Truncate data.
141
+
142
+ Args:
143
+ data: Iterable[{key, wav, label, sample_rate}]
144
+ truncate_length: truncate length
145
+
146
+ Returns:
147
+ Iterable[{key, wav, label, sample_rate}]
148
+ """
149
+ for sample in data:
150
+ waveform = sample['speech']
151
+ if waveform.shape[1] > truncate_length:
152
+ start = random.randint(0, waveform.shape[1] - truncate_length)
153
+ waveform = waveform[:, start: start + truncate_length]
154
+ else:
155
+ waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
156
+ sample['speech'] = waveform
157
+ yield sample
158
+
159
+
160
+ def compute_fbank(data,
161
+ feat_extractor,
162
+ token_mel_ratio=0,
163
+ mode='train'):
164
+ """ Extract fbank
165
+
166
+ Args:
167
+ data: Iterable[{key, wav, label, sample_rate}]
168
+
169
+ Returns:
170
+ Iterable[{key, feat, label}]
171
+ """
172
+ for sample in data:
173
+ assert 'sample_rate' in sample
174
+ assert 'speech' in sample
175
+ assert 'utt' in sample
176
+ assert 'text_token' in sample
177
+ waveform = sample['speech']
178
+ feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
179
+ if token_mel_ratio != 0:
180
+ # trim to align speech_token and speech_feat
181
+ token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
182
+ feat = feat[:token_mel_ratio * token_len]
183
+ sample["speech_token"] = sample["speech_token"][:token_len]
184
+ sample['speech_feat'] = feat
185
+ yield sample
186
+
187
+
188
+ def compute_f0(data, sample_rate, hop_size, mode='train'):
189
+ """ Extract f0
190
+
191
+ Args:
192
+ data: Iterable[{key, wav, label, sample_rate}]
193
+
194
+ Returns:
195
+ Iterable[{key, feat, label}]
196
+ """
197
+ frame_period = hop_size * 1000 / sample_rate
198
+ for sample in data:
199
+ assert 'sample_rate' in sample
200
+ assert 'speech' in sample
201
+ assert 'utt' in sample
202
+ assert 'text_token' in sample
203
+ waveform = sample['speech']
204
+ _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
205
+ if sum(_f0 != 0) < 5: # this happens when the algorithm fails
206
+ _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
207
+ f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
208
+ f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
209
+ sample['pitch_feat'] = f0
210
+ yield sample
211
+
212
+
213
+ def parse_embedding(data, normalize, mode='train'):
214
+ """ Parse utt_embedding/spk_embedding
215
+
216
+ Args:
217
+ data: Iterable[{key, wav, label, sample_rate}]
218
+
219
+ Returns:
220
+ Iterable[{key, feat, label}]
221
+ """
222
+ for sample in data:
223
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
224
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
225
+ if normalize:
226
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
227
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
228
+ yield sample
229
+
230
+
231
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
232
+ """ Decode text to chars or BPE
233
+ Inplace operation
234
+
235
+ Args:
236
+ data: Iterable[{key, wav, txt, sample_rate}]
237
+
238
+ Returns:
239
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
240
+ """
241
+ tokenizer = get_tokenizer()
242
+ for sample in data:
243
+ assert 'text' in sample
244
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
245
+ if mode == 'inference':
246
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
247
+ yield sample
248
+
249
+
250
+ def shuffle(data, shuffle_size=10000, mode='train'):
251
+ """ Local shuffle the data
252
+
253
+ Args:
254
+ data: Iterable[{key, feat, label}]
255
+ shuffle_size: buffer size for shuffle
256
+
257
+ Returns:
258
+ Iterable[{key, feat, label}]
259
+ """
260
+ buf = []
261
+ for sample in data:
262
+ buf.append(sample)
263
+ if len(buf) >= shuffle_size:
264
+ random.shuffle(buf)
265
+ for x in buf:
266
+ yield x
267
+ buf = []
268
+ # The sample left over
269
+ random.shuffle(buf)
270
+ for x in buf:
271
+ yield x
272
+
273
+
274
+ def sort(data, sort_size=500, mode='train'):
275
+ """ Sort the data by feature length.
276
+ Sort is used after shuffle and before batch, so we can group
277
+ utts with similar lengths into a batch, and `sort_size` should
278
+ be less than `shuffle_size`
279
+
280
+ Args:
281
+ data: Iterable[{key, feat, label}]
282
+ sort_size: buffer size for sort
283
+
284
+ Returns:
285
+ Iterable[{key, feat, label}]
286
+ """
287
+
288
+ buf = []
289
+ for sample in data:
290
+ buf.append(sample)
291
+ if len(buf) >= sort_size:
292
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
293
+ for x in buf:
294
+ yield x
295
+ buf = []
296
+ # The sample left over
297
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
298
+ for x in buf:
299
+ yield x
300
+
301
+
302
+ def static_batch(data, batch_size=16):
303
+ """ Static batch the data by `batch_size`
304
+
305
+ Args:
306
+ data: Iterable[{key, feat, label}]
307
+ batch_size: batch size
308
+
309
+ Returns:
310
+ Iterable[List[{key, feat, label}]]
311
+ """
312
+ buf = []
313
+ for sample in data:
314
+ buf.append(sample)
315
+ if len(buf) >= batch_size:
316
+ yield buf
317
+ buf = []
318
+ if len(buf) > 0:
319
+ yield buf
320
+
321
+
322
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
323
+ """ Dynamic batch the data until the total frames in batch
324
+ reach `max_frames_in_batch`
325
+
326
+ Args:
327
+ data: Iterable[{key, feat, label}]
328
+ max_frames_in_batch: max_frames in one batch
329
+
330
+ Returns:
331
+ Iterable[List[{key, feat, label}]]
332
+ """
333
+ buf = []
334
+ longest_frames = 0
335
+ for sample in data:
336
+ assert 'speech_feat' in sample
337
+ assert isinstance(sample['speech_feat'], torch.Tensor)
338
+ new_sample_frames = sample['speech_feat'].size(0)
339
+ longest_frames = max(longest_frames, new_sample_frames)
340
+ frames_after_padding = longest_frames * (len(buf) + 1)
341
+ if frames_after_padding > max_frames_in_batch:
342
+ yield buf
343
+ buf = [sample]
344
+ longest_frames = new_sample_frames
345
+ else:
346
+ buf.append(sample)
347
+ if len(buf) > 0:
348
+ yield buf
349
+
350
+
351
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
352
+ """ Wrapper for static/dynamic batch
353
+ """
354
+ if mode == 'inference':
355
+ return static_batch(data, 1)
356
+ else:
357
+ if batch_type == 'static':
358
+ return static_batch(data, batch_size)
359
+ elif batch_type == 'dynamic':
360
+ return dynamic_batch(data, max_frames_in_batch)
361
+ else:
362
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
363
+
364
+
365
+ def padding(data, use_spk_embedding, mode='train', gan=False):
366
+ """ Padding the data into training data
367
+
368
+ Args:
369
+ data: Iterable[List[{key, feat, label}]]
370
+
371
+ Returns:
372
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
373
+ """
374
+ for sample in data:
375
+ assert isinstance(sample, list)
376
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
377
+ dtype=torch.int32)
378
+ order = torch.argsort(speech_feat_len, descending=True)
379
+
380
+ utts = [sample[i]['utt'] for i in order]
381
+ speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
382
+ speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
383
+ speech = pad_sequence(speech, batch_first=True, padding_value=0)
384
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
385
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
386
+ speech_token = pad_sequence(speech_token,
387
+ batch_first=True,
388
+ padding_value=0)
389
+ speech_feat = [sample[i]['speech_feat'] for i in order]
390
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
391
+ speech_feat = pad_sequence(speech_feat,
392
+ batch_first=True,
393
+ padding_value=0)
394
+ text = [sample[i]['text'] for i in order]
395
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
396
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
397
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
398
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
399
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
400
+ batch = {
401
+ "utts": utts,
402
+ "speech": speech,
403
+ "speech_len": speech_len,
404
+ "speech_token": speech_token,
405
+ "speech_token_len": speech_token_len,
406
+ "speech_feat": speech_feat,
407
+ "speech_feat_len": speech_feat_len,
408
+ "text": text,
409
+ "text_token": text_token,
410
+ "text_token_len": text_token_len,
411
+ "utt_embedding": utt_embedding,
412
+ "spk_embedding": spk_embedding,
413
+ }
414
+ if gan is True:
415
+ # in gan train, we need pitch_feat
416
+ pitch_feat = [sample[i]['pitch_feat'] for i in order]
417
+ pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
418
+ pitch_feat = pad_sequence(pitch_feat,
419
+ batch_first=True,
420
+ padding_value=0)
421
+ batch["pitch_feat"] = pitch_feat
422
+ batch["pitch_feat_len"] = pitch_feat_len
423
+ else:
424
+ # only gan train needs speech, delete it to save memory
425
+ del batch["speech"]
426
+ del batch["speech_len"]
427
+ if mode == 'inference':
428
+ tts_text = [sample[i]['tts_text'] for i in order]
429
+ tts_index = [sample[i]['tts_index'] for i in order]
430
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
431
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
432
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
433
+ batch.update({'tts_text': tts_text,
434
+ 'tts_index': tts_index,
435
+ 'tts_text_token': tts_text_token,
436
+ 'tts_text_token_len': tts_text_token_len})
437
+ if use_spk_embedding is True:
438
+ batch["embedding"] = batch["spk_embedding"]
439
+ else:
440
+ batch["embedding"] = batch["utt_embedding"]
441
+ yield batch
cosyvoice/dataset/processor_dpo.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+
17
+ import pyarrow.parquet as pq
18
+ from io import BytesIO
19
+ import torch
20
+ import torchaudio
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ import torch.nn.functional as F
23
+ import pyworld as pw
24
+
25
+
26
+ AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
27
+
28
+
29
+ def parquet_opener(data, mode='train', tts_data={}):
30
+ """ Give url or local file, return file descriptor
31
+ Inplace operation.
32
+
33
+ Args:
34
+ data(Iterable[str]): url or local file list
35
+
36
+ Returns:
37
+ Iterable[{src, stream}]
38
+ """
39
+ for sample in data:
40
+ assert 'src' in sample
41
+ url = sample['src']
42
+ try:
43
+ for df in pq.ParquetFile(url).iter_batches(batch_size=64):
44
+ df = df.to_pandas()
45
+ for i in range(len(df)):
46
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
47
+ continue
48
+ sample.update(dict(df.loc[i]))
49
+ if mode == 'train':
50
+ # NOTE do not return sample directly, must initialize a new dict
51
+ yield {**sample}
52
+ else:
53
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
54
+ yield {**sample, 'tts_index': index, 'tts_text': text}
55
+ except Exception as ex:
56
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
57
+
58
+
59
+ def filter(data,
60
+ max_length=10240,
61
+ min_length=10,
62
+ token_max_length=200,
63
+ token_min_length=1,
64
+ min_output_input_ratio=0.0005,
65
+ max_output_input_ratio=1,
66
+ mode='train'):
67
+ """ Filter sample according to feature and label length
68
+ Inplace operation.
69
+
70
+ Args::
71
+ data: Iterable[{key, wav, label, sample_rate}]
72
+ max_length: drop utterance which is greater than max_length(10ms)
73
+ min_length: drop utterance which is less than min_length(10ms)
74
+ token_max_length: drop utterance which is greater than
75
+ token_max_length, especially when use char unit for
76
+ english modeling
77
+ token_min_length: drop utterance which is
78
+ less than token_max_length
79
+ min_output_input_ratio: minimal ration of
80
+ token_length / feats_length(10ms)
81
+ max_output_input_ratio: maximum ration of
82
+ token_length / feats_length(10ms)
83
+
84
+ Returns:
85
+ Iterable[{key, wav, label, sample_rate}]
86
+ """
87
+ for sample in data:
88
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
89
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
90
+ del sample['audio_data']
91
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
92
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
93
+ if num_frames < min_length:
94
+ continue
95
+ if num_frames > max_length:
96
+ continue
97
+ if len(sample['text_token']) < token_min_length:
98
+ continue
99
+ if len(sample['text_token']) > token_max_length:
100
+ continue
101
+ if len(sample['speech_token']) == 0:
102
+ continue
103
+ if num_frames != 0:
104
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
105
+ continue
106
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
107
+ continue
108
+ yield sample
109
+
110
+
111
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
112
+ """ Resample data.
113
+ Inplace operation.
114
+
115
+ Args:
116
+ data: Iterable[{key, wav, label, sample_rate}]
117
+ resample_rate: target resample rate
118
+
119
+ Returns:
120
+ Iterable[{key, wav, label, sample_rate}]
121
+ """
122
+ for sample in data:
123
+ assert 'sample_rate' in sample
124
+ assert 'speech' in sample
125
+ sample_rate = sample['sample_rate']
126
+ waveform = sample['speech']
127
+ if sample_rate != resample_rate:
128
+ if sample_rate < min_sample_rate:
129
+ continue
130
+ sample['sample_rate'] = resample_rate
131
+ sample['speech'] = torchaudio.transforms.Resample(
132
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
133
+ max_val = sample['speech'].abs().max()
134
+ if max_val > 1:
135
+ sample['speech'] /= max_val
136
+ yield sample
137
+
138
+
139
+ def truncate(data, truncate_length=24576, mode='train'):
140
+ """ Truncate data.
141
+
142
+ Args:
143
+ data: Iterable[{key, wav, label, sample_rate}]
144
+ truncate_length: truncate length
145
+
146
+ Returns:
147
+ Iterable[{key, wav, label, sample_rate}]
148
+ """
149
+ for sample in data:
150
+ waveform = sample['speech']
151
+ if waveform.shape[1] > truncate_length:
152
+ start = random.randint(0, waveform.shape[1] - truncate_length)
153
+ waveform = waveform[:, start: start + truncate_length]
154
+ else:
155
+ waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
156
+ sample['speech'] = waveform
157
+ yield sample
158
+
159
+
160
+ def compute_fbank(data,
161
+ feat_extractor,
162
+ mode='train'):
163
+ """ Extract fbank
164
+
165
+ Args:
166
+ data: Iterable[{key, wav, label, sample_rate}]
167
+
168
+ Returns:
169
+ Iterable[{key, feat, label}]
170
+ """
171
+ for sample in data:
172
+ assert 'sample_rate' in sample
173
+ assert 'speech' in sample
174
+ assert 'utt' in sample
175
+ assert 'text_token' in sample
176
+ waveform = sample['speech']
177
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
178
+ sample['speech_feat'] = mat
179
+ yield sample
180
+
181
+
182
+ def compute_f0(data, sample_rate, hop_size, mode='train'):
183
+ """ Extract f0
184
+
185
+ Args:
186
+ data: Iterable[{key, wav, label, sample_rate}]
187
+
188
+ Returns:
189
+ Iterable[{key, feat, label}]
190
+ """
191
+ frame_period = hop_size * 1000 / sample_rate
192
+ for sample in data:
193
+ assert 'sample_rate' in sample
194
+ assert 'speech' in sample
195
+ assert 'utt' in sample
196
+ assert 'text_token' in sample
197
+ waveform = sample['speech']
198
+ _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
199
+ if sum(_f0 != 0) < 5: # this happens when the algorithm fails
200
+ _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
201
+ f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
202
+ f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
203
+ sample['pitch_feat'] = f0
204
+ yield sample
205
+
206
+
207
+ def parse_embedding(data, normalize, mode='train'):
208
+ """ Parse utt_embedding/spk_embedding
209
+
210
+ Args:
211
+ data: Iterable[{key, wav, label, sample_rate}]
212
+
213
+ Returns:
214
+ Iterable[{key, feat, label}]
215
+ """
216
+ for sample in data:
217
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
218
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
219
+ if normalize:
220
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
221
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
222
+ yield sample
223
+
224
+
225
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
226
+ """ Decode text to chars or BPE
227
+ Inplace operation
228
+
229
+ Args:
230
+ data: Iterable[{key, wav, txt, sample_rate}]
231
+
232
+ Returns:
233
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
234
+ """
235
+ tokenizer = get_tokenizer()
236
+ for sample in data:
237
+ assert 'text' in sample
238
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
239
+ if mode == 'inference':
240
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
241
+ yield sample
242
+
243
+
244
+ def shuffle(data, shuffle_size=10000, mode='train'):
245
+ """ Local shuffle the data
246
+
247
+ Args:
248
+ data: Iterable[{key, feat, label}]
249
+ shuffle_size: buffer size for shuffle
250
+
251
+ Returns:
252
+ Iterable[{key, feat, label}]
253
+ """
254
+ buf = []
255
+ for sample in data:
256
+ buf.append(sample)
257
+ if len(buf) >= shuffle_size:
258
+ random.shuffle(buf)
259
+ for x in buf:
260
+ yield x
261
+ buf = []
262
+ # The sample left over
263
+ random.shuffle(buf)
264
+ for x in buf:
265
+ yield x
266
+
267
+
268
+ def sort(data, sort_size=500, mode='train'):
269
+ """ Sort the data by feature length.
270
+ Sort is used after shuffle and before batch, so we can group
271
+ utts with similar lengths into a batch, and `sort_size` should
272
+ be less than `shuffle_size`
273
+
274
+ Args:
275
+ data: Iterable[{key, feat, label}]
276
+ sort_size: buffer size for sort
277
+
278
+ Returns:
279
+ Iterable[{key, feat, label}]
280
+ """
281
+
282
+ buf = []
283
+ for sample in data:
284
+ buf.append(sample)
285
+ if len(buf) >= sort_size:
286
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
287
+ for x in buf:
288
+ yield x
289
+ buf = []
290
+ # The sample left over
291
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
292
+ for x in buf:
293
+ yield x
294
+
295
+
296
+ def static_batch(data, batch_size=16):
297
+ """ Static batch the data by `batch_size`
298
+
299
+ Args:
300
+ data: Iterable[{key, feat, label}]
301
+ batch_size: batch size
302
+
303
+ Returns:
304
+ Iterable[List[{key, feat, label}]]
305
+ """
306
+ buf = []
307
+ for sample in data:
308
+ buf.append(sample)
309
+ if len(buf) >= batch_size:
310
+ yield buf
311
+ buf = []
312
+ if len(buf) > 0:
313
+ yield buf
314
+
315
+
316
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
317
+ """ Dynamic batch the data until the total frames in batch
318
+ reach `max_frames_in_batch`
319
+
320
+ Args:
321
+ data: Iterable[{key, feat, label}]
322
+ max_frames_in_batch: max_frames in one batch
323
+
324
+ Returns:
325
+ Iterable[List[{key, feat, label}]]
326
+ """
327
+ buf = []
328
+ longest_frames = 0
329
+ for sample in data:
330
+ assert 'speech_feat' in sample
331
+ assert isinstance(sample['speech_feat'], torch.Tensor)
332
+ new_sample_frames = sample['speech_feat'].size(0)
333
+ longest_frames = max(longest_frames, new_sample_frames)
334
+ frames_after_padding = longest_frames * (len(buf) + 1)
335
+ if frames_after_padding > max_frames_in_batch:
336
+ yield buf
337
+ buf = [sample]
338
+ longest_frames = new_sample_frames
339
+ else:
340
+ buf.append(sample)
341
+ if len(buf) > 0:
342
+ yield buf
343
+
344
+
345
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
346
+ """ Wrapper for static/dynamic batch
347
+ """
348
+ if mode == 'inference':
349
+ return static_batch(data, 1)
350
+ else:
351
+ if batch_type == 'static':
352
+ return static_batch(data, batch_size)
353
+ elif batch_type == 'dynamic':
354
+ return dynamic_batch(data, max_frames_in_batch)
355
+ else:
356
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
357
+
358
+
359
+ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
360
+ """ Padding the data into training data
361
+
362
+ Args:
363
+ data: Iterable[List[{key, feat, label}]]
364
+
365
+ Returns:
366
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
367
+ """
368
+ for sample in data:
369
+ assert isinstance(sample, list)
370
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
371
+ dtype=torch.int32)
372
+ order = torch.argsort(speech_feat_len, descending=True)
373
+
374
+ utts = [sample[i]['utt'] for i in order]
375
+ speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
376
+ speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
377
+ speech = pad_sequence(speech, batch_first=True, padding_value=0)
378
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
379
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
380
+ speech_token = pad_sequence(speech_token,
381
+ batch_first=True,
382
+ padding_value=0)
383
+ speech_feat = [sample[i]['speech_feat'] for i in order]
384
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
385
+ speech_feat = pad_sequence(speech_feat,
386
+ batch_first=True,
387
+ padding_value=0)
388
+ text = [sample[i]['text'] for i in order]
389
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
390
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
391
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
392
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
393
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
394
+ batch = {
395
+ "utts": utts,
396
+ "speech": speech,
397
+ "speech_len": speech_len,
398
+ "speech_token": speech_token,
399
+ "speech_token_len": speech_token_len,
400
+ "speech_feat": speech_feat,
401
+ "speech_feat_len": speech_feat_len,
402
+ "text": text,
403
+ "text_token": text_token,
404
+ "text_token_len": text_token_len,
405
+ "utt_embedding": utt_embedding,
406
+ "spk_embedding": spk_embedding,
407
+ }
408
+ if dpo:
409
+ reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
410
+ reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
411
+ reject_speech_token = pad_sequence(reject_speech_token,
412
+ batch_first=True,
413
+ padding_value=0)
414
+ batch['reject_speech_token'] = reject_speech_token
415
+ batch['reject_speech_token_len'] = reject_speech_token_len
416
+ if gan is True:
417
+ # in gan train, we need pitch_feat
418
+ pitch_feat = [sample[i]['pitch_feat'] for i in order]
419
+ pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
420
+ pitch_feat = pad_sequence(pitch_feat,
421
+ batch_first=True,
422
+ padding_value=0)
423
+ batch["pitch_feat"] = pitch_feat
424
+ batch["pitch_feat_len"] = pitch_feat_len
425
+ else:
426
+ # only gan train needs speech, delete it to save memory
427
+ del batch["speech"]
428
+ del batch["speech_len"]
429
+ if mode == 'inference':
430
+ tts_text = [sample[i]['tts_text'] for i in order]
431
+ tts_index = [sample[i]['tts_index'] for i in order]
432
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
433
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
434
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
435
+ batch.update({'tts_text': tts_text,
436
+ 'tts_index': tts_index,
437
+ 'tts_text_token': tts_text_token,
438
+ 'tts_text_token_len': tts_text_token_len})
439
+ if use_spk_embedding is True:
440
+ batch["embedding"] = batch["spk_embedding"]
441
+ else:
442
+ batch["embedding"] = batch["utt_embedding"]
443
+ yield batch
cosyvoice/flow/decoder.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from einops import pack, rearrange, repeat
19
+ from cosyvoice.utils.common import mask_to_bias
20
+ from cosyvoice.utils.mask import add_optional_chunk_mask
21
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
22
+ from matcha.models.components.transformer import BasicTransformerBlock
23
+
24
+
25
+ class Transpose(torch.nn.Module):
26
+ def __init__(self, dim0: int, dim1: int):
27
+ super().__init__()
28
+ self.dim0 = dim0
29
+ self.dim1 = dim1
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ x = torch.transpose(x, self.dim0, self.dim1)
33
+ return x
34
+
35
+
36
+ class CausalConv1d(torch.nn.Conv1d):
37
+ def __init__(
38
+ self,
39
+ in_channels: int,
40
+ out_channels: int,
41
+ kernel_size: int,
42
+ stride: int = 1,
43
+ dilation: int = 1,
44
+ groups: int = 1,
45
+ bias: bool = True,
46
+ padding_mode: str = 'zeros',
47
+ device=None,
48
+ dtype=None
49
+ ) -> None:
50
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
51
+ kernel_size, stride,
52
+ padding=0, dilation=dilation,
53
+ groups=groups, bias=bias,
54
+ padding_mode=padding_mode,
55
+ device=device, dtype=dtype)
56
+ assert stride == 1
57
+ self.causal_padding = kernel_size - 1
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
61
+ x = super(CausalConv1d, self).forward(x)
62
+ return x
63
+
64
+
65
+ class CausalBlock1D(Block1D):
66
+ def __init__(self, dim: int, dim_out: int):
67
+ super(CausalBlock1D, self).__init__(dim, dim_out)
68
+ self.block = torch.nn.Sequential(
69
+ CausalConv1d(dim, dim_out, 3),
70
+ Transpose(1, 2),
71
+ nn.LayerNorm(dim_out),
72
+ Transpose(1, 2),
73
+ nn.Mish(),
74
+ )
75
+
76
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ output = self.block(x * mask)
78
+ return output * mask
79
+
80
+
81
+ class CausalResnetBlock1D(ResnetBlock1D):
82
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
83
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
84
+ self.block1 = CausalBlock1D(dim, dim_out)
85
+ self.block2 = CausalBlock1D(dim_out, dim_out)
86
+
87
+
88
+ class ConditionalDecoder(nn.Module):
89
+ def __init__(
90
+ self,
91
+ in_channels,
92
+ out_channels,
93
+ channels=(256, 256),
94
+ dropout=0.05,
95
+ attention_head_dim=64,
96
+ n_blocks=1,
97
+ num_mid_blocks=2,
98
+ num_heads=4,
99
+ act_fn="snake",
100
+ ):
101
+ """
102
+ This decoder requires an input with the same shape of the target. So, if your text content
103
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
104
+ """
105
+ super().__init__()
106
+ channels = tuple(channels)
107
+ self.in_channels = in_channels
108
+ self.out_channels = out_channels
109
+
110
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
111
+ time_embed_dim = channels[0] * 4
112
+ self.time_mlp = TimestepEmbedding(
113
+ in_channels=in_channels,
114
+ time_embed_dim=time_embed_dim,
115
+ act_fn="silu",
116
+ )
117
+ self.down_blocks = nn.ModuleList([])
118
+ self.mid_blocks = nn.ModuleList([])
119
+ self.up_blocks = nn.ModuleList([])
120
+
121
+ output_channel = in_channels
122
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
123
+ input_channel = output_channel
124
+ output_channel = channels[i]
125
+ is_last = i == len(channels) - 1
126
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
127
+ transformer_blocks = nn.ModuleList(
128
+ [
129
+ BasicTransformerBlock(
130
+ dim=output_channel,
131
+ num_attention_heads=num_heads,
132
+ attention_head_dim=attention_head_dim,
133
+ dropout=dropout,
134
+ activation_fn=act_fn,
135
+ )
136
+ for _ in range(n_blocks)
137
+ ]
138
+ )
139
+ downsample = (
140
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
141
+ )
142
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
143
+
144
+ for _ in range(num_mid_blocks):
145
+ input_channel = channels[-1]
146
+ out_channels = channels[-1]
147
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
148
+
149
+ transformer_blocks = nn.ModuleList(
150
+ [
151
+ BasicTransformerBlock(
152
+ dim=output_channel,
153
+ num_attention_heads=num_heads,
154
+ attention_head_dim=attention_head_dim,
155
+ dropout=dropout,
156
+ activation_fn=act_fn,
157
+ )
158
+ for _ in range(n_blocks)
159
+ ]
160
+ )
161
+
162
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
163
+
164
+ channels = channels[::-1] + (channels[0],)
165
+ for i in range(len(channels) - 1):
166
+ input_channel = channels[i] * 2
167
+ output_channel = channels[i + 1]
168
+ is_last = i == len(channels) - 2
169
+ resnet = ResnetBlock1D(
170
+ dim=input_channel,
171
+ dim_out=output_channel,
172
+ time_emb_dim=time_embed_dim,
173
+ )
174
+ transformer_blocks = nn.ModuleList(
175
+ [
176
+ BasicTransformerBlock(
177
+ dim=output_channel,
178
+ num_attention_heads=num_heads,
179
+ attention_head_dim=attention_head_dim,
180
+ dropout=dropout,
181
+ activation_fn=act_fn,
182
+ )
183
+ for _ in range(n_blocks)
184
+ ]
185
+ )
186
+ upsample = (
187
+ Upsample1D(output_channel, use_conv_transpose=True)
188
+ if not is_last
189
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
190
+ )
191
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
192
+ self.final_block = Block1D(channels[-1], channels[-1])
193
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
194
+ self.initialize_weights()
195
+
196
+ def initialize_weights(self):
197
+ for m in self.modules():
198
+ if isinstance(m, nn.Conv1d):
199
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
200
+ if m.bias is not None:
201
+ nn.init.constant_(m.bias, 0)
202
+ elif isinstance(m, nn.GroupNorm):
203
+ nn.init.constant_(m.weight, 1)
204
+ nn.init.constant_(m.bias, 0)
205
+ elif isinstance(m, nn.Linear):
206
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
207
+ if m.bias is not None:
208
+ nn.init.constant_(m.bias, 0)
209
+
210
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
211
+ """Forward pass of the UNet1DConditional model.
212
+
213
+ Args:
214
+ x (torch.Tensor): shape (batch_size, in_channels, time)
215
+ mask (_type_): shape (batch_size, 1, time)
216
+ t (_type_): shape (batch_size)
217
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
218
+ cond (_type_, optional): placeholder for future use. Defaults to None.
219
+
220
+ Raises:
221
+ ValueError: _description_
222
+ ValueError: _description_
223
+
224
+ Returns:
225
+ _type_: _description_
226
+ """
227
+
228
+ t = self.time_embeddings(t).to(t.dtype)
229
+ t = self.time_mlp(t)
230
+
231
+ x = pack([x, mu], "b * t")[0]
232
+
233
+ if spks is not None:
234
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
235
+ x = pack([x, spks], "b * t")[0]
236
+ if cond is not None:
237
+ x = pack([x, cond], "b * t")[0]
238
+
239
+ hiddens = []
240
+ masks = [mask]
241
+ for resnet, transformer_blocks, downsample in self.down_blocks:
242
+ mask_down = masks[-1]
243
+ x = resnet(x, mask_down, t)
244
+ x = rearrange(x, "b c t -> b t c").contiguous()
245
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
246
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
247
+ for transformer_block in transformer_blocks:
248
+ x = transformer_block(
249
+ hidden_states=x,
250
+ attention_mask=attn_mask,
251
+ timestep=t,
252
+ )
253
+ x = rearrange(x, "b t c -> b c t").contiguous()
254
+ hiddens.append(x) # Save hidden states for skip connections
255
+ x = downsample(x * mask_down)
256
+ masks.append(mask_down[:, :, ::2])
257
+ masks = masks[:-1]
258
+ mask_mid = masks[-1]
259
+
260
+ for resnet, transformer_blocks in self.mid_blocks:
261
+ x = resnet(x, mask_mid, t)
262
+ x = rearrange(x, "b c t -> b t c").contiguous()
263
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
264
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
265
+ for transformer_block in transformer_blocks:
266
+ x = transformer_block(
267
+ hidden_states=x,
268
+ attention_mask=attn_mask,
269
+ timestep=t,
270
+ )
271
+ x = rearrange(x, "b t c -> b c t").contiguous()
272
+
273
+ for resnet, transformer_blocks, upsample in self.up_blocks:
274
+ mask_up = masks.pop()
275
+ skip = hiddens.pop()
276
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
277
+ x = resnet(x, mask_up, t)
278
+ x = rearrange(x, "b c t -> b t c").contiguous()
279
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
280
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
281
+ for transformer_block in transformer_blocks:
282
+ x = transformer_block(
283
+ hidden_states=x,
284
+ attention_mask=attn_mask,
285
+ timestep=t,
286
+ )
287
+ x = rearrange(x, "b t c -> b c t").contiguous()
288
+ x = upsample(x * mask_up)
289
+ x = self.final_block(x, mask_up)
290
+ output = self.final_proj(x * mask_up)
291
+ return output * mask
292
+
293
+
294
+ class CausalConditionalDecoder(ConditionalDecoder):
295
+ def __init__(
296
+ self,
297
+ in_channels,
298
+ out_channels,
299
+ channels=(256, 256),
300
+ dropout=0.05,
301
+ attention_head_dim=64,
302
+ n_blocks=1,
303
+ num_mid_blocks=2,
304
+ num_heads=4,
305
+ act_fn="snake",
306
+ static_chunk_size=50,
307
+ num_decoding_left_chunks=2,
308
+ ):
309
+ """
310
+ This decoder requires an input with the same shape of the target. So, if your text content
311
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
312
+ """
313
+ torch.nn.Module.__init__(self)
314
+ channels = tuple(channels)
315
+ self.in_channels = in_channels
316
+ self.out_channels = out_channels
317
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
318
+ time_embed_dim = channels[0] * 4
319
+ self.time_mlp = TimestepEmbedding(
320
+ in_channels=in_channels,
321
+ time_embed_dim=time_embed_dim,
322
+ act_fn="silu",
323
+ )
324
+ self.static_chunk_size = static_chunk_size
325
+ self.num_decoding_left_chunks = num_decoding_left_chunks
326
+ self.down_blocks = nn.ModuleList([])
327
+ self.mid_blocks = nn.ModuleList([])
328
+ self.up_blocks = nn.ModuleList([])
329
+
330
+ output_channel = in_channels
331
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
332
+ input_channel = output_channel
333
+ output_channel = channels[i]
334
+ is_last = i == len(channels) - 1
335
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
336
+ transformer_blocks = nn.ModuleList(
337
+ [
338
+ BasicTransformerBlock(
339
+ dim=output_channel,
340
+ num_attention_heads=num_heads,
341
+ attention_head_dim=attention_head_dim,
342
+ dropout=dropout,
343
+ activation_fn=act_fn,
344
+ )
345
+ for _ in range(n_blocks)
346
+ ]
347
+ )
348
+ downsample = (
349
+ Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
350
+ )
351
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
352
+
353
+ for _ in range(num_mid_blocks):
354
+ input_channel = channels[-1]
355
+ out_channels = channels[-1]
356
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
357
+
358
+ transformer_blocks = nn.ModuleList(
359
+ [
360
+ BasicTransformerBlock(
361
+ dim=output_channel,
362
+ num_attention_heads=num_heads,
363
+ attention_head_dim=attention_head_dim,
364
+ dropout=dropout,
365
+ activation_fn=act_fn,
366
+ )
367
+ for _ in range(n_blocks)
368
+ ]
369
+ )
370
+
371
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
372
+
373
+ channels = channels[::-1] + (channels[0],)
374
+ for i in range(len(channels) - 1):
375
+ input_channel = channels[i] * 2
376
+ output_channel = channels[i + 1]
377
+ is_last = i == len(channels) - 2
378
+ resnet = CausalResnetBlock1D(
379
+ dim=input_channel,
380
+ dim_out=output_channel,
381
+ time_emb_dim=time_embed_dim,
382
+ )
383
+ transformer_blocks = nn.ModuleList(
384
+ [
385
+ BasicTransformerBlock(
386
+ dim=output_channel,
387
+ num_attention_heads=num_heads,
388
+ attention_head_dim=attention_head_dim,
389
+ dropout=dropout,
390
+ activation_fn=act_fn,
391
+ )
392
+ for _ in range(n_blocks)
393
+ ]
394
+ )
395
+ upsample = (
396
+ Upsample1D(output_channel, use_conv_transpose=True)
397
+ if not is_last
398
+ else CausalConv1d(output_channel, output_channel, 3)
399
+ )
400
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
401
+ self.final_block = CausalBlock1D(channels[-1], channels[-1])
402
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
403
+ self.initialize_weights()
404
+
405
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
406
+ """Forward pass of the UNet1DConditional model.
407
+
408
+ Args:
409
+ x (torch.Tensor): shape (batch_size, in_channels, time)
410
+ mask (_type_): shape (batch_size, 1, time)
411
+ t (_type_): shape (batch_size)
412
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
413
+ cond (_type_, optional): placeholder for future use. Defaults to None.
414
+
415
+ Raises:
416
+ ValueError: _description_
417
+ ValueError: _description_
418
+
419
+ Returns:
420
+ _type_: _description_
421
+ """
422
+ t = self.time_embeddings(t).to(t.dtype)
423
+ t = self.time_mlp(t)
424
+
425
+ x = pack([x, mu], "b * t")[0]
426
+
427
+ if spks is not None:
428
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
429
+ x = pack([x, spks], "b * t")[0]
430
+ if cond is not None:
431
+ x = pack([x, cond], "b * t")[0]
432
+
433
+ hiddens = []
434
+ masks = [mask]
435
+ for resnet, transformer_blocks, downsample in self.down_blocks:
436
+ mask_down = masks[-1]
437
+ x = resnet(x, mask_down, t)
438
+ x = rearrange(x, "b c t -> b t c").contiguous()
439
+ if streaming is True:
440
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
441
+ else:
442
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
443
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
444
+ for transformer_block in transformer_blocks:
445
+ x = transformer_block(
446
+ hidden_states=x,
447
+ attention_mask=attn_mask,
448
+ timestep=t,
449
+ )
450
+ x = rearrange(x, "b t c -> b c t").contiguous()
451
+ hiddens.append(x) # Save hidden states for skip connections
452
+ x = downsample(x * mask_down)
453
+ masks.append(mask_down[:, :, ::2])
454
+ masks = masks[:-1]
455
+ mask_mid = masks[-1]
456
+
457
+ for resnet, transformer_blocks in self.mid_blocks:
458
+ x = resnet(x, mask_mid, t)
459
+ x = rearrange(x, "b c t -> b t c").contiguous()
460
+ if streaming is True:
461
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
462
+ else:
463
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
464
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
465
+ for transformer_block in transformer_blocks:
466
+ x = transformer_block(
467
+ hidden_states=x,
468
+ attention_mask=attn_mask,
469
+ timestep=t,
470
+ )
471
+ x = rearrange(x, "b t c -> b c t").contiguous()
472
+
473
+ for resnet, transformer_blocks, upsample in self.up_blocks:
474
+ mask_up = masks.pop()
475
+ skip = hiddens.pop()
476
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
477
+ x = resnet(x, mask_up, t)
478
+ x = rearrange(x, "b c t -> b t c").contiguous()
479
+ if streaming is True:
480
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
481
+ else:
482
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
483
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
484
+ for transformer_block in transformer_blocks:
485
+ x = transformer_block(
486
+ hidden_states=x,
487
+ attention_mask=attn_mask,
488
+ timestep=t,
489
+ )
490
+ x = rearrange(x, "b t c -> b c t").contiguous()
491
+ x = upsample(x * mask_up)
492
+ x = self.final_block(x, mask_up)
493
+ output = self.final_proj(x * mask_up)
494
+ return output * mask
cosyvoice/flow/flow.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from cosyvoice.utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
43
+ super().__init__()
44
+ self.input_size = input_size
45
+ self.output_size = output_size
46
+ self.decoder_conf = decoder_conf
47
+ self.mel_feat_conf = mel_feat_conf
48
+ self.vocab_size = vocab_size
49
+ self.output_type = output_type
50
+ self.input_frame_rate = input_frame_rate
51
+ logging.info(f"input frame rate={self.input_frame_rate}")
52
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
53
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
54
+ self.encoder = encoder
55
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
56
+ self.decoder = decoder
57
+ self.length_regulator = length_regulator
58
+ self.only_mask_loss = only_mask_loss
59
+
60
+ def forward(
61
+ self,
62
+ batch: dict,
63
+ device: torch.device,
64
+ ) -> Dict[str, Optional[torch.Tensor]]:
65
+ token = batch['speech_token'].to(device)
66
+ token_len = batch['speech_token_len'].to(device)
67
+ feat = batch['speech_feat'].to(device)
68
+ feat_len = batch['speech_feat_len'].to(device)
69
+ embedding = batch['embedding'].to(device)
70
+
71
+ # xvec projection
72
+ embedding = F.normalize(embedding, dim=1)
73
+ embedding = self.spk_embed_affine_layer(embedding)
74
+
75
+ # concat text and prompt_text
76
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
77
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
78
+
79
+ # text encode
80
+ h, h_lengths = self.encoder(token, token_len)
81
+ h = self.encoder_proj(h)
82
+ h, h_lengths = self.length_regulator(h, feat_len)
83
+
84
+ # get conditions
85
+ conds = torch.zeros(feat.shape, device=token.device)
86
+ for i, j in enumerate(feat_len):
87
+ if random.random() < 0.5:
88
+ continue
89
+ index = random.randint(0, int(0.3 * j))
90
+ conds[i, :index] = feat[i, :index]
91
+ conds = conds.transpose(1, 2)
92
+
93
+ mask = (~make_pad_mask(feat_len)).to(h)
94
+ # NOTE this is unnecessary, feat/h already same shape
95
+ loss, _ = self.decoder.compute_loss(
96
+ feat.transpose(1, 2).contiguous(),
97
+ mask.unsqueeze(1),
98
+ h.transpose(1, 2).contiguous(),
99
+ embedding,
100
+ cond=conds
101
+ )
102
+ return {'loss': loss}
103
+
104
+ @torch.inference_mode()
105
+ def inference(self,
106
+ token,
107
+ token_len,
108
+ prompt_token,
109
+ prompt_token_len,
110
+ prompt_feat,
111
+ prompt_feat_len,
112
+ embedding,
113
+ flow_cache):
114
+ assert token.shape[0] == 1
115
+ # xvec projection
116
+ embedding = F.normalize(embedding, dim=1)
117
+ embedding = self.spk_embed_affine_layer(embedding)
118
+
119
+ # concat speech token and prompt speech token
120
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
121
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
122
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
123
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
124
+
125
+ # text encode
126
+ h, h_lengths = self.encoder(token, token_len)
127
+ h = self.encoder_proj(h)
128
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
129
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
130
+
131
+ # get conditions
132
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
133
+ conds[:, :mel_len1] = prompt_feat
134
+ conds = conds.transpose(1, 2)
135
+
136
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
137
+ feat, flow_cache = self.decoder(
138
+ mu=h.transpose(1, 2).contiguous(),
139
+ mask=mask.unsqueeze(1),
140
+ spks=embedding,
141
+ cond=conds,
142
+ n_timesteps=10,
143
+ prompt_len=mel_len1,
144
+ cache=flow_cache
145
+ )
146
+ feat = feat[:, :, mel_len1:]
147
+ assert feat.shape[2] == mel_len2
148
+ return feat.float(), flow_cache
149
+
150
+
151
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
152
+ def __init__(self,
153
+ input_size: int = 512,
154
+ output_size: int = 80,
155
+ spk_embed_dim: int = 192,
156
+ output_type: str = "mel",
157
+ vocab_size: int = 4096,
158
+ input_frame_rate: int = 50,
159
+ only_mask_loss: bool = True,
160
+ token_mel_ratio: int = 2,
161
+ pre_lookahead_len: int = 3,
162
+ encoder: torch.nn.Module = None,
163
+ decoder: torch.nn.Module = None,
164
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
165
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
166
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
167
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
168
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
169
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
170
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
171
+ super().__init__()
172
+ self.input_size = input_size
173
+ self.output_size = output_size
174
+ self.decoder_conf = decoder_conf
175
+ self.mel_feat_conf = mel_feat_conf
176
+ self.vocab_size = vocab_size
177
+ self.output_type = output_type
178
+ self.input_frame_rate = input_frame_rate
179
+ logging.info(f"input frame rate={self.input_frame_rate}")
180
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
181
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
182
+ self.encoder = encoder
183
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
184
+ self.decoder = decoder
185
+ self.only_mask_loss = only_mask_loss
186
+ self.token_mel_ratio = token_mel_ratio
187
+ self.pre_lookahead_len = pre_lookahead_len
188
+
189
+ def forward(
190
+ self,
191
+ batch: dict,
192
+ device: torch.device,
193
+ ) -> Dict[str, Optional[torch.Tensor]]:
194
+ token = batch['speech_token'].to(device)
195
+ token_len = batch['speech_token_len'].to(device)
196
+ feat = batch['speech_feat'].to(device)
197
+ feat_len = batch['speech_feat_len'].to(device)
198
+ embedding = batch['embedding'].to(device)
199
+
200
+ # NOTE unified training, static_chunk_size > 0 or = 0
201
+ streaming = True if random.random() < 0.5 else False
202
+
203
+ # xvec projection
204
+ embedding = F.normalize(embedding, dim=1)
205
+ embedding = self.spk_embed_affine_layer(embedding)
206
+
207
+ # concat text and prompt_text
208
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
209
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
210
+
211
+ # text encode
212
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
213
+ h = self.encoder_proj(h)
214
+
215
+ # get conditions
216
+ conds = torch.zeros(feat.shape, device=token.device)
217
+ for i, j in enumerate(feat_len):
218
+ if random.random() < 0.5:
219
+ continue
220
+ index = random.randint(0, int(0.3 * j))
221
+ conds[i, :index] = feat[i, :index]
222
+ conds = conds.transpose(1, 2)
223
+
224
+ mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
225
+ loss, _ = self.decoder.compute_loss(
226
+ feat.transpose(1, 2).contiguous(),
227
+ mask.unsqueeze(1),
228
+ h.transpose(1, 2).contiguous(),
229
+ embedding,
230
+ cond=conds,
231
+ streaming=streaming,
232
+ )
233
+ return {'loss': loss}
234
+
235
+ @torch.inference_mode()
236
+ def inference(self,
237
+ token,
238
+ token_len,
239
+ prompt_token,
240
+ prompt_token_len,
241
+ prompt_feat,
242
+ prompt_feat_len,
243
+ embedding,
244
+ streaming,
245
+ finalize):
246
+ assert token.shape[0] == 1
247
+ # xvec projection
248
+ embedding = F.normalize(embedding, dim=1)
249
+ embedding = self.spk_embed_affine_layer(embedding)
250
+
251
+ # concat text and prompt_text
252
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
253
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
254
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
255
+
256
+ # text encode
257
+ if finalize is True:
258
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
259
+ else:
260
+ token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
261
+ h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
262
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
263
+ h = self.encoder_proj(h)
264
+
265
+ # get conditions
266
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
267
+ conds[:, :mel_len1] = prompt_feat
268
+ conds = conds.transpose(1, 2)
269
+
270
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
271
+ feat, _ = self.decoder(
272
+ mu=h.transpose(1, 2).contiguous(),
273
+ mask=mask.unsqueeze(1),
274
+ spks=embedding,
275
+ cond=conds,
276
+ n_timesteps=10,
277
+ streaming=streaming
278
+ )
279
+ feat = feat[:, :, mel_len1:]
280
+ assert feat.shape[2] == mel_len2
281
+ return feat.float(), None
cosyvoice/flow/flow_matching.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import threading
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from matcha.models.components.flow_matching import BASECFM
19
+
20
+
21
+ class ConditionalCFM(BASECFM):
22
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
23
+ super().__init__(
24
+ n_feats=in_channels,
25
+ cfm_params=cfm_params,
26
+ n_spks=n_spks,
27
+ spk_emb_dim=spk_emb_dim,
28
+ )
29
+ self.t_scheduler = cfm_params.t_scheduler
30
+ self.training_cfg_rate = cfm_params.training_cfg_rate
31
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
32
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
33
+ # Just change the architecture of the estimator here
34
+ self.estimator = estimator
35
+ self.lock = threading.Lock()
36
+
37
+ @torch.inference_mode()
38
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
39
+ """Forward diffusion
40
+
41
+ Args:
42
+ mu (torch.Tensor): output of encoder
43
+ shape: (batch_size, n_feats, mel_timesteps)
44
+ mask (torch.Tensor): output_mask
45
+ shape: (batch_size, 1, mel_timesteps)
46
+ n_timesteps (int): number of diffusion steps
47
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
48
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
49
+ shape: (batch_size, spk_emb_dim)
50
+ cond: Not used but kept for future purposes
51
+
52
+ Returns:
53
+ sample: generated mel-spectrogram
54
+ shape: (batch_size, n_feats, mel_timesteps)
55
+ """
56
+
57
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
58
+ cache_size = cache.shape[2]
59
+ # fix prompt and overlap part mu and z
60
+ if cache_size != 0:
61
+ z[:, :, :cache_size] = cache[:, :, :, 0]
62
+ mu[:, :, :cache_size] = cache[:, :, :, 1]
63
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
64
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
65
+ cache = torch.stack([z_cache, mu_cache], dim=-1)
66
+
67
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
68
+ if self.t_scheduler == 'cosine':
69
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
70
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
71
+
72
+ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
73
+ """
74
+ Fixed euler solver for ODEs.
75
+ Args:
76
+ x (torch.Tensor): random noise
77
+ t_span (torch.Tensor): n_timesteps interpolated
78
+ shape: (n_timesteps + 1,)
79
+ mu (torch.Tensor): output of encoder
80
+ shape: (batch_size, n_feats, mel_timesteps)
81
+ mask (torch.Tensor): output_mask
82
+ shape: (batch_size, 1, mel_timesteps)
83
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
84
+ shape: (batch_size, spk_emb_dim)
85
+ cond: Not used but kept for future purposes
86
+ """
87
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
88
+ t = t.unsqueeze(dim=0)
89
+
90
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
91
+ # Or in future might add like a return_all_steps flag
92
+ sol = []
93
+
94
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
95
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
96
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
97
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
98
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
99
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
100
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
101
+ for step in range(1, len(t_span)):
102
+ # Classifier-Free Guidance inference introduced in VoiceBox
103
+ x_in[:] = x
104
+ mask_in[:] = mask
105
+ mu_in[0] = mu
106
+ t_in[:] = t.unsqueeze(0)
107
+ spks_in[0] = spks
108
+ cond_in[0] = cond
109
+ dphi_dt = self.forward_estimator(
110
+ x_in, mask_in,
111
+ mu_in, t_in,
112
+ spks_in,
113
+ cond_in,
114
+ streaming
115
+ )
116
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
117
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
118
+ x = x + dt * dphi_dt
119
+ t = t + dt
120
+ sol.append(x)
121
+ if step < len(t_span) - 1:
122
+ dt = t_span[step + 1] - t
123
+
124
+ return sol[-1].float()
125
+
126
+ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
127
+ if isinstance(self.estimator, torch.nn.Module):
128
+ return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
129
+ else:
130
+ estimator, trt_engine = self.estimator.acquire_estimator()
131
+ estimator.set_input_shape('x', (2, 80, x.size(2)))
132
+ estimator.set_input_shape('mask', (2, 1, x.size(2)))
133
+ estimator.set_input_shape('mu', (2, 80, x.size(2)))
134
+ estimator.set_input_shape('t', (2,))
135
+ estimator.set_input_shape('spks', (2, 80))
136
+ estimator.set_input_shape('cond', (2, 80, x.size(2)))
137
+ data_ptrs = [x.contiguous().data_ptr(),
138
+ mask.contiguous().data_ptr(),
139
+ mu.contiguous().data_ptr(),
140
+ t.contiguous().data_ptr(),
141
+ spks.contiguous().data_ptr(),
142
+ cond.contiguous().data_ptr(),
143
+ x.data_ptr()]
144
+ for i, j in enumerate(data_ptrs):
145
+ estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
146
+ # run trt engine
147
+ assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
148
+ torch.cuda.current_stream().synchronize()
149
+ self.estimator.release_estimator(estimator)
150
+ return x
151
+
152
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
153
+ """Computes diffusion loss
154
+
155
+ Args:
156
+ x1 (torch.Tensor): Target
157
+ shape: (batch_size, n_feats, mel_timesteps)
158
+ mask (torch.Tensor): target mask
159
+ shape: (batch_size, 1, mel_timesteps)
160
+ mu (torch.Tensor): output of encoder
161
+ shape: (batch_size, n_feats, mel_timesteps)
162
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
163
+ shape: (batch_size, spk_emb_dim)
164
+
165
+ Returns:
166
+ loss: conditional flow matching loss
167
+ y: conditional flow
168
+ shape: (batch_size, n_feats, mel_timesteps)
169
+ """
170
+ b, _, t = mu.shape
171
+
172
+ # random timestep
173
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
174
+ if self.t_scheduler == 'cosine':
175
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
176
+ # sample noise p(x_0)
177
+ z = torch.randn_like(x1)
178
+
179
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
180
+ u = x1 - (1 - self.sigma_min) * z
181
+
182
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
183
+ if self.training_cfg_rate > 0:
184
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
185
+ mu = mu * cfg_mask.view(-1, 1, 1)
186
+ spks = spks * cfg_mask.view(-1, 1)
187
+ cond = cond * cfg_mask.view(-1, 1, 1)
188
+
189
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
190
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
191
+ return loss, y
192
+
193
+
194
+ class CausalConditionalCFM(ConditionalCFM):
195
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
196
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
197
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
198
+
199
+ @torch.inference_mode()
200
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
201
+ """Forward diffusion
202
+
203
+ Args:
204
+ mu (torch.Tensor): output of encoder
205
+ shape: (batch_size, n_feats, mel_timesteps)
206
+ mask (torch.Tensor): output_mask
207
+ shape: (batch_size, 1, mel_timesteps)
208
+ n_timesteps (int): number of diffusion steps
209
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
210
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
211
+ shape: (batch_size, spk_emb_dim)
212
+ cond: Not used but kept for future purposes
213
+
214
+ Returns:
215
+ sample: generated mel-spectrogram
216
+ shape: (batch_size, n_feats, mel_timesteps)
217
+ """
218
+
219
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
220
+ # fix prompt and overlap part mu and z
221
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
222
+ if self.t_scheduler == 'cosine':
223
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
224
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
cosyvoice/flow/length_regulator.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ import torch
17
+ from torch.nn import functional as F
18
+ from cosyvoice.utils.mask import make_pad_mask
19
+
20
+
21
+ class InterpolateRegulator(nn.Module):
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ sampling_ratios: Tuple,
26
+ out_channels: int = None,
27
+ groups: int = 1,
28
+ ):
29
+ super().__init__()
30
+ self.sampling_ratios = sampling_ratios
31
+ out_channels = out_channels or channels
32
+ model = nn.ModuleList([])
33
+ if len(sampling_ratios) > 0:
34
+ for _ in sampling_ratios:
35
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
36
+ norm = nn.GroupNorm(groups, channels)
37
+ act = nn.Mish()
38
+ model.extend([module, norm, act])
39
+ model.append(
40
+ nn.Conv1d(channels, out_channels, 1, 1)
41
+ )
42
+ self.model = nn.Sequential(*model)
43
+
44
+ def forward(self, x, ylens=None):
45
+ # x in (B, T, D)
46
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
47
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
48
+ out = self.model(x).transpose(1, 2).contiguous()
49
+ olens = ylens
50
+ return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
55
+ # x in (B, T, D)
56
+ if x2.shape[1] > 40:
57
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
58
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
59
+ mode='linear')
60
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
61
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
62
+ else:
63
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
64
+ if x1.shape[1] != 0:
65
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
66
+ x = torch.concat([x1, x2], dim=2)
67
+ else:
68
+ x = x2
69
+ out = self.model(x).transpose(1, 2).contiguous()
70
+ return out, mel_len1 + mel_len2
cosyvoice/hifigan/discriminator.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ try:
5
+ from torch.nn.utils.parametrizations import weight_norm, spectral_norm
6
+ except ImportError:
7
+ from torch.nn.utils import weight_norm, spectral_norm
8
+ from typing import List, Optional, Tuple
9
+ from einops import rearrange
10
+ from torchaudio.transforms import Spectrogram
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+
15
+ class MultipleDiscriminator(nn.Module):
16
+ def __init__(
17
+ self, mpd: nn.Module, mrd: nn.Module
18
+ ):
19
+ super().__init__()
20
+ self.mpd = mpd
21
+ self.mrd = mrd
22
+
23
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
24
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
25
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
26
+ y_d_rs += this_y_d_rs
27
+ y_d_gs += this_y_d_gs
28
+ fmap_rs += this_fmap_rs
29
+ fmap_gs += this_fmap_gs
30
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
31
+ y_d_rs += this_y_d_rs
32
+ y_d_gs += this_y_d_gs
33
+ fmap_rs += this_fmap_rs
34
+ fmap_gs += this_fmap_gs
35
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
36
+
37
+
38
+ class MultiResolutionDiscriminator(nn.Module):
39
+ def __init__(
40
+ self,
41
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
42
+ num_embeddings: Optional[int] = None,
43
+ ):
44
+ """
45
+ Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
46
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
47
+
48
+ Args:
49
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
50
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
51
+ Defaults to None.
52
+ """
53
+
54
+ super().__init__()
55
+ self.discriminators = nn.ModuleList(
56
+ [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
57
+ )
58
+
59
+ def forward(
60
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
61
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
62
+ y_d_rs = []
63
+ y_d_gs = []
64
+ fmap_rs = []
65
+ fmap_gs = []
66
+
67
+ for d in self.discriminators:
68
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
69
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
70
+ y_d_rs.append(y_d_r)
71
+ fmap_rs.append(fmap_r)
72
+ y_d_gs.append(y_d_g)
73
+ fmap_gs.append(fmap_g)
74
+
75
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
76
+
77
+
78
+ class DiscriminatorR(nn.Module):
79
+ def __init__(
80
+ self,
81
+ window_length: int,
82
+ num_embeddings: Optional[int] = None,
83
+ channels: int = 32,
84
+ hop_factor: float = 0.25,
85
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
86
+ ):
87
+ super().__init__()
88
+ self.window_length = window_length
89
+ self.hop_factor = hop_factor
90
+ self.spec_fn = Spectrogram(
91
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
92
+ )
93
+ n_fft = window_length // 2 + 1
94
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
95
+ self.bands = bands
96
+ convs = lambda: nn.ModuleList(
97
+ [
98
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
99
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
100
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
101
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
102
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
103
+ ]
104
+ )
105
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
106
+
107
+ if num_embeddings is not None:
108
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
109
+ torch.nn.init.zeros_(self.emb.weight)
110
+
111
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
112
+
113
+ def spectrogram(self, x):
114
+ # Remove DC offset
115
+ x = x - x.mean(dim=-1, keepdims=True)
116
+ # Peak normalize the volume of input audio
117
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
118
+ x = self.spec_fn(x)
119
+ x = torch.view_as_real(x)
120
+ x = rearrange(x, "b f t c -> b c t f")
121
+ # Split into bands
122
+ x_bands = [x[..., b[0]: b[1]] for b in self.bands]
123
+ return x_bands
124
+
125
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
126
+ x_bands = self.spectrogram(x)
127
+ fmap = []
128
+ x = []
129
+ for band, stack in zip(x_bands, self.band_convs):
130
+ for i, layer in enumerate(stack):
131
+ band = layer(band)
132
+ band = torch.nn.functional.leaky_relu(band, 0.1)
133
+ if i > 0:
134
+ fmap.append(band)
135
+ x.append(band)
136
+ x = torch.cat(x, dim=-1)
137
+ if cond_embedding_id is not None:
138
+ emb = self.emb(cond_embedding_id)
139
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
140
+ else:
141
+ h = 0
142
+ x = self.conv_post(x)
143
+ fmap.append(x)
144
+ x += h
145
+
146
+ return x, fmap
147
+
148
+
149
+ class MultiResSpecDiscriminator(torch.nn.Module):
150
+
151
+ def __init__(self,
152
+ fft_sizes=[1024, 2048, 512],
153
+ hop_sizes=[120, 240, 50],
154
+ win_lengths=[600, 1200, 240],
155
+ window="hann_window"):
156
+
157
+ super(MultiResSpecDiscriminator, self).__init__()
158
+ self.discriminators = nn.ModuleList([
159
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
160
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
161
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
162
+
163
+ def forward(self, y, y_hat):
164
+ y_d_rs = []
165
+ y_d_gs = []
166
+ fmap_rs = []
167
+ fmap_gs = []
168
+ for _, d in enumerate(self.discriminators):
169
+ y_d_r, fmap_r = d(y)
170
+ y_d_g, fmap_g = d(y_hat)
171
+ y_d_rs.append(y_d_r)
172
+ fmap_rs.append(fmap_r)
173
+ y_d_gs.append(y_d_g)
174
+ fmap_gs.append(fmap_g)
175
+
176
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
177
+
178
+
179
+ def stft(x, fft_size, hop_size, win_length, window):
180
+ """Perform STFT and convert to magnitude spectrogram.
181
+ Args:
182
+ x (Tensor): Input signal tensor (B, T).
183
+ fft_size (int): FFT size.
184
+ hop_size (int): Hop size.
185
+ win_length (int): Window length.
186
+ window (str): Window function type.
187
+ Returns:
188
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
189
+ """
190
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
191
+
192
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
193
+ return torch.abs(x_stft).transpose(2, 1)
194
+
195
+
196
+ class SpecDiscriminator(nn.Module):
197
+ """docstring for Discriminator."""
198
+
199
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
200
+ super(SpecDiscriminator, self).__init__()
201
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
202
+ self.fft_size = fft_size
203
+ self.shift_size = shift_size
204
+ self.win_length = win_length
205
+ self.window = getattr(torch, window)(win_length)
206
+ self.discriminators = nn.ModuleList([
207
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
208
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
209
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
210
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
211
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
212
+ ])
213
+
214
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
215
+
216
+ def forward(self, y):
217
+
218
+ fmap = []
219
+ y = y.squeeze(1)
220
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
221
+ y = y.unsqueeze(1)
222
+ for _, d in enumerate(self.discriminators):
223
+ y = d(y)
224
+ y = F.leaky_relu(y, LRELU_SLOPE)
225
+ fmap.append(y)
226
+
227
+ y = self.out(y)
228
+ fmap.append(y)
229
+
230
+ return torch.flatten(y, 1, -1), fmap
cosyvoice/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ try:
17
+ from torch.nn.utils.parametrizations import weight_norm
18
+ except ImportError:
19
+ from torch.nn.utils import weight_norm
20
+
21
+
22
+ class ConvRNNF0Predictor(nn.Module):
23
+ def __init__(self,
24
+ num_class: int = 1,
25
+ in_channels: int = 80,
26
+ cond_channels: int = 512
27
+ ):
28
+ super().__init__()
29
+
30
+ self.num_class = num_class
31
+ self.condnet = nn.Sequential(
32
+ weight_norm(
33
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
34
+ ),
35
+ nn.ELU(),
36
+ weight_norm(
37
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
38
+ ),
39
+ nn.ELU(),
40
+ weight_norm(
41
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
42
+ ),
43
+ nn.ELU(),
44
+ weight_norm(
45
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
46
+ ),
47
+ nn.ELU(),
48
+ weight_norm(
49
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
50
+ ),
51
+ nn.ELU(),
52
+ )
53
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ x = self.condnet(x)
57
+ x = x.transpose(1, 2)
58
+ return torch.abs(self.classifier(x).squeeze(-1))
cosyvoice/hifigan/generator.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ from typing import Dict, Optional, List
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ try:
27
+ from torch.nn.utils.parametrizations import weight_norm
28
+ except ImportError:
29
+ from torch.nn.utils import weight_norm
30
+ from torch.distributions.uniform import Uniform
31
+
32
+ from cosyvoice.transformer.activation import Snake
33
+ from cosyvoice.utils.common import get_padding
34
+ from cosyvoice.utils.common import init_weights
35
+
36
+
37
+ """hifigan based generator implementation.
38
+
39
+ This code is modified from https://github.com/jik876/hifi-gan
40
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
41
+ https://github.com/NVIDIA/BigVGAN
42
+
43
+ """
44
+
45
+
46
+ class ResBlock(torch.nn.Module):
47
+ """Residual block module in HiFiGAN/BigVGAN."""
48
+ def __init__(
49
+ self,
50
+ channels: int = 512,
51
+ kernel_size: int = 3,
52
+ dilations: List[int] = [1, 3, 5],
53
+ ):
54
+ super(ResBlock, self).__init__()
55
+ self.convs1 = nn.ModuleList()
56
+ self.convs2 = nn.ModuleList()
57
+
58
+ for dilation in dilations:
59
+ self.convs1.append(
60
+ weight_norm(
61
+ Conv1d(
62
+ channels,
63
+ channels,
64
+ kernel_size,
65
+ 1,
66
+ dilation=dilation,
67
+ padding=get_padding(kernel_size, dilation)
68
+ )
69
+ )
70
+ )
71
+ self.convs2.append(
72
+ weight_norm(
73
+ Conv1d(
74
+ channels,
75
+ channels,
76
+ kernel_size,
77
+ 1,
78
+ dilation=1,
79
+ padding=get_padding(kernel_size, 1)
80
+ )
81
+ )
82
+ )
83
+ self.convs1.apply(init_weights)
84
+ self.convs2.apply(init_weights)
85
+ self.activations1 = nn.ModuleList([
86
+ Snake(channels, alpha_logscale=False)
87
+ for _ in range(len(self.convs1))
88
+ ])
89
+ self.activations2 = nn.ModuleList([
90
+ Snake(channels, alpha_logscale=False)
91
+ for _ in range(len(self.convs2))
92
+ ])
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ for idx in range(len(self.convs1)):
96
+ xt = self.activations1[idx](x)
97
+ xt = self.convs1[idx](xt)
98
+ xt = self.activations2[idx](xt)
99
+ xt = self.convs2[idx](xt)
100
+ x = xt + x
101
+ return x
102
+
103
+ def remove_weight_norm(self):
104
+ for idx in range(len(self.convs1)):
105
+ remove_weight_norm(self.convs1[idx])
106
+ remove_weight_norm(self.convs2[idx])
107
+
108
+
109
+ class SineGen(torch.nn.Module):
110
+ """ Definition of sine generator
111
+ SineGen(samp_rate, harmonic_num = 0,
112
+ sine_amp = 0.1, noise_std = 0.003,
113
+ voiced_threshold = 0,
114
+ flag_for_pulse=False)
115
+ samp_rate: sampling rate in Hz
116
+ harmonic_num: number of harmonic overtones (default 0)
117
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
118
+ noise_std: std of Gaussian noise (default 0.003)
119
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
120
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
121
+ Note: when flag_for_pulse is True, the first time step of a voiced
122
+ segment is always sin(np.pi) or cos(0)
123
+ """
124
+
125
+ def __init__(self, samp_rate, harmonic_num=0,
126
+ sine_amp=0.1, noise_std=0.003,
127
+ voiced_threshold=0):
128
+ super(SineGen, self).__init__()
129
+ self.sine_amp = sine_amp
130
+ self.noise_std = noise_std
131
+ self.harmonic_num = harmonic_num
132
+ self.sampling_rate = samp_rate
133
+ self.voiced_threshold = voiced_threshold
134
+
135
+ def _f02uv(self, f0):
136
+ # generate uv signal
137
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
138
+ return uv
139
+
140
+ @torch.no_grad()
141
+ def forward(self, f0):
142
+ """
143
+ :param f0: [B, 1, sample_len], Hz
144
+ :return: [B, 1, sample_len]
145
+ """
146
+
147
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
148
+ for i in range(self.harmonic_num + 1):
149
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
150
+
151
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
152
+ u_dist = Uniform(low=-np.pi, high=np.pi)
153
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
154
+ phase_vec[:, 0, :] = 0
155
+
156
+ # generate sine waveforms
157
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
158
+
159
+ # generate uv signal
160
+ uv = self._f02uv(f0)
161
+
162
+ # noise: for unvoiced should be similar to sine_amp
163
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
164
+ # . for voiced regions is self.noise_std
165
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
166
+ noise = noise_amp * torch.randn_like(sine_waves)
167
+
168
+ # first: set the unvoiced part to 0 by uv
169
+ # then: additive noise
170
+ sine_waves = sine_waves * uv + noise
171
+ return sine_waves, uv, noise
172
+
173
+
174
+ class SourceModuleHnNSF(torch.nn.Module):
175
+ """ SourceModule for hn-nsf
176
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
177
+ add_noise_std=0.003, voiced_threshod=0)
178
+ sampling_rate: sampling_rate in Hz
179
+ harmonic_num: number of harmonic above F0 (default: 0)
180
+ sine_amp: amplitude of sine source signal (default: 0.1)
181
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
182
+ note that amplitude of noise in unvoiced is decided
183
+ by sine_amp
184
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
185
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
186
+ F0_sampled (batchsize, length, 1)
187
+ Sine_source (batchsize, length, 1)
188
+ noise_source (batchsize, length 1)
189
+ uv (batchsize, length, 1)
190
+ """
191
+
192
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
193
+ add_noise_std=0.003, voiced_threshod=0):
194
+ super(SourceModuleHnNSF, self).__init__()
195
+
196
+ self.sine_amp = sine_amp
197
+ self.noise_std = add_noise_std
198
+
199
+ # to produce sine waveforms
200
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
201
+ sine_amp, add_noise_std, voiced_threshod)
202
+
203
+ # to merge source harmonics into a single excitation
204
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
205
+ self.l_tanh = torch.nn.Tanh()
206
+
207
+ def forward(self, x):
208
+ """
209
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
210
+ F0_sampled (batchsize, length, 1)
211
+ Sine_source (batchsize, length, 1)
212
+ noise_source (batchsize, length 1)
213
+ """
214
+ # source for harmonic branch
215
+ with torch.no_grad():
216
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
217
+ sine_wavs = sine_wavs.transpose(1, 2)
218
+ uv = uv.transpose(1, 2)
219
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
220
+
221
+ # source for noise branch, in the same shape as uv
222
+ noise = torch.randn_like(uv) * self.sine_amp / 3
223
+ return sine_merge, noise, uv
224
+
225
+
226
+ class SineGen2(torch.nn.Module):
227
+ """ Definition of sine generator
228
+ SineGen(samp_rate, harmonic_num = 0,
229
+ sine_amp = 0.1, noise_std = 0.003,
230
+ voiced_threshold = 0,
231
+ flag_for_pulse=False)
232
+ samp_rate: sampling rate in Hz
233
+ harmonic_num: number of harmonic overtones (default 0)
234
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
235
+ noise_std: std of Gaussian noise (default 0.003)
236
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
237
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
238
+ Note: when flag_for_pulse is True, the first time step of a voiced
239
+ segment is always sin(np.pi) or cos(0)
240
+ """
241
+
242
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
243
+ sine_amp=0.1, noise_std=0.003,
244
+ voiced_threshold=0,
245
+ flag_for_pulse=False):
246
+ super(SineGen2, self).__init__()
247
+ self.sine_amp = sine_amp
248
+ self.noise_std = noise_std
249
+ self.harmonic_num = harmonic_num
250
+ self.dim = self.harmonic_num + 1
251
+ self.sampling_rate = samp_rate
252
+ self.voiced_threshold = voiced_threshold
253
+ self.flag_for_pulse = flag_for_pulse
254
+ self.upsample_scale = upsample_scale
255
+
256
+ def _f02uv(self, f0):
257
+ # generate uv signal
258
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
259
+ return uv
260
+
261
+ def _f02sine(self, f0_values):
262
+ """ f0_values: (batchsize, length, dim)
263
+ where dim indicates fundamental tone and overtones
264
+ """
265
+ # convert to F0 in rad. The interger part n can be ignored
266
+ # because 2 * np.pi * n doesn't affect phase
267
+ rad_values = (f0_values / self.sampling_rate) % 1
268
+
269
+ # initial phase noise (no noise for fundamental component)
270
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
271
+ rand_ini[:, 0] = 0
272
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
273
+
274
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
275
+ if not self.flag_for_pulse:
276
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
277
+ scale_factor=1 / self.upsample_scale,
278
+ mode="linear").transpose(1, 2)
279
+
280
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
281
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
282
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
283
+ sines = torch.sin(phase)
284
+ else:
285
+ # If necessary, make sure that the first time step of every
286
+ # voiced segments is sin(pi) or cos(0)
287
+ # This is used for pulse-train generation
288
+
289
+ # identify the last time step in unvoiced segments
290
+ uv = self._f02uv(f0_values)
291
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
292
+ uv_1[:, -1, :] = 1
293
+ u_loc = (uv < 1) * (uv_1 > 0)
294
+
295
+ # get the instantanouse phase
296
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
297
+ # different batch needs to be processed differently
298
+ for idx in range(f0_values.shape[0]):
299
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
300
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
301
+ # stores the accumulation of i.phase within
302
+ # each voiced segments
303
+ tmp_cumsum[idx, :, :] = 0
304
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
305
+
306
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
307
+ # within the previous voiced segment.
308
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
309
+
310
+ # get the sines
311
+ sines = torch.cos(i_phase * 2 * np.pi)
312
+ return sines
313
+
314
+ def forward(self, f0):
315
+ """ sine_tensor, uv = forward(f0)
316
+ input F0: tensor(batchsize=1, length, dim=1)
317
+ f0 for unvoiced steps should be 0
318
+ output sine_tensor: tensor(batchsize=1, length, dim)
319
+ output uv: tensor(batchsize=1, length, 1)
320
+ """
321
+ # fundamental component
322
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
323
+
324
+ # generate sine waveforms
325
+ sine_waves = self._f02sine(fn) * self.sine_amp
326
+
327
+ # generate uv signal
328
+ uv = self._f02uv(f0)
329
+
330
+ # noise: for unvoiced should be similar to sine_amp
331
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
332
+ # . for voiced regions is self.noise_std
333
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
334
+ noise = noise_amp * torch.randn_like(sine_waves)
335
+
336
+ # first: set the unvoiced part to 0 by uv
337
+ # then: additive noise
338
+ sine_waves = sine_waves * uv + noise
339
+ return sine_waves, uv, noise
340
+
341
+
342
+ class SourceModuleHnNSF2(torch.nn.Module):
343
+ """ SourceModule for hn-nsf
344
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
345
+ add_noise_std=0.003, voiced_threshod=0)
346
+ sampling_rate: sampling_rate in Hz
347
+ harmonic_num: number of harmonic above F0 (default: 0)
348
+ sine_amp: amplitude of sine source signal (default: 0.1)
349
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
350
+ note that amplitude of noise in unvoiced is decided
351
+ by sine_amp
352
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
353
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
354
+ F0_sampled (batchsize, length, 1)
355
+ Sine_source (batchsize, length, 1)
356
+ noise_source (batchsize, length 1)
357
+ uv (batchsize, length, 1)
358
+ """
359
+
360
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
361
+ add_noise_std=0.003, voiced_threshod=0):
362
+ super(SourceModuleHnNSF2, self).__init__()
363
+
364
+ self.sine_amp = sine_amp
365
+ self.noise_std = add_noise_std
366
+
367
+ # to produce sine waveforms
368
+ self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
369
+ sine_amp, add_noise_std, voiced_threshod)
370
+
371
+ # to merge source harmonics into a single excitation
372
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
373
+ self.l_tanh = torch.nn.Tanh()
374
+
375
+ def forward(self, x):
376
+ """
377
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
378
+ F0_sampled (batchsize, length, 1)
379
+ Sine_source (batchsize, length, 1)
380
+ noise_source (batchsize, length 1)
381
+ """
382
+ # source for harmonic branch
383
+ with torch.no_grad():
384
+ sine_wavs, uv, _ = self.l_sin_gen(x)
385
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
386
+
387
+ # source for noise branch, in the same shape as uv
388
+ noise = torch.randn_like(uv) * self.sine_amp / 3
389
+ return sine_merge, noise, uv
390
+
391
+
392
+ class HiFTGenerator(nn.Module):
393
+ """
394
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
395
+ https://arxiv.org/abs/2309.09493
396
+ """
397
+ def __init__(
398
+ self,
399
+ in_channels: int = 80,
400
+ base_channels: int = 512,
401
+ nb_harmonics: int = 8,
402
+ sampling_rate: int = 22050,
403
+ nsf_alpha: float = 0.1,
404
+ nsf_sigma: float = 0.003,
405
+ nsf_voiced_threshold: float = 10,
406
+ upsample_rates: List[int] = [8, 8],
407
+ upsample_kernel_sizes: List[int] = [16, 16],
408
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
409
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
410
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
411
+ source_resblock_kernel_sizes: List[int] = [7, 11],
412
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
413
+ lrelu_slope: float = 0.1,
414
+ audio_limit: float = 0.99,
415
+ f0_predictor: torch.nn.Module = None,
416
+ ):
417
+ super(HiFTGenerator, self).__init__()
418
+
419
+ self.out_channels = 1
420
+ self.nb_harmonics = nb_harmonics
421
+ self.sampling_rate = sampling_rate
422
+ self.istft_params = istft_params
423
+ self.lrelu_slope = lrelu_slope
424
+ self.audio_limit = audio_limit
425
+
426
+ self.num_kernels = len(resblock_kernel_sizes)
427
+ self.num_upsamples = len(upsample_rates)
428
+ # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
429
+ this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
430
+ self.m_source = this_SourceModuleHnNSF(
431
+ sampling_rate=sampling_rate,
432
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
433
+ harmonic_num=nb_harmonics,
434
+ sine_amp=nsf_alpha,
435
+ add_noise_std=nsf_sigma,
436
+ voiced_threshod=nsf_voiced_threshold)
437
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
438
+
439
+ self.conv_pre = weight_norm(
440
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
441
+ )
442
+
443
+ # Up
444
+ self.ups = nn.ModuleList()
445
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
446
+ self.ups.append(
447
+ weight_norm(
448
+ ConvTranspose1d(
449
+ base_channels // (2**i),
450
+ base_channels // (2**(i + 1)),
451
+ k,
452
+ u,
453
+ padding=(k - u) // 2,
454
+ )
455
+ )
456
+ )
457
+
458
+ # Down
459
+ self.source_downs = nn.ModuleList()
460
+ self.source_resblocks = nn.ModuleList()
461
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
462
+ downsample_cum_rates = np.cumprod(downsample_rates)
463
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
464
+ if u == 1:
465
+ self.source_downs.append(
466
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
467
+ )
468
+ else:
469
+ self.source_downs.append(
470
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
471
+ )
472
+
473
+ self.source_resblocks.append(
474
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
475
+ )
476
+
477
+ self.resblocks = nn.ModuleList()
478
+ for i in range(len(self.ups)):
479
+ ch = base_channels // (2**(i + 1))
480
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
481
+ self.resblocks.append(ResBlock(ch, k, d))
482
+
483
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
484
+ self.ups.apply(init_weights)
485
+ self.conv_post.apply(init_weights)
486
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
487
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
488
+ self.f0_predictor = f0_predictor
489
+
490
+ def remove_weight_norm(self):
491
+ print('Removing weight norm...')
492
+ for l in self.ups:
493
+ remove_weight_norm(l)
494
+ for l in self.resblocks:
495
+ l.remove_weight_norm()
496
+ remove_weight_norm(self.conv_pre)
497
+ remove_weight_norm(self.conv_post)
498
+ self.m_source.remove_weight_norm()
499
+ for l in self.source_downs:
500
+ remove_weight_norm(l)
501
+ for l in self.source_resblocks:
502
+ l.remove_weight_norm()
503
+
504
+ def _stft(self, x):
505
+ spec = torch.stft(
506
+ x,
507
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
508
+ return_complex=True)
509
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
510
+ return spec[..., 0], spec[..., 1]
511
+
512
+ def _istft(self, magnitude, phase):
513
+ magnitude = torch.clip(magnitude, max=1e2)
514
+ real = magnitude * torch.cos(phase)
515
+ img = magnitude * torch.sin(phase)
516
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
517
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
518
+ return inverse_transform
519
+
520
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
521
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
522
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
523
+
524
+ x = self.conv_pre(x)
525
+ for i in range(self.num_upsamples):
526
+ x = F.leaky_relu(x, self.lrelu_slope)
527
+ x = self.ups[i](x)
528
+
529
+ if i == self.num_upsamples - 1:
530
+ x = self.reflection_pad(x)
531
+
532
+ # fusion
533
+ si = self.source_downs[i](s_stft)
534
+ si = self.source_resblocks[i](si)
535
+ x = x + si
536
+
537
+ xs = None
538
+ for j in range(self.num_kernels):
539
+ if xs is None:
540
+ xs = self.resblocks[i * self.num_kernels + j](x)
541
+ else:
542
+ xs += self.resblocks[i * self.num_kernels + j](x)
543
+ x = xs / self.num_kernels
544
+
545
+ x = F.leaky_relu(x)
546
+ x = self.conv_post(x)
547
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
548
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
549
+
550
+ x = self._istft(magnitude, phase)
551
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
552
+ return x
553
+
554
+ def forward(
555
+ self,
556
+ batch: dict,
557
+ device: torch.device,
558
+ ) -> Dict[str, Optional[torch.Tensor]]:
559
+ speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
560
+ # mel->f0
561
+ f0 = self.f0_predictor(speech_feat)
562
+ # f0->source
563
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
564
+ s, _, _ = self.m_source(s)
565
+ s = s.transpose(1, 2)
566
+ # mel+source->speech
567
+ generated_speech = self.decode(x=speech_feat, s=s)
568
+ return generated_speech, f0
569
+
570
+ @torch.inference_mode()
571
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
572
+ # mel->f0
573
+ f0 = self.f0_predictor(speech_feat)
574
+ # f0->source
575
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
576
+ s, _, _ = self.m_source(s)
577
+ s = s.transpose(1, 2)
578
+ # use cache_source to avoid glitch
579
+ if cache_source.shape[2] != 0:
580
+ s[:, :, :cache_source.shape[2]] = cache_source
581
+ generated_speech = self.decode(x=speech_feat, s=s)
582
+ return generated_speech, s
cosyvoice/hifigan/hifigan.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
6
+ from cosyvoice.utils.losses import tpr_loss, mel_loss
7
+
8
+
9
+ class HiFiGan(nn.Module):
10
+ def __init__(self, generator, discriminator, mel_spec_transform,
11
+ multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
12
+ tpr_loss_weight=1.0, tpr_loss_tau=0.04):
13
+ super(HiFiGan, self).__init__()
14
+ self.generator = generator
15
+ self.discriminator = discriminator
16
+ self.mel_spec_transform = mel_spec_transform
17
+ self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
18
+ self.feat_match_loss_weight = feat_match_loss_weight
19
+ self.tpr_loss_weight = tpr_loss_weight
20
+ self.tpr_loss_tau = tpr_loss_tau
21
+
22
+ def forward(
23
+ self,
24
+ batch: dict,
25
+ device: torch.device,
26
+ ) -> Dict[str, Optional[torch.Tensor]]:
27
+ if batch['turn'] == 'generator':
28
+ return self.forward_generator(batch, device)
29
+ else:
30
+ return self.forward_discriminator(batch, device)
31
+
32
+ def forward_generator(self, batch, device):
33
+ real_speech = batch['speech'].to(device)
34
+ pitch_feat = batch['pitch_feat'].to(device)
35
+ # 1. calculate generator outputs
36
+ generated_speech, generated_f0 = self.generator(batch, device)
37
+ # 2. calculate discriminator outputs
38
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
39
+ # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
40
+ loss_gen, _ = generator_loss(y_d_gs)
41
+ loss_fm = feature_loss(fmap_rs, fmap_gs)
42
+ loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
43
+ if self.tpr_loss_weight != 0:
44
+ loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau)
45
+ else:
46
+ loss_tpr = torch.zeros(1).to(device)
47
+ loss_f0 = F.l1_loss(generated_f0, pitch_feat)
48
+ loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
49
+ self.multi_mel_spectral_recon_loss_weight * loss_mel + \
50
+ self.tpr_loss_weight * loss_tpr + loss_f0
51
+ return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
52
+
53
+ def forward_discriminator(self, batch, device):
54
+ real_speech = batch['speech'].to(device)
55
+ # 1. calculate generator outputs
56
+ with torch.no_grad():
57
+ generated_speech, generated_f0 = self.generator(batch, device)
58
+ # 2. calculate discriminator outputs
59
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach())
60
+ # 3. calculate discriminator losses, tpr losses [Optional]
61
+ loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
62
+ if self.tpr_loss_weight != 0:
63
+ loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
64
+ else:
65
+ loss_tpr = torch.zeros(1).to(device)
66
+ loss = loss_disc + self.tpr_loss_weight * loss_tpr
67
+ return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
cosyvoice/llm/llm.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import random
15
+ from typing import Dict, Optional, Callable, List, Generator
16
+ import torch
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+ from transformers import Qwen2ForCausalLM
20
+ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
21
+ from cosyvoice.utils.common import IGNORE_ID
22
+ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
23
+ from cosyvoice.utils.common import th_accuracy
24
+ from cosyvoice.utils.file_utils import logging
25
+ from cosyvoice.utils.mask import make_pad_mask
26
+
27
+
28
+ class TransformerLM(torch.nn.Module):
29
+ def __init__(
30
+ self,
31
+ text_encoder_input_size: int,
32
+ llm_input_size: int,
33
+ llm_output_size: int,
34
+ text_token_size: int,
35
+ speech_token_size: int,
36
+ text_encoder: torch.nn.Module,
37
+ llm: torch.nn.Module,
38
+ sampling: Callable,
39
+ length_normalized_loss: bool = True,
40
+ lsm_weight: float = 0.0,
41
+ spk_embed_dim: int = 192,
42
+ ):
43
+ super().__init__()
44
+ self.llm_input_size = llm_input_size
45
+ self.speech_token_size = speech_token_size
46
+ # 1. build text token inputs related modules
47
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
48
+ self.text_encoder = text_encoder
49
+ self.text_encoder_affine_layer = nn.Linear(
50
+ self.text_encoder.output_size(),
51
+ llm_input_size
52
+ )
53
+
54
+ # 2. build speech token language model related modules
55
+ self.sos_eos = 0
56
+ self.task_id = 1
57
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
58
+ self.llm = llm
59
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
60
+ self.criterion_ce = LabelSmoothingLoss(
61
+ size=speech_token_size + 1,
62
+ padding_idx=IGNORE_ID,
63
+ smoothing=lsm_weight,
64
+ normalize_length=length_normalized_loss,
65
+ )
66
+
67
+ # 3. [Optional] build speech token related modules
68
+ self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
69
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
70
+
71
+ # 4. sampling method
72
+ self.sampling = sampling
73
+
74
+ def encode(
75
+ self,
76
+ text: torch.Tensor,
77
+ text_lengths: torch.Tensor,
78
+ ):
79
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
80
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
81
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
82
+ return encoder_out, encoder_out_lens
83
+
84
+ def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
85
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
86
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
87
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
88
+ for i in range(len(text_token))]
89
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
90
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
91
+ return lm_input, lm_input_len
92
+
93
+ def forward(
94
+ self,
95
+ batch: dict,
96
+ device: torch.device,
97
+ ) -> Dict[str, Optional[torch.Tensor]]:
98
+ """
99
+ Args:
100
+ text: (B, L, D)
101
+ text_lengths: (B,)
102
+ audio: (B, T, N) or (B, T)
103
+ audio_lengths: (B,)
104
+ """
105
+ text_token = batch['text_token'].to(device)
106
+ text_token_len = batch['text_token_len'].to(device)
107
+ speech_token = batch['speech_token'].to(device)
108
+ speech_token_len = batch['speech_token_len'].to(device)
109
+ embedding = batch['embedding'].to(device)
110
+
111
+ # 1. prepare llm_target
112
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
113
+ [self.speech_token_size]) for i in range(text_token.size(0))]
114
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
115
+
116
+ # 1. encode text_token
117
+ text_token = self.text_embedding(text_token)
118
+ text_token, text_token_len = self.encode(text_token, text_token_len)
119
+
120
+ # 2. embedding projection
121
+ embedding = F.normalize(embedding, dim=1)
122
+ embedding = self.spk_embed_affine_layer(embedding)
123
+ embedding = embedding.unsqueeze(1)
124
+
125
+ # 3. eos and task_id
126
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
127
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
128
+
129
+ # 4. encode speech_token
130
+ speech_token = self.speech_embedding(speech_token)
131
+
132
+ # 5. unpad and pad
133
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
134
+ task_id_emb, speech_token, speech_token_len)
135
+
136
+ # 6. run lm forward
137
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
138
+ logits = self.llm_decoder(lm_output)
139
+ loss = self.criterion_ce(logits, lm_target)
140
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
141
+ return {'loss': loss, 'acc': acc}
142
+
143
+ def sampling_ids(
144
+ self,
145
+ weighted_scores: torch.Tensor,
146
+ decoded_tokens: List,
147
+ sampling: int,
148
+ ignore_eos: bool = True,
149
+ ):
150
+ num_trials, max_trials = 0, 100
151
+ while True:
152
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
153
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
154
+ break
155
+ num_trials += 1
156
+ if num_trials > max_trials:
157
+ raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
158
+ return top_ids
159
+
160
+ @torch.inference_mode()
161
+ def inference(
162
+ self,
163
+ text: torch.Tensor,
164
+ text_len: torch.Tensor,
165
+ prompt_text: torch.Tensor,
166
+ prompt_text_len: torch.Tensor,
167
+ prompt_speech_token: torch.Tensor,
168
+ prompt_speech_token_len: torch.Tensor,
169
+ embedding: torch.Tensor,
170
+ sampling: int = 25,
171
+ max_token_text_ratio: float = 20,
172
+ min_token_text_ratio: float = 2,
173
+ ) -> Generator[torch.Tensor, None, None]:
174
+ device = text.device
175
+ text = torch.concat([prompt_text, text], dim=1)
176
+ text_len += prompt_text_len
177
+ text = self.text_embedding(text)
178
+
179
+ # 1. encode text
180
+ text, text_len = self.encode(text, text_len)
181
+
182
+ # 2. encode embedding
183
+ if embedding.shape[0] != 0:
184
+ embedding = F.normalize(embedding, dim=1)
185
+ embedding = self.spk_embed_affine_layer(embedding)
186
+ embedding = embedding.unsqueeze(dim=1)
187
+ else:
188
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
189
+
190
+ # 3. concat llm_input
191
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
192
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
193
+ if prompt_speech_token_len != 0:
194
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
195
+ else:
196
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
197
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
198
+
199
+ # 4. cal min/max_length
200
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
201
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
202
+
203
+ # 5. step by step decode
204
+ out_tokens = []
205
+ offset = 0
206
+ att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
207
+ for i in range(max_len):
208
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
209
+ att_cache=att_cache, cnn_cache=cnn_cache,
210
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
211
+ device=lm_input.device)).to(torch.bool))
212
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
213
+ # force continue decode first token
214
+ if i == 0:
215
+ logp[:, self.speech_token_size] = -float('inf')
216
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
217
+ if top_ids == self.speech_token_size:
218
+ break
219
+ # in stream mode, yield token one by one
220
+ yield top_ids
221
+ out_tokens.append(top_ids)
222
+ offset += lm_input.size(1)
223
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
224
+
225
+
226
+ class Qwen2Encoder(torch.nn.Module):
227
+ def __init__(self, pretrain_path):
228
+ super().__init__()
229
+ self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
230
+
231
+ def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
232
+ T = xs.size(1)
233
+ masks = ~make_pad_mask(xs_lens, T)
234
+ outs = self.model(
235
+ inputs_embeds=xs,
236
+ attention_mask=masks,
237
+ output_hidden_states=True,
238
+ return_dict=True,
239
+ )
240
+ return outs.hidden_states[-1], masks.unsqueeze(1)
241
+
242
+ def forward_one_step(self, xs, masks, cache=None):
243
+ input_masks = masks[:, -1, :]
244
+ outs = self.model(
245
+ inputs_embeds=xs,
246
+ attention_mask=input_masks,
247
+ output_hidden_states=True,
248
+ return_dict=True,
249
+ use_cache=True,
250
+ past_key_values=cache,
251
+ )
252
+ xs = outs.hidden_states[-1]
253
+ new_cache = outs.past_key_values
254
+ return xs, new_cache
255
+
256
+
257
+ class Qwen2LM(TransformerLM):
258
+ def __init__(
259
+ self,
260
+ llm_input_size: int,
261
+ llm_output_size: int,
262
+ speech_token_size: int,
263
+ llm: torch.nn.Module,
264
+ sampling: Callable,
265
+ length_normalized_loss: bool = True,
266
+ lsm_weight: float = 0.0,
267
+ mix_ratio: List[int] = [5, 15],
268
+ ):
269
+ torch.nn.Module.__init__(self)
270
+ self.llm_input_size = llm_input_size
271
+ self.llm_output_size = llm_output_size
272
+ self.speech_token_size = speech_token_size
273
+
274
+ # 2. build speech token language model related modules
275
+ self.sos_eos = 0
276
+ self.task_id = 1
277
+ self.fill_token = 2
278
+
279
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
280
+ self.llm = llm
281
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
282
+ self.criterion_ce = LabelSmoothingLoss(
283
+ size=speech_token_size + 3,
284
+ padding_idx=IGNORE_ID,
285
+ smoothing=lsm_weight,
286
+ normalize_length=length_normalized_loss,
287
+ )
288
+
289
+ # 3. [Optional] build speech token related modules
290
+ self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
291
+
292
+ # 4. sampling method
293
+ self.sampling = sampling
294
+ self.mix_ratio = mix_ratio
295
+
296
+ def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
297
+ lm_target, lm_input = [], []
298
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
299
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
300
+ text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
301
+ speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
302
+ for i in range(len(text_token)):
303
+ # bistream sequence
304
+ if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
305
+ this_lm_target, this_lm_input = [], []
306
+ this_lm_target.append(IGNORE_ID)
307
+ this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
308
+ for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
309
+ this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
310
+ this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
311
+ if len(this_text_token) == self.mix_ratio[0]:
312
+ assert len(this_speech_token) == self.mix_ratio[1]
313
+ this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
314
+ this_lm_target += this_speech_token
315
+ this_lm_target.append(self.speech_token_size + 2)
316
+ this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
317
+ this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
318
+ else:
319
+ this_lm_target += [-1] * len(this_text_token)
320
+ this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
321
+ this_lm_target.append(self.speech_token_size)
322
+ this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
323
+ this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
324
+ this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
325
+ this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
326
+ # unistream sequence
327
+ else:
328
+ this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
329
+ this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
330
+ self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
331
+ lm_target.append(this_lm_target)
332
+ lm_input.append(this_lm_input)
333
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
334
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
335
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
336
+ return lm_target, lm_input, lm_input_len
337
+
338
+ def forward(
339
+ self,
340
+ batch: dict,
341
+ device: torch.device,
342
+ ) -> Dict[str, Optional[torch.Tensor]]:
343
+ """
344
+ Args:
345
+ text: (B, L, D)
346
+ text_lengths: (B,)
347
+ audio: (B, T, N) or (B, T)
348
+ audio_lengths: (B,)
349
+ """
350
+ text_token = batch['text_token'].to(device)
351
+ text_token_len = batch['text_token_len'].to(device)
352
+ speech_token = batch['speech_token'].to(device)
353
+ speech_token_len = batch['speech_token_len'].to(device)
354
+
355
+ # 1. encode text_token
356
+ text_token_emb = self.llm.model.model.embed_tokens(text_token)
357
+
358
+ # 2. encode speech_token
359
+ speech_token_emb = self.speech_embedding(speech_token)
360
+
361
+ # 3. prepare llm_input/target
362
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
363
+ lm_target = lm_target.to(device)
364
+
365
+ # 4. run lm forward
366
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
367
+ logits = self.llm_decoder(lm_output)
368
+ loss = self.criterion_ce(logits, lm_target.to(device))
369
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
370
+ return {'loss': loss, 'acc': acc}
371
+
372
+ @torch.inference_mode()
373
+ def inference(
374
+ self,
375
+ text: torch.Tensor,
376
+ text_len: torch.Tensor,
377
+ prompt_text: torch.Tensor,
378
+ prompt_text_len: torch.Tensor,
379
+ prompt_speech_token: torch.Tensor,
380
+ prompt_speech_token_len: torch.Tensor,
381
+ embedding: torch.Tensor,
382
+ sampling: int = 25,
383
+ max_token_text_ratio: float = 20,
384
+ min_token_text_ratio: float = 2,
385
+ ) -> Generator[torch.Tensor, None, None]:
386
+ device = text.device
387
+ text = torch.concat([prompt_text, text], dim=1)
388
+ text_len += prompt_text_len
389
+ text = self.llm.model.model.embed_tokens(text)
390
+
391
+ # 3. concat llm_input
392
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
393
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
394
+ if prompt_speech_token_len != 0:
395
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
396
+ else:
397
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
398
+ lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
399
+
400
+ # 4. cal min/max_length
401
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
402
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
403
+
404
+ # 5. step by step decode
405
+ out_tokens = []
406
+ cache = None
407
+ for i in range(max_len):
408
+ y_pred, cache = self.llm.forward_one_step(lm_input,
409
+ masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
410
+ cache=cache)
411
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
412
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
413
+ if top_ids == self.speech_token_size:
414
+ break
415
+ if top_ids > self.speech_token_size:
416
+ continue
417
+ # in stream mode, yield token one by one
418
+ yield top_ids
419
+ out_tokens.append(top_ids)
420
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
421
+
422
+ @torch.inference_mode()
423
+ def inference_bistream(
424
+ self,
425
+ text: Generator,
426
+ prompt_text: torch.Tensor,
427
+ prompt_text_len: torch.Tensor,
428
+ prompt_speech_token: torch.Tensor,
429
+ prompt_speech_token_len: torch.Tensor,
430
+ embedding: torch.Tensor,
431
+ sampling: int = 25,
432
+ max_token_text_ratio: float = 20,
433
+ min_token_text_ratio: float = 2,
434
+ ) -> Generator[torch.Tensor, None, None]:
435
+
436
+ device = prompt_text.device
437
+ # 1. prepare input
438
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
439
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
440
+ if prompt_speech_token_len != 0:
441
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
442
+ else:
443
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
444
+ lm_input = torch.concat([sos_eos_emb], dim=1)
445
+
446
+ # 2. iterate text
447
+ out_tokens = []
448
+ cache = None
449
+ # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
450
+ text_cache = self.llm.model.model.embed_tokens(prompt_text)
451
+ next_fill_index = -1
452
+ for this_text in text:
453
+ text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
454
+ # prompt_speech_token_emb not empty, try append to lm_input
455
+ while prompt_speech_token_emb.size(1) != 0:
456
+ if text_cache.size(1) >= self.mix_ratio[0]:
457
+ lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
458
+ logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
459
+ lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
460
+ text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
461
+ else:
462
+ logging.info('not enough text token to decode, wait for more')
463
+ break
464
+ # no prompt_speech_token_emb remain, can decode some speech token
465
+ if prompt_speech_token_emb.size(1) == 0:
466
+ if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
467
+ logging.info('get fill token, need to append more text token')
468
+ if text_cache.size(1) >= self.mix_ratio[0]:
469
+ lm_input_text = text_cache[:, :self.mix_ratio[0]]
470
+ logging.info('append {} text token'.format(lm_input_text.size(1)))
471
+ if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
472
+ lm_input = lm_input_text
473
+ else:
474
+ lm_input = torch.concat([lm_input, lm_input_text], dim=1)
475
+ text_cache = text_cache[:, self.mix_ratio[0]:]
476
+ else:
477
+ logging.info('not enough text token to decode, wait for more')
478
+ continue
479
+ while True:
480
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
481
+ y_pred, cache = self.llm.forward_one_step(lm_input,
482
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
483
+ cache=cache)
484
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
485
+ if next_fill_index != -1 and len(out_tokens) == next_fill_index:
486
+ top_ids = self.speech_token_size + 2
487
+ next_fill_index += (self.mix_ratio[1] + 1)
488
+ else:
489
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
490
+ if top_ids == self.speech_token_size + 2:
491
+ next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
492
+ logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
493
+ out_tokens.append(top_ids)
494
+ if top_ids >= self.speech_token_size:
495
+ if top_ids == self.speech_token_size + 2:
496
+ break
497
+ else:
498
+ raise ValueError('should not get token {}'.format(top_ids))
499
+ yield top_ids
500
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
501
+
502
+ # 3. final decode
503
+ lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
504
+ logging.info('no more text token, decode until met eos')
505
+ while True:
506
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
507
+ y_pred, cache = self.llm.forward_one_step(lm_input,
508
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
509
+ cache=cache)
510
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
511
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
512
+ out_tokens.append(top_ids)
513
+ if top_ids >= self.speech_token_size:
514
+ if top_ids == self.speech_token_size:
515
+ break
516
+ else:
517
+ raise ValueError('should not get token {}'.format(top_ids))
518
+ # in stream mode, yield token one by one
519
+ yield top_ids
520
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
cosyvoice/llm/llm_dpo.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Callable, List, Generator
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+ from transformers import Qwen2ForCausalLM
19
+ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
20
+ from cosyvoice.utils.common import IGNORE_ID
21
+ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
22
+ from cosyvoice.utils.common import th_accuracy
23
+ from cosyvoice.utils.file_utils import logging
24
+ from cosyvoice.utils.mask import make_pad_mask
25
+
26
+
27
+ class TransformerLM(torch.nn.Module):
28
+ def __init__(
29
+ self,
30
+ text_encoder_input_size: int,
31
+ llm_input_size: int,
32
+ llm_output_size: int,
33
+ text_token_size: int,
34
+ speech_token_size: int,
35
+ text_encoder: torch.nn.Module,
36
+ llm: torch.nn.Module,
37
+ sampling: Callable,
38
+ length_normalized_loss: bool = True,
39
+ lsm_weight: float = 0.0,
40
+ spk_embed_dim: int = 192,
41
+ ):
42
+ super().__init__()
43
+ self.llm_input_size = llm_input_size
44
+ self.speech_token_size = speech_token_size
45
+ # 1. build text token inputs related modules
46
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
47
+ self.text_encoder = text_encoder
48
+ self.text_encoder_affine_layer = nn.Linear(
49
+ self.text_encoder.output_size(),
50
+ llm_input_size
51
+ )
52
+
53
+ # 2. build speech token language model related modules
54
+ self.sos_eos = 0
55
+ self.task_id = 1
56
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
57
+ self.llm = llm
58
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
59
+ self.criterion_ce = LabelSmoothingLoss(
60
+ size=speech_token_size + 1,
61
+ padding_idx=IGNORE_ID,
62
+ smoothing=lsm_weight,
63
+ normalize_length=length_normalized_loss,
64
+ )
65
+
66
+ # 3. [Optional] build speech token related modules
67
+ self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
68
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
69
+
70
+ # 4. sampling method
71
+ self.sampling = sampling
72
+
73
+ def encode(
74
+ self,
75
+ text: torch.Tensor,
76
+ text_lengths: torch.Tensor,
77
+ ):
78
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
79
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
80
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
81
+ return encoder_out, encoder_out_lens
82
+
83
+ def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
84
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
85
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
86
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
87
+ for i in range(len(text_token))]
88
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
89
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
90
+ return lm_input, lm_input_len
91
+
92
+ def forward(
93
+ self,
94
+ batch: dict,
95
+ device: torch.device,
96
+ ) -> Dict[str, Optional[torch.Tensor]]:
97
+ """
98
+ Args:
99
+ text: (B, L, D)
100
+ text_lengths: (B,)
101
+ audio: (B, T, N) or (B, T)
102
+ audio_lengths: (B,)
103
+ """
104
+ text_token = batch['text_token'].to(device)
105
+ text_token_len = batch['text_token_len'].to(device)
106
+ speech_token = batch['speech_token'].to(device)
107
+ speech_token_len = batch['speech_token_len'].to(device)
108
+ embedding = batch['embedding'].to(device)
109
+
110
+ # 1. prepare llm_target
111
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
112
+ [self.speech_token_size]) for i in range(text_token.size(0))]
113
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
114
+
115
+ # 1. encode text_token
116
+ text_token = self.text_embedding(text_token)
117
+ text_token, text_token_len = self.encode(text_token, text_token_len)
118
+
119
+ # 2. embedding projection
120
+ embedding = F.normalize(embedding, dim=1)
121
+ embedding = self.spk_embed_affine_layer(embedding)
122
+ embedding = embedding.unsqueeze(1)
123
+
124
+ # 3. eos and task_id
125
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
126
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
127
+
128
+ # 4. encode speech_token
129
+ speech_token = self.speech_embedding(speech_token)
130
+
131
+ # 5. unpad and pad
132
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
133
+ task_id_emb, speech_token, speech_token_len)
134
+
135
+ # 6. run lm forward
136
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
137
+ logits = self.llm_decoder(lm_output)
138
+ loss = self.criterion_ce(logits, lm_target)
139
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
140
+ return {'loss': loss, 'acc': acc}
141
+
142
+ def sampling_ids(
143
+ self,
144
+ weighted_scores: torch.Tensor,
145
+ decoded_tokens: List,
146
+ sampling: int,
147
+ ignore_eos: bool = True,
148
+ ):
149
+ num_trials, max_trials = 0, 100
150
+ while True:
151
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
152
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
153
+ break
154
+ num_trials += 1
155
+ if num_trials > max_trials:
156
+ raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
157
+ return top_ids
158
+
159
+ @torch.inference_mode()
160
+ def inference(
161
+ self,
162
+ text: torch.Tensor,
163
+ text_len: torch.Tensor,
164
+ prompt_text: torch.Tensor,
165
+ prompt_text_len: torch.Tensor,
166
+ prompt_speech_token: torch.Tensor,
167
+ prompt_speech_token_len: torch.Tensor,
168
+ embedding: torch.Tensor,
169
+ sampling: int = 25,
170
+ max_token_text_ratio: float = 20,
171
+ min_token_text_ratio: float = 2,
172
+ ) -> Generator[torch.Tensor, None, None]:
173
+ if self.fp16 is True:
174
+ embedding = embedding.half()
175
+
176
+ device = text.device
177
+ text = torch.concat([prompt_text, text], dim=1)
178
+ text_len += prompt_text_len
179
+ text = self.text_embedding(text)
180
+
181
+ # 1. encode text
182
+ text, text_len = self.encode(text, text_len)
183
+
184
+ # 2. encode embedding
185
+ if embedding.shape[0] != 0:
186
+ embedding = F.normalize(embedding, dim=1)
187
+ embedding = self.spk_embed_affine_layer(embedding)
188
+ embedding = embedding.unsqueeze(dim=1)
189
+ else:
190
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
191
+
192
+ # 3. concat llm_input
193
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
194
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
195
+ if prompt_speech_token_len != 0:
196
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
197
+ else:
198
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
199
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
200
+
201
+ # 4. cal min/max_length
202
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
203
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
204
+
205
+ # 5. step by step decode
206
+ out_tokens = []
207
+ offset = 0
208
+ att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
209
+ for i in range(max_len):
210
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
211
+ att_cache=att_cache, cnn_cache=cnn_cache,
212
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
213
+ device=lm_input.device)).to(torch.bool))
214
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
215
+ # force continue decode first token
216
+ if i == 0:
217
+ logp[:, self.speech_token_size] = -float('inf')
218
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
219
+ if top_ids == self.speech_token_size:
220
+ break
221
+ # in stream mode, yield token one by one
222
+ yield top_ids
223
+ out_tokens.append(top_ids)
224
+ offset += lm_input.size(1)
225
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
226
+
227
+
228
+ class Qwen2Encoder(torch.nn.Module):
229
+ def __init__(self, pretrain_path):
230
+ super().__init__()
231
+ self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
232
+
233
+ def forward_one_step(self, xs, masks, cache=None):
234
+ input_masks = masks[:, -1, :]
235
+ outs = self.model(
236
+ inputs_embeds=xs,
237
+ attention_mask=input_masks,
238
+ output_hidden_states=True,
239
+ return_dict=True,
240
+ use_cache=True,
241
+ past_key_values=cache,
242
+ )
243
+ xs = outs.hidden_states[-1]
244
+ new_cache = outs.past_key_values
245
+ return xs, new_cache
246
+
247
+
248
+ class Qwen2LM(TransformerLM):
249
+ def __init__(
250
+ self,
251
+ llm_input_size: int,
252
+ llm_output_size: int,
253
+ speech_token_size: int,
254
+ llm: torch.nn.Module,
255
+ sampling: Callable,
256
+ length_normalized_loss: bool = True,
257
+ lsm_weight: float = 0.0,
258
+ mix_ratio: List[int] = [5, 15],
259
+ dpo: bool = False,
260
+ ):
261
+ torch.nn.Module.__init__(self)
262
+ self.llm_input_size = llm_input_size
263
+ self.llm_output_size = llm_output_size
264
+ self.speech_token_size = speech_token_size
265
+
266
+ # 2. build speech token language model related modules
267
+ self.sos_eos = 0
268
+ self.task_id = 1
269
+ self.fill_token = 2
270
+
271
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
272
+ self.llm = llm
273
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
274
+ self.criterion_ce = LabelSmoothingLoss(
275
+ size=speech_token_size + 3,
276
+ padding_idx=IGNORE_ID,
277
+ smoothing=lsm_weight,
278
+ normalize_length=length_normalized_loss,
279
+ )
280
+
281
+ # 3. [Optional] build speech token related modules
282
+ self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
283
+
284
+ # 4. sampling method
285
+ self.sampling = sampling
286
+ self.mix_ratio = mix_ratio
287
+
288
+ # 5. [Optional] set dpo
289
+ self.dpo = dpo
290
+
291
+
292
+ def forward(
293
+ self,
294
+ batch: dict,
295
+ device: torch.device,
296
+ ) -> Dict[str, Optional[torch.Tensor]]:
297
+ text_token = batch['text_token'].to(device)
298
+ text_token_len = batch['text_token_len'].to(device)
299
+ speech_token = batch['speech_token'].to(device)
300
+ speech_token_len = batch['speech_token_len'].to(device)
301
+ if self.dpo:
302
+ reject_speech_token = batch['reject_speech_token'].to(device)
303
+ reject_speech_token_len = batch['reject_speech_token_len'].to(device)
304
+ # 1. prepare llm_target
305
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
306
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
307
+ target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
308
+ [self.speech_token_size]) for i in range(text_token.size(0))]
309
+ if self.dpo:
310
+ reject_target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + reject_speech_token[i, :reject_speech_token_len[i]].tolist() +
311
+ [self.speech_token_size]) for i in range(text_token.size(0))]
312
+ target_ids.extend(reject_target_ids)
313
+ target_ids = pad_sequence(target_ids, batch_first=True, padding_value=IGNORE_ID).to(device)
314
+
315
+ # 2. speech token projection
316
+ speech_emb = self.speech_embedding(speech_token)
317
+ if self.dpo:
318
+ reject_speech_emb = self.speech_embedding(reject_speech_token)
319
+
320
+ # 3. text token projection
321
+ text_token_lst = unpad_sequence(text_token, text_token_len, batch_first=True)
322
+ text_emb = [self.llm.model.model.embed_tokens(y) for y in text_token_lst]
323
+
324
+ # 4. prepare llm_input
325
+ speech_emb = unpad_sequence(speech_emb, speech_token_len.cpu(), batch_first=True)
326
+ input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), speech_emb[i]], dim=0)
327
+ for i in range(len(text_emb))]
328
+ if self.dpo:
329
+ reject_speech_emb = unpad_sequence(reject_speech_emb, reject_speech_token_len.cpu(), batch_first=True)
330
+ reject_input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), reject_speech_emb[i]], dim=0)
331
+ for i in range(len(text_emb))]
332
+ input_emb.extend(reject_input_emb)
333
+ input_emb_lengths = torch.tensor([i.size(0) for i in input_emb], dtype=torch.int32).to(device)
334
+ input_emb = pad_sequence(input_emb, batch_first=True, padding_value=IGNORE_ID).to(device)
335
+
336
+ attention_mask = ~make_pad_mask(input_emb_lengths)
337
+
338
+ result = self.llm.model(
339
+ inputs_embeds=input_emb,
340
+ attention_mask=attention_mask,
341
+ return_dict=True
342
+ )
343
+ hidden_states = result.hidden_states
344
+ logits = self.llm_decoder(hidden_states[-1])
345
+ loss = self.criterion_ce(logits[: speech_token.shape[0]], target_ids[: speech_token.shape[0]])
346
+ acc = th_accuracy(
347
+ logits[: speech_token.shape[0]].view(-1, self.speech_token_size + 3),
348
+ target_ids[: speech_token.shape[0]],
349
+ ignore_label=IGNORE_ID,
350
+ )
351
+ if not self.dpo:
352
+ return {
353
+ "loss": loss,
354
+ "acc": acc,
355
+ }
356
+ else:
357
+ all_logps_sum, all_logps_mean = self.get_batch_logps(
358
+ logits, target_ids, attention_mask, text_token_len, average_log_prob=False, ignore_id=IGNORE_ID
359
+ )
360
+ chosen_logps = all_logps_sum[: speech_token.shape[0]]
361
+ rejected_logps = all_logps_sum[speech_token.shape[0]:]
362
+ return {
363
+ "loss": loss,
364
+ "acc": acc,
365
+ "chosen_logps": chosen_logps,
366
+ "rejected_logps": rejected_logps
367
+ }
368
+
369
+
370
+ def get_batch_logps(
371
+ self,
372
+ logits: torch.FloatTensor,
373
+ labels: torch.LongTensor,
374
+ attention_mask,
375
+ prompt_token_lens,
376
+ average_log_prob: bool = False,
377
+ ignore_id: int = -1,
378
+ ) -> torch.FloatTensor:
379
+ """Compute the log probabilities of the given labels under the given logits.
380
+
381
+ Args:
382
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
383
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)
384
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
385
+
386
+ Returns:
387
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
388
+ """
389
+ assert average_log_prob == False
390
+ assert logits.shape[:-1] == labels.shape
391
+ labels = labels[:, 1:].clone()
392
+ logits = logits[:, :-1, :]
393
+ loss_masks = attention_mask.clone().bool()
394
+ # mask prompts
395
+ for mask, text_token_len in zip(loss_masks, prompt_token_lens):
396
+ mask[:text_token_len + 1] = False
397
+ loss_masks = loss_masks[:, 1:]
398
+ labels[loss_masks == False] = 0
399
+ # dummy token; we'll ignore the losses on these tokens later
400
+ ignore = labels == ignore_id
401
+ labels = labels.masked_fill(ignore, 0) # avoid -1 index
402
+ per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) # (bs, time,)
403
+ logprobs_sums = (per_token_logps * loss_masks).sum(-1)
404
+ logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1)
405
+ return logprobs_sums, logprobs_means
406
+
407
+
408
+ @torch.inference_mode()
409
+ def inference(
410
+ self,
411
+ text: torch.Tensor,
412
+ text_len: torch.Tensor,
413
+ prompt_text: torch.Tensor,
414
+ prompt_text_len: torch.Tensor,
415
+ prompt_speech_token: torch.Tensor,
416
+ prompt_speech_token_len: torch.Tensor,
417
+ embedding: torch.Tensor,
418
+ sampling: int = 25,
419
+ max_token_text_ratio: float = 20,
420
+ min_token_text_ratio: float = 2,
421
+ ) -> Generator[torch.Tensor, None, None]:
422
+ device = text.device
423
+ text = torch.concat([prompt_text, text], dim=1)
424
+ text_len += prompt_text_len
425
+ text = self.llm.model.model.embed_tokens(text)
426
+
427
+ # 3. concat llm_input
428
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
429
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
430
+ if prompt_speech_token_len != 0:
431
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
432
+ else:
433
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
434
+ lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
435
+
436
+ # 4. cal min/max_length
437
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
438
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
439
+
440
+ # 5. step by step decode
441
+ out_tokens = []
442
+ cache = None
443
+ for i in range(max_len):
444
+ y_pred, cache = self.llm.forward_one_step(lm_input,
445
+ masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
446
+ cache=cache)
447
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
448
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
449
+ if top_ids == self.speech_token_size:
450
+ break
451
+ if top_ids > self.speech_token_size:
452
+ continue
453
+ # in stream mode, yield token one by one
454
+ yield top_ids
455
+ out_tokens.append(top_ids)
456
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
457
+
458
+ @torch.inference_mode()
459
+ def inference_bistream(
460
+ self,
461
+ text: Generator,
462
+ prompt_text: torch.Tensor,
463
+ prompt_text_len: torch.Tensor,
464
+ prompt_speech_token: torch.Tensor,
465
+ prompt_speech_token_len: torch.Tensor,
466
+ embedding: torch.Tensor,
467
+ sampling: int = 25,
468
+ max_token_text_ratio: float = 20,
469
+ min_token_text_ratio: float = 2,
470
+ ) -> Generator[torch.Tensor, None, None]:
471
+
472
+ device = prompt_text.device
473
+ # 1. prepare input
474
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
475
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
476
+ if prompt_speech_token_len != 0:
477
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
478
+ else:
479
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
480
+ lm_input = torch.concat([sos_eos_emb], dim=1)
481
+
482
+ # 2. iterate text
483
+ out_tokens = []
484
+ cache = None
485
+ # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
486
+ text_cache = self.llm.model.model.embed_tokens(prompt_text)
487
+ next_fill_index = -1
488
+ for this_text in text:
489
+ text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
490
+ # prompt_speech_token_emb not empty, try append to lm_input
491
+ while prompt_speech_token_emb.size(1) != 0:
492
+ if text_cache.size(1) >= self.mix_ratio[0]:
493
+ lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
494
+ logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
495
+ lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
496
+ text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
497
+ else:
498
+ logging.info('not enough text token to decode, wait for more')
499
+ break
500
+ # no prompt_speech_token_emb remain, can decode some speech token
501
+ if prompt_speech_token_emb.size(1) == 0:
502
+ if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
503
+ logging.info('get fill token, need to append more text token')
504
+ if text_cache.size(1) >= self.mix_ratio[0]:
505
+ lm_input_text = text_cache[:, :self.mix_ratio[0]]
506
+ logging.info('append {} text token'.format(lm_input_text.size(1)))
507
+ if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
508
+ lm_input = lm_input_text
509
+ else:
510
+ lm_input = torch.concat([lm_input, lm_input_text], dim=1)
511
+ text_cache = text_cache[:, self.mix_ratio[0]:]
512
+ else:
513
+ logging.info('not enough text token to decode, wait for more')
514
+ continue
515
+ while True:
516
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
517
+ y_pred, cache = self.llm.forward_one_step(lm_input,
518
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
519
+ cache=cache)
520
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
521
+ if next_fill_index != -1 and len(out_tokens) == next_fill_index:
522
+ top_ids = self.speech_token_size + 2
523
+ next_fill_index += (self.mix_ratio[1] + 1)
524
+ else:
525
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
526
+ if top_ids == self.speech_token_size + 2:
527
+ next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
528
+ logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
529
+ out_tokens.append(top_ids)
530
+ if top_ids >= self.speech_token_size:
531
+ if top_ids == self.speech_token_size + 2:
532
+ break
533
+ else:
534
+ raise ValueError('should not get token {}'.format(top_ids))
535
+ yield top_ids
536
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
537
+
538
+ # 3. final decode
539
+ lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
540
+ logging.info('no more text token, decode until met eos')
541
+ while True:
542
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
543
+ y_pred, cache = self.llm.forward_one_step(lm_input,
544
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
545
+ cache=cache)
546
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
547
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
548
+ out_tokens.append(top_ids)
549
+ if top_ids >= self.speech_token_size:
550
+ if top_ids == self.speech_token_size:
551
+ break
552
+ else:
553
+ raise ValueError('should not get token {}'.format(top_ids))
554
+ # in stream mode, yield token one by one
555
+ yield top_ids
556
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
cosyvoice/llm/llm_vllm.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import time
15
+ import queue
16
+ import asyncio
17
+ import threading
18
+ from typing import List, Generator, AsyncGenerator
19
+ import torch
20
+ from cosyvoice.utils.file_utils import logging
21
+ from cosyvoice.llm.llm import Qwen2LM
22
+
23
+ # 启用vllm V1版本
24
+ import os
25
+ os.environ["VLLM_USE_V1"] = '1'
26
+ from vllm import ModelRegistry
27
+ from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput
28
+ from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
29
+ from vllm.sampling_params import SamplingParams
30
+
31
+ from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
32
+ ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
33
+
34
+ # EngineArgs
35
+ ENGINE_ARGS = {
36
+ "block_size": 16,
37
+ "swap_space": 0,
38
+ # "enforce_eager": True,
39
+ "gpu_memory_utilization": 0.4,
40
+ "max_num_batched_tokens": 1024,
41
+ "max_model_len": 1024,
42
+ "max_num_seqs": 256,
43
+ "disable_log_requests": True,
44
+ "disable_log_stats": True,
45
+ "dtype": "float16"
46
+ }
47
+
48
+ from vllm.sampling_params import RequestOutputKind
49
+ # SamplingParams
50
+ SAMPLING_PARAMS = {
51
+ "temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
52
+ "top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
53
+ "top_k": 25,
54
+ # "min_tokens": 80, # 不支持设置最小的tokens数量设置,开启后vllm直接崩溃,无法启动
55
+ # "presence_penalty": 1.0, # 不支持设置
56
+ # "frequency_penalty": 0.0, # 不支持设置
57
+ "max_tokens": 1024,
58
+ "detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效,待后续版本更新后减少计算
59
+ "ignore_eos": False,
60
+ "output_kind": RequestOutputKind.DELTA # 设置为DELTA,如调整该参数,请同时调整llm_inference的处理代码
61
+ }
62
+
63
+ def tensor_to_list(tensor: torch.tensor):
64
+ return tensor.view(-1).cpu().numpy().tolist()
65
+
66
+ class VllmQwen2LM(Qwen2LM):
67
+ def __init__(
68
+ self,
69
+ model_dir,
70
+ mix_ratio: List[int] = [5, 15],
71
+ ):
72
+ self.fp16 = False
73
+ self.half = lambda: None
74
+ self.mix_ratio = mix_ratio
75
+ # ---------------------------------------------
76
+ # vllm engine 的参数配置
77
+ engine_args = AsyncEngineArgs(
78
+ model=model_dir,
79
+ **ENGINE_ARGS,
80
+ )
81
+ self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
82
+
83
+ self.speech_token_size = 6564 # 6561 + 3
84
+ self.llm_token_size = 151936 # llm vocab_size
85
+ self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1
86
+ self.task_token_id = self.sos_eos_token_id + 1
87
+ self.zero_token_id = self.task_token_id + 1
88
+
89
+ # vllm 的推理任务需要在一个固定的事件循环中,因此启动一个后台线程运行转用于推理任务
90
+ self.loop = asyncio.new_event_loop()
91
+ self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True)
92
+ self.loop_thread.start()
93
+
94
+ def _run_event_loop(self):
95
+ asyncio.set_event_loop(self.loop)
96
+ self.loop.run_forever()
97
+
98
+ async def async_llm_inference(self, out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens):
99
+ sampling_params = SamplingParams(**SAMPLING_PARAMS)
100
+ sampling_params.stop_token_ids = stop_token_ids or [6561]
101
+ if max_tokens:
102
+ sampling_params.max_tokens = max_tokens
103
+ async for output in self.llm_engine.generate(
104
+ {
105
+ "prompt_token_ids": prompt_token_ids,
106
+ },
107
+ sampling_params=sampling_params,
108
+ request_id=request_id or f"{time.time()}",
109
+ ):
110
+ out_queue.put((output.outputs[0], output.finished))
111
+
112
+ def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None):
113
+ out_queue = queue.Queue()
114
+ asyncio.run_coroutine_threadsafe(
115
+ self.async_llm_inference(out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens), self.loop
116
+ )
117
+ # 接收 out_queue 返回的结果
118
+ finished = False
119
+ while not finished:
120
+ (output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()
121
+ yield output
122
+
123
+ def inference(
124
+ self,
125
+ text: torch.Tensor,
126
+ text_len: torch.Tensor,
127
+ prompt_text: torch.Tensor,
128
+ prompt_text_len: torch.Tensor,
129
+ prompt_speech_token: torch.Tensor,
130
+ prompt_speech_token_len: torch.Tensor,
131
+ embedding: torch.Tensor,
132
+ sampling: int = 25,
133
+ max_token_text_ratio: float = 20,
134
+ min_token_text_ratio: float = 2,
135
+ ) -> Generator[torch.Tensor|int, None, None]:
136
+ prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
137
+ prompt_speech_token = tensor_to_list(prompt_speech_token)
138
+
139
+ text = tensor_to_list(text + torch.tensor(6564))
140
+ prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \
141
+ [self.task_token_id] + prompt_speech_token
142
+ max_tokens = len(text) * 20
143
+ for output in self.llm_inference(
144
+ prompt_token_ids,
145
+ stop_token_ids=[6561],
146
+ max_tokens=max_tokens,
147
+ ):
148
+ if output.token_ids[-1] == 6561:
149
+ need_add_tokens = output.token_ids[:-1]
150
+ else:
151
+ need_add_tokens = output.token_ids
152
+ for token in need_add_tokens:
153
+ yield token
154
+
155
+ def inference_bistream(
156
+ self,
157
+ text: Generator,
158
+ prompt_text: torch.Tensor,
159
+ prompt_text_len: torch.Tensor,
160
+ prompt_speech_token: torch.Tensor,
161
+ prompt_speech_token_len: torch.Tensor,
162
+ embedding: torch.Tensor,
163
+ sampling: int = 25,
164
+ max_token_text_ratio: float = 20,
165
+ min_token_text_ratio: float = 2,
166
+ ) -> Generator[torch.Tensor, None, None]:
167
+ prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
168
+ prompt_speech_token = tensor_to_list(prompt_speech_token)
169
+
170
+ last_tokens = []
171
+ prompt_token_ids = [self.sos_eos_token_id]
172
+ text_tokens_cache = prompt_text
173
+ for this_text in text:
174
+ this_text = tensor_to_list(this_text + torch.tensor(6564))
175
+ # text need tokens
176
+ assert isinstance(this_text, list), "text need token ids List[int]."
177
+ text_tokens_cache += this_text
178
+ while len(prompt_speech_token) != 0:
179
+ if len(text_tokens_cache) >= self.mix_ratio[0]:
180
+ text_input_token = text_tokens_cache[:self.mix_ratio[0]]
181
+ speech_input_token = prompt_speech_token[:self.mix_ratio[1]]
182
+ prompt_token_ids += text_input_token + speech_input_token
183
+ # reset the last cache
184
+ text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
185
+ prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
186
+ else:
187
+ break
188
+ if len(prompt_speech_token) == 0:
189
+ if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
190
+ if len(text_tokens_cache) >= self.mix_ratio[0]:
191
+ text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
192
+ prompt_token_ids += text_tokens_temp
193
+ text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
194
+ else:
195
+ continue
196
+ for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
197
+ last_tokens = output.token_ids
198
+ if last_tokens[-1] == 6563:
199
+ need_add_tokens = last_tokens[:-1]
200
+ else:
201
+ need_add_tokens = last_tokens
202
+ for token in need_add_tokens:
203
+ yield token
204
+ prompt_token_ids.extend(need_add_tokens)
205
+ prompt_token_ids += text_tokens_cache + [self.task_token_id]
206
+ for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
207
+ if output.token_ids[-1] == 6561:
208
+ need_add_tokens = output.token_ids[:-1]
209
+ else:
210
+ need_add_tokens = output.token_ids
211
+ for token in need_add_tokens:
212
+ yield token
cosyvoice/llm/vllm_use_cosyvoice2_model.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
5
+ # Copyright 2024 The Qwen team.
6
+ # Copyright 2023 The vLLM team.
7
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
10
+ # and OPT implementations in this library. It has been modified from its
11
+ # original forms to accommodate minor architectural differences compared
12
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
13
+ #
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+ """Inference-only Qwen2 model compatible with HuggingFace weights."""
26
+ from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any
27
+ from typing_extensions import TypeVar
28
+
29
+ import torch
30
+ from torch import nn
31
+
32
+ from vllm.attention import AttentionMetadata
33
+ from vllm.config import VllmConfig
34
+ from vllm.logger import init_logger
35
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
37
+ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
38
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
39
+ from vllm.sequence import IntermediateTensors
40
+
41
+ from vllm.model_executor.models.interfaces import T
42
+ from vllm.model_executor.models.qwen2 import Qwen2Model
43
+
44
+ from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
45
+
46
+ logger = init_logger(__name__)
47
+
48
+ IGNORE_ID = -1
49
+
50
+
51
+ class CosyVoice2Model(nn.Module):
52
+
53
+ packed_modules_mapping = {
54
+ "qkv_proj": [
55
+ "q_proj",
56
+ "k_proj",
57
+ "v_proj",
58
+ ],
59
+ "gate_up_proj": [
60
+ "gate_proj",
61
+ "up_proj",
62
+ ],
63
+ }
64
+
65
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
66
+ super().__init__()
67
+ config = vllm_config.model_config.hf_config
68
+ quant_config = vllm_config.quant_config
69
+ lora_config = vllm_config.lora_config
70
+
71
+ self.config = config
72
+ self.lora_config = lora_config
73
+ self.quant_config = quant_config
74
+
75
+ self.llm_input_size = 896
76
+ self.llm_output_size = 896
77
+
78
+ self.speech_token_size = 6561+3
79
+ self.llm_token_size = config.vocab_size
80
+
81
+ # 2. build speech token language model related modules
82
+ self.sos_eos = 0
83
+ self.task_id = 1
84
+ self.fill_token = 2
85
+
86
+
87
+ self.allow_patterns_overrides = ["llm.*"]
88
+ self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size)
89
+ self.model = Qwen2Model(vllm_config=vllm_config,
90
+ prefix=maybe_prefix(prefix, "model"))
91
+
92
+ # self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size)
93
+ self.llm_decoder = ParallelLMHead(self.speech_token_size,
94
+ self.llm_output_size,
95
+ bias=True,
96
+ quant_config=quant_config,
97
+ prefix=maybe_prefix(
98
+ prefix, "llm_decoder"))
99
+ self.logits_processor = LogitsProcessor(self.speech_token_size)
100
+
101
+ # length_normalized_loss: bool = True,
102
+ # lsm_weight: float = 0.0,
103
+ # self.criterion_ce = LabelSmoothingLoss(
104
+ # size=self.speech_token_size,
105
+ # padding_idx=IGNORE_ID,
106
+ # smoothing=lsm_weight,
107
+ # normalize_length=length_normalized_loss,
108
+ # )
109
+
110
+ # 3. [Optional] build speech token related modules
111
+ self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size)
112
+
113
+ # 4. sampling method
114
+ ## use vllm sampling method
115
+ self.sampler = get_sampler()
116
+ self.make_empty_intermediate_tensors = (
117
+ self.model.make_empty_intermediate_tensors)
118
+
119
+ self.mix_ratio: List[int] = [5, 15]
120
+
121
+ # 定义特殊token常量
122
+ self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32)
123
+ self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32) # 163840 + 6564 = 170404
124
+ self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32) # 170405
125
+ self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32)
126
+
127
+ self.zero_embed_buffer = torch.zeros(
128
+ (vllm_config.scheduler_config.max_num_seqs, self.llm_input_size),
129
+ dtype=self.llm_embedding.weight.dtype,
130
+ device=self.llm_embedding.weight.device
131
+ )
132
+ self.inputs_embed_buffer = torch.zeros(
133
+ (vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size),
134
+ dtype=self.llm_embedding.weight.dtype,
135
+ device=self.llm_embedding.weight.device,
136
+ )
137
+
138
+ def get_sos_eos_emb(self):
139
+ return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
140
+
141
+ def get_task_id_emb(self):
142
+ return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
143
+
144
+ def get_input_embeddings(
145
+ self,
146
+ input_ids: torch.Tensor,
147
+ multimodal_embeddings: Optional[T] = None,
148
+ attn_metadata: Optional["AttentionMetadata"] = None,
149
+ ) -> torch.Tensor:
150
+ """
151
+ Returns the input embeddings merged from the text embeddings from
152
+ input_ids and the multimodal embeddings generated from multimodal
153
+ kwargs.
154
+ """
155
+ # 创建掩码,标记哪些 token_id 属于音频 Token
156
+ mask = input_ids < self.speech_token_size
157
+
158
+ # 获取 input_ids 的原始形状
159
+ input_shape = input_ids.shape
160
+ # 展平 input_ids 和掩码以便统一处理
161
+ flat_input_ids = input_ids.view(-1)
162
+ flat_mask = mask.view(-1)
163
+
164
+ inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]]
165
+ inputs_embeds.zero_()
166
+
167
+ # Process speech tokens
168
+ if flat_mask.any():
169
+ speech_token_ids = flat_input_ids[flat_mask]
170
+ inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids)
171
+
172
+ # 处理大于 delta 的 token_id
173
+ if (~flat_mask).any():
174
+ llm_token_ids = flat_input_ids[~flat_mask]
175
+ llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask])
176
+
177
+ sos_eos_mask = llm_token_ids == self.sos_eos_token_id
178
+ task_mask = llm_token_ids == self.task_token_id
179
+ zero_mask = llm_token_ids == self.zero_token_id
180
+ normal_mask = ~(sos_eos_mask | task_mask | zero_mask)
181
+
182
+ # 分层处理逻辑
183
+ # 第一优先级:SOS/EOS标记
184
+ if sos_eos_mask.any():
185
+ llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0)
186
+
187
+ # 第二优先级:任务标记
188
+ if task_mask.any():
189
+ llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0)
190
+
191
+ # 第二优先级:空音频标记
192
+ if zero_mask.any():
193
+ llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])]
194
+
195
+ # 常规LLM token
196
+ if normal_mask.any():
197
+ original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta
198
+ # print('original_ids: ',original_ids)
199
+ llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids)
200
+
201
+ inputs_embeds[~flat_mask] = llm_embeds
202
+
203
+ inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size)
204
+
205
+ # 合并多模态嵌入(如果有)
206
+ if multimodal_embeddings is not None:
207
+ inputs_embeds = merge_multimodal_embeddings(
208
+ input_ids, inputs_embeds, multimodal_embeddings,
209
+ self.config.audio_token_index
210
+ )
211
+ return inputs_embeds
212
+
213
+ def forward(
214
+ self,
215
+ input_ids: torch.Tensor,
216
+ positions: torch.Tensor,
217
+ kv_caches: List[torch.Tensor],
218
+ attn_metadata: AttentionMetadata,
219
+ intermediate_tensors: Optional[IntermediateTensors] = None,
220
+ inputs_embeds: Optional[torch.Tensor] = None,
221
+ ) -> Union[torch.Tensor, IntermediateTensors]:
222
+ if inputs_embeds is None:
223
+ inputs_embeds = self.get_input_embeddings(
224
+ input_ids,
225
+ attn_metadata=attn_metadata,
226
+ )
227
+ return self.model(input_ids, positions, kv_caches,
228
+ attn_metadata, intermediate_tensors,
229
+ inputs_embeds)
230
+
231
+ def compute_logits(
232
+ self,
233
+ hidden_states: torch.Tensor,
234
+ sampling_metadata: SamplingMetadata,
235
+ ) -> Optional[torch.Tensor]:
236
+ logits = self.logits_processor(self.llm_decoder, hidden_states,
237
+ sampling_metadata)
238
+ return logits
239
+
240
+ def sample(
241
+ self,
242
+ logits: torch.Tensor,
243
+ sampling_metadata: SamplingMetadata,
244
+ ) -> Optional[SamplerOutput]:
245
+ next_tokens = self.sampler(logits, sampling_metadata)
246
+ return next_tokens
247
+
248
+ @staticmethod
249
+ def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]:
250
+ for name, param in weights:
251
+ # 处理Qwen2Model核心参数
252
+ if name.startswith("llm."):
253
+ if name.startswith("llm.model.model."):
254
+ name = name.replace("llm.model.model.", "model.")
255
+ else:
256
+ continue
257
+ # print('weights name: ', name)
258
+ yield name, param
259
+
260
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
261
+ weights = self.convert_weights(weights)
262
+ loader = AutoWeightsLoader(self)
263
+ loader.load_weights(weights)
cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
cosyvoice/tokenizer/tokenizer.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Optional
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ from whisper.tokenizer import Tokenizer
8
+
9
+ import tiktoken
10
+
11
+ LANGUAGES = {
12
+ "en": "english",
13
+ "zh": "chinese",
14
+ "de": "german",
15
+ "es": "spanish",
16
+ "ru": "russian",
17
+ "ko": "korean",
18
+ "fr": "french",
19
+ "ja": "japanese",
20
+ "pt": "portuguese",
21
+ "tr": "turkish",
22
+ "pl": "polish",
23
+ "ca": "catalan",
24
+ "nl": "dutch",
25
+ "ar": "arabic",
26
+ "sv": "swedish",
27
+ "it": "italian",
28
+ "id": "indonesian",
29
+ "hi": "hindi",
30
+ "fi": "finnish",
31
+ "vi": "vietnamese",
32
+ "he": "hebrew",
33
+ "uk": "ukrainian",
34
+ "el": "greek",
35
+ "ms": "malay",
36
+ "cs": "czech",
37
+ "ro": "romanian",
38
+ "da": "danish",
39
+ "hu": "hungarian",
40
+ "ta": "tamil",
41
+ "no": "norwegian",
42
+ "th": "thai",
43
+ "ur": "urdu",
44
+ "hr": "croatian",
45
+ "bg": "bulgarian",
46
+ "lt": "lithuanian",
47
+ "la": "latin",
48
+ "mi": "maori",
49
+ "ml": "malayalam",
50
+ "cy": "welsh",
51
+ "sk": "slovak",
52
+ "te": "telugu",
53
+ "fa": "persian",
54
+ "lv": "latvian",
55
+ "bn": "bengali",
56
+ "sr": "serbian",
57
+ "az": "azerbaijani",
58
+ "sl": "slovenian",
59
+ "kn": "kannada",
60
+ "et": "estonian",
61
+ "mk": "macedonian",
62
+ "br": "breton",
63
+ "eu": "basque",
64
+ "is": "icelandic",
65
+ "hy": "armenian",
66
+ "ne": "nepali",
67
+ "mn": "mongolian",
68
+ "bs": "bosnian",
69
+ "kk": "kazakh",
70
+ "sq": "albanian",
71
+ "sw": "swahili",
72
+ "gl": "galician",
73
+ "mr": "marathi",
74
+ "pa": "punjabi",
75
+ "si": "sinhala",
76
+ "km": "khmer",
77
+ "sn": "shona",
78
+ "yo": "yoruba",
79
+ "so": "somali",
80
+ "af": "afrikaans",
81
+ "oc": "occitan",
82
+ "ka": "georgian",
83
+ "be": "belarusian",
84
+ "tg": "tajik",
85
+ "sd": "sindhi",
86
+ "gu": "gujarati",
87
+ "am": "amharic",
88
+ "yi": "yiddish",
89
+ "lo": "lao",
90
+ "uz": "uzbek",
91
+ "fo": "faroese",
92
+ "ht": "haitian creole",
93
+ "ps": "pashto",
94
+ "tk": "turkmen",
95
+ "nn": "nynorsk",
96
+ "mt": "maltese",
97
+ "sa": "sanskrit",
98
+ "lb": "luxembourgish",
99
+ "my": "myanmar",
100
+ "bo": "tibetan",
101
+ "tl": "tagalog",
102
+ "mg": "malagasy",
103
+ "as": "assamese",
104
+ "tt": "tatar",
105
+ "haw": "hawaiian",
106
+ "ln": "lingala",
107
+ "ha": "hausa",
108
+ "ba": "bashkir",
109
+ "jw": "javanese",
110
+ "su": "sundanese",
111
+ "yue": "cantonese",
112
+ "minnan": "minnan",
113
+ "wuyu": "wuyu",
114
+ "dialect": "dialect",
115
+ "zh/en": "zh/en",
116
+ "en/zh": "en/zh",
117
+ }
118
+
119
+ # language code lookup by name, with a few language aliases
120
+ TO_LANGUAGE_CODE = {
121
+ **{language: code for code, language in LANGUAGES.items()},
122
+ "burmese": "my",
123
+ "valencian": "ca",
124
+ "flemish": "nl",
125
+ "haitian": "ht",
126
+ "letzeburgesch": "lb",
127
+ "pushto": "ps",
128
+ "panjabi": "pa",
129
+ "moldavian": "ro",
130
+ "moldovan": "ro",
131
+ "sinhalese": "si",
132
+ "castilian": "es",
133
+ "mandarin": "zh",
134
+ }
135
+
136
+ AUDIO_EVENT = {
137
+ "ASR": "ASR",
138
+ "AED": "AED",
139
+ "SER": "SER",
140
+ "Speech": "Speech",
141
+ "/Speech": "/Speech",
142
+ "BGM": "BGM",
143
+ "/BGM": "/BGM",
144
+ "Laughter": "Laughter",
145
+ "/Laughter": "/Laughter",
146
+ "Applause": "Applause",
147
+ "/Applause": "/Applause",
148
+ }
149
+
150
+ EMOTION = {
151
+ "HAPPY": "HAPPY",
152
+ "SAD": "SAD",
153
+ "ANGRY": "ANGRY",
154
+ "NEUTRAL": "NEUTRAL",
155
+ }
156
+
157
+ TTS_Vocal_Token = {
158
+ "TTS/B": "TTS/B",
159
+ "TTS/O": "TTS/O",
160
+ "TTS/Q": "TTS/Q",
161
+ "TTS/A": "TTS/A",
162
+ "TTS/CO": "TTS/CO",
163
+ "TTS/CL": "TTS/CL",
164
+ "TTS/H": "TTS/H",
165
+ **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
166
+ }
167
+
168
+
169
+ @lru_cache(maxsize=None)
170
+ def get_encoding(name: str = "gpt2", num_languages: int = 99):
171
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
172
+ ranks = {
173
+ base64.b64decode(token): int(rank)
174
+ for token, rank in (line.split() for line in open(vocab_path) if line)
175
+ }
176
+ n_vocab = len(ranks)
177
+ special_tokens = {}
178
+
179
+
180
+ specials = [
181
+ "<|endoftext|>",
182
+ "<|startoftranscript|>",
183
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
184
+ *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
185
+ *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
186
+ "<|translate|>",
187
+ "<|transcribe|>",
188
+ "<|startoflm|>",
189
+ "<|startofprev|>",
190
+ "<|nospeech|>",
191
+ "<|notimestamps|>",
192
+ *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
193
+ *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
194
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
195
+ ]
196
+
197
+
198
+ for token in specials:
199
+ special_tokens[token] = n_vocab
200
+ n_vocab += 1
201
+
202
+ return tiktoken.Encoding(
203
+ name=os.path.basename(vocab_path),
204
+ explicit_n_vocab=n_vocab,
205
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
206
+ mergeable_ranks=ranks,
207
+ special_tokens=special_tokens,
208
+ )
209
+
210
+
211
+ @lru_cache(maxsize=None)
212
+ def get_tokenizer(
213
+ multilingual: bool,
214
+ *,
215
+ num_languages: int = 99,
216
+ language: Optional[str] = None,
217
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
218
+ ) -> Tokenizer:
219
+ if language is not None:
220
+ language = language.lower()
221
+ if language not in LANGUAGES:
222
+ if language in TO_LANGUAGE_CODE:
223
+ language = TO_LANGUAGE_CODE[language]
224
+ else:
225
+ raise ValueError(f"Unsupported language: {language}")
226
+
227
+ if multilingual:
228
+ encoding_name = "multilingual_zh_ja_yue_char_del"
229
+ language = language or "en"
230
+ task = task or "transcribe"
231
+ else:
232
+ encoding_name = "gpt2"
233
+ language = None
234
+ task = None
235
+
236
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages)
237
+
238
+ return Tokenizer(
239
+ encoding=encoding, num_languages=num_languages, language=language, task=task
240
+ )
241
+
242
+
243
+ class QwenTokenizer():
244
+ def __init__(self, token_path, skip_special_tokens=True):
245
+ super().__init__()
246
+ # NOTE: non-chat model, all these special tokens keep randomly initialized.
247
+ special_tokens = {
248
+ 'eos_token': '<|endoftext|>',
249
+ 'pad_token': '<|endoftext|>',
250
+ 'additional_special_tokens': [
251
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
252
+ '[breath]', '<strong>', '</strong>', '[noise]',
253
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
254
+ '[quick_breath]',
255
+ "<laughter>", "</laughter>",
256
+ "[hissing]", "[sigh]", "[vocalized-noise]",
257
+ "[lipsmack]", "[mn]", "<cough>", "<crying/>","</crying>","<crying>"," <laughter/>"
258
+ ]
259
+ }
260
+ self.special_tokens = special_tokens
261
+ self.tokenizer = AutoTokenizer.from_pretrained(token_path)
262
+ self.tokenizer.add_special_tokens(special_tokens)
263
+ self.skip_special_tokens = skip_special_tokens
264
+
265
+ def encode(self, text, **kwargs):
266
+ tokens = self.tokenizer([text], return_tensors="pt")
267
+ tokens = tokens["input_ids"][0].cpu().tolist()
268
+ return tokens
269
+
270
+ def decode(self, tokens):
271
+ tokens = torch.tensor(tokens, dtype=torch.int64)
272
+ text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
273
+ return text
274
+
275
+
276
+ @lru_cache(maxsize=None)
277
+ def get_qwen_tokenizer(
278
+ token_path: str,
279
+ skip_special_tokens: bool
280
+ ) -> QwenTokenizer:
281
+ return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
cosyvoice/transformer/__init__.py ADDED
File without changes
cosyvoice/transformer/activation.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
+ # 2020 Mobvoi Inc (Binbin Zhang)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Swish() activation function for Conformer."""
18
+
19
+ import torch
20
+ from torch import nn, sin, pow
21
+ from torch.nn import Parameter
22
+
23
+
24
+ class Swish(torch.nn.Module):
25
+ """Construct an Swish object."""
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ """Return Swish activation function."""
29
+ return x * torch.sigmoid(x)
30
+
31
+
32
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
+ # LICENSE is in incl_licenses directory.
34
+ class Snake(nn.Module):
35
+ '''
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ '''
50
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
+ '''
52
+ Initialization.
53
+ INPUT:
54
+ - in_features: shape of the input
55
+ - alpha: trainable parameter
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ alpha will be trained along with the rest of your model.
58
+ '''
59
+ super(Snake, self).__init__()
60
+ self.in_features = in_features
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
68
+
69
+ self.alpha.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ '''
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ Snake ∶= x + 1/a * sin^2 (xa)
78
+ '''
79
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(alpha)
82
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
+
84
+ return x
cosyvoice/transformer/attention.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+
53
+ def forward_qkv(
54
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """Transform query, key and value.
57
+
58
+ Args:
59
+ query (torch.Tensor): Query tensor (#batch, time1, size).
60
+ key (torch.Tensor): Key tensor (#batch, time2, size).
61
+ value (torch.Tensor): Value tensor (#batch, time2, size).
62
+
63
+ Returns:
64
+ torch.Tensor: Transformed query tensor, size
65
+ (#batch, n_head, time1, d_k).
66
+ torch.Tensor: Transformed key tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+ torch.Tensor: Transformed value tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+
71
+ """
72
+ n_batch = query.size(0)
73
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
+
80
+ return q, k, v
81
+
82
+ def forward_attention(
83
+ self,
84
+ value: torch.Tensor,
85
+ scores: torch.Tensor,
86
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
+ ) -> torch.Tensor:
88
+ """Compute attention context vector.
89
+
90
+ Args:
91
+ value (torch.Tensor): Transformed value, size
92
+ (#batch, n_head, time2, d_k).
93
+ scores (torch.Tensor): Attention score, size
94
+ (#batch, n_head, time1, time2).
95
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
97
+
98
+ Returns:
99
+ torch.Tensor: Transformed value (#batch, time1, d_model)
100
+ weighted by the attention score (#batch, time1, time2).
101
+
102
+ """
103
+ n_batch = value.size(0)
104
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
+ # 1st chunk to ease the onnx export.]
107
+ # 2. pytorch training
108
+ if mask.size(2) > 0: # time2 > 0
109
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
+ # For last chunk, time2 might be larger than scores.size(-1)
111
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
+ scores = scores.masked_fill(mask, -float('inf'))
113
+ attn = torch.softmax(scores, dim=-1).masked_fill(
114
+ mask, 0.0) # (batch, head, time1, time2)
115
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
+ # 1. onnx(16/-1, -1/-1, 16/0)
117
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
+ else:
119
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
+
121
+ p_attn = self.dropout(attn)
122
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
+ self.h * self.d_k)
125
+ ) # (batch, time1, d_model)
126
+
127
+ return self.linear_out(x) # (batch, time1, d_model)
128
+
129
+ def forward(
130
+ self,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
+ pos_emb: torch.Tensor = torch.empty(0),
136
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute scaled dot product attention.
139
+
140
+ Args:
141
+ query (torch.Tensor): Query tensor (#batch, time1, size).
142
+ key (torch.Tensor): Key tensor (#batch, time2, size).
143
+ value (torch.Tensor): Value tensor (#batch, time2, size).
144
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
+ (#batch, time1, time2).
146
+ 1.When applying cross attention between decoder and encoder,
147
+ the batch padding mask for input is in (#batch, 1, T) shape.
148
+ 2.When applying self attention of encoder,
149
+ the mask is in (#batch, T, T) shape.
150
+ 3.When applying self attention of decoder,
151
+ the mask is in (#batch, L, L) shape.
152
+ 4.If the different position in decoder see different block
153
+ of the encoder, such as Mocha, the passed in mask could be
154
+ in (#batch, L, T) shape. But there is no such case in current
155
+ CosyVoice.
156
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
+ where `cache_t == chunk_size * num_decoding_left_chunks`
158
+ and `head * d_k == size`
159
+
160
+
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
+ where `cache_t == chunk_size * num_decoding_left_chunks`
165
+ and `head * d_k == size`
166
+
167
+ """
168
+ q, k, v = self.forward_qkv(query, key, value)
169
+
170
+ # NOTE(xcsong):
171
+ # when export onnx model, for 1st chunk, we feed
172
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
+ # and we will always do splitting and
176
+ # concatnation(this will simplify onnx export). Note that
177
+ # it's OK to concat & split zero-shaped tensors(see code below).
178
+ # when export jit model, for 1st chunk, we always feed
179
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
+ # >>> a = torch.ones((1, 2, 0, 4))
181
+ # >>> b = torch.ones((1, 2, 3, 4))
182
+ # >>> c = torch.cat((a, b), dim=2)
183
+ # >>> torch.equal(b, c) # True
184
+ # >>> d = torch.split(a, 2, dim=-1)
185
+ # >>> torch.equal(d[0], d[1]) # True
186
+ if cache.size(0) > 0:
187
+ key_cache, value_cache = torch.split(cache,
188
+ cache.size(-1) // 2,
189
+ dim=-1)
190
+ k = torch.cat([key_cache, k], dim=2)
191
+ v = torch.cat([value_cache, v], dim=2)
192
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
+ # non-trivial to calculate `next_cache_start` here.
194
+ new_cache = torch.cat((k, v), dim=-1)
195
+
196
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
+ return self.forward_attention(v, scores, mask), new_cache
198
+
199
+
200
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
+ """Multi-Head Attention layer with relative position encoding.
202
+ Paper: https://arxiv.org/abs/1901.02860
203
+ Args:
204
+ n_head (int): The number of heads.
205
+ n_feat (int): The number of features.
206
+ dropout_rate (float): Dropout rate.
207
+ """
208
+
209
+ def __init__(self,
210
+ n_head: int,
211
+ n_feat: int,
212
+ dropout_rate: float,
213
+ key_bias: bool = True):
214
+ """Construct an RelPositionMultiHeadedAttention object."""
215
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
+ # linear transformation for positional encoding
217
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
+ # these two learnable bias are used in matrix c and matrix d
219
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
+
225
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
226
+ """Compute relative positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
+ time1 means the length of query vector.
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor.
234
+
235
+ """
236
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
237
+ device=x.device,
238
+ dtype=x.dtype)
239
+ x_padded = torch.cat([zero_pad, x], dim=-1)
240
+
241
+ x_padded = x_padded.view(x.size()[0],
242
+ x.size()[1],
243
+ x.size(3) + 1, x.size(2))
244
+ x = x_padded[:, :, 1:].view_as(x)[
245
+ :, :, :, : x.size(-1) // 2 + 1
246
+ ] # only keep the positions from 0 to time2
247
+ return x
248
+
249
+ def forward(
250
+ self,
251
+ query: torch.Tensor,
252
+ key: torch.Tensor,
253
+ value: torch.Tensor,
254
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
255
+ pos_emb: torch.Tensor = torch.empty(0),
256
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
259
+ Args:
260
+ query (torch.Tensor): Query tensor (#batch, time1, size).
261
+ key (torch.Tensor): Key tensor (#batch, time2, size).
262
+ value (torch.Tensor): Value tensor (#batch, time2, size).
263
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
264
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
265
+ pos_emb (torch.Tensor): Positional embedding tensor
266
+ (#batch, time2, size).
267
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
268
+ where `cache_t == chunk_size * num_decoding_left_chunks`
269
+ and `head * d_k == size`
270
+ Returns:
271
+ torch.Tensor: Output tensor (#batch, time1, d_model).
272
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
273
+ where `cache_t == chunk_size * num_decoding_left_chunks`
274
+ and `head * d_k == size`
275
+ """
276
+ q, k, v = self.forward_qkv(query, key, value)
277
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
278
+
279
+ # NOTE(xcsong):
280
+ # when export onnx model, for 1st chunk, we feed
281
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
282
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
283
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
284
+ # and we will always do splitting and
285
+ # concatnation(this will simplify onnx export). Note that
286
+ # it's OK to concat & split zero-shaped tensors(see code below).
287
+ # when export jit model, for 1st chunk, we always feed
288
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
289
+ # >>> a = torch.ones((1, 2, 0, 4))
290
+ # >>> b = torch.ones((1, 2, 3, 4))
291
+ # >>> c = torch.cat((a, b), dim=2)
292
+ # >>> torch.equal(b, c) # True
293
+ # >>> d = torch.split(a, 2, dim=-1)
294
+ # >>> torch.equal(d[0], d[1]) # True
295
+ if cache.size(0) > 0:
296
+ key_cache, value_cache = torch.split(cache,
297
+ cache.size(-1) // 2,
298
+ dim=-1)
299
+ k = torch.cat([key_cache, k], dim=2)
300
+ v = torch.cat([value_cache, v], dim=2)
301
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
302
+ # non-trivial to calculate `next_cache_start` here.
303
+ new_cache = torch.cat((k, v), dim=-1)
304
+
305
+ n_batch_pos = pos_emb.size(0)
306
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
307
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
308
+
309
+ # (batch, head, time1, d_k)
310
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
311
+ # (batch, head, time1, d_k)
312
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
313
+
314
+ # compute attention score
315
+ # first compute matrix a and matrix c
316
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
317
+ # (batch, head, time1, time2)
318
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
319
+
320
+ # compute matrix b and matrix d
321
+ # (batch, head, time1, time2)
322
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
323
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
324
+ if matrix_ac.shape != matrix_bd.shape:
325
+ matrix_bd = self.rel_shift(matrix_bd)
326
+
327
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
328
+ self.d_k) # (batch, head, time1, time2)
329
+
330
+ return self.forward_attention(v, scores, mask), new_cache
cosyvoice/transformer/convolution.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """ConvolutionModule definition."""
17
+
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class ConvolutionModule(nn.Module):
25
+ """ConvolutionModule in Conformer model."""
26
+
27
+ def __init__(self,
28
+ channels: int,
29
+ kernel_size: int = 15,
30
+ activation: nn.Module = nn.ReLU(),
31
+ norm: str = "batch_norm",
32
+ causal: bool = False,
33
+ bias: bool = True):
34
+ """Construct an ConvolutionModule object.
35
+ Args:
36
+ channels (int): The number of channels of conv layers.
37
+ kernel_size (int): Kernel size of conv layers.
38
+ causal (int): Whether use causal convolution or not
39
+ """
40
+ super().__init__()
41
+
42
+ self.pointwise_conv1 = nn.Conv1d(
43
+ channels,
44
+ 2 * channels,
45
+ kernel_size=1,
46
+ stride=1,
47
+ padding=0,
48
+ bias=bias,
49
+ )
50
+ # self.lorder is used to distinguish if it's a causal convolution,
51
+ # if self.lorder > 0: it's a causal convolution, the input will be
52
+ # padded with self.lorder frames on the left in forward.
53
+ # else: it's a symmetrical convolution
54
+ if causal:
55
+ padding = 0
56
+ self.lorder = kernel_size - 1
57
+ else:
58
+ # kernel_size should be an odd number for none causal convolution
59
+ assert (kernel_size - 1) % 2 == 0
60
+ padding = (kernel_size - 1) // 2
61
+ self.lorder = 0
62
+ self.depthwise_conv = nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ stride=1,
67
+ padding=padding,
68
+ groups=channels,
69
+ bias=bias,
70
+ )
71
+
72
+ assert norm in ['batch_norm', 'layer_norm']
73
+ if norm == "batch_norm":
74
+ self.use_layer_norm = False
75
+ self.norm = nn.BatchNorm1d(channels)
76
+ else:
77
+ self.use_layer_norm = True
78
+ self.norm = nn.LayerNorm(channels)
79
+
80
+ self.pointwise_conv2 = nn.Conv1d(
81
+ channels,
82
+ channels,
83
+ kernel_size=1,
84
+ stride=1,
85
+ padding=0,
86
+ bias=bias,
87
+ )
88
+ self.activation = activation
89
+
90
+ def forward(
91
+ self,
92
+ x: torch.Tensor,
93
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
94
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ """Compute convolution module.
97
+ Args:
98
+ x (torch.Tensor): Input tensor (#batch, time, channels).
99
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
100
+ (0, 0, 0) means fake mask.
101
+ cache (torch.Tensor): left context cache, it is only
102
+ used in causal convolution (#batch, channels, cache_t),
103
+ (0, 0, 0) meas fake cache.
104
+ Returns:
105
+ torch.Tensor: Output tensor (#batch, time, channels).
106
+ """
107
+ # exchange the temporal dimension and the feature dimension
108
+ x = x.transpose(1, 2) # (#batch, channels, time)
109
+
110
+ # mask batch padding
111
+ if mask_pad.size(2) > 0: # time > 0
112
+ x.masked_fill_(~mask_pad, 0.0)
113
+
114
+ if self.lorder > 0:
115
+ if cache.size(2) == 0: # cache_t == 0
116
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
117
+ else:
118
+ assert cache.size(0) == x.size(0) # equal batch
119
+ assert cache.size(1) == x.size(1) # equal channel
120
+ x = torch.cat((cache, x), dim=2)
121
+ assert (x.size(2) > self.lorder)
122
+ new_cache = x[:, :, -self.lorder:]
123
+ else:
124
+ # It's better we just return None if no cache is required,
125
+ # However, for JIT export, here we just fake one tensor instead of
126
+ # None.
127
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
128
+
129
+ # GLU mechanism
130
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
131
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
132
+
133
+ # 1D Depthwise Conv
134
+ x = self.depthwise_conv(x)
135
+ if self.use_layer_norm:
136
+ x = x.transpose(1, 2)
137
+ x = self.activation(self.norm(x))
138
+ if self.use_layer_norm:
139
+ x = x.transpose(1, 2)
140
+ x = self.pointwise_conv2(x)
141
+ # mask batch padding
142
+ if mask_pad.size(2) > 0: # time > 0
143
+ x.masked_fill_(~mask_pad, 0.0)
144
+
145
+ return x.transpose(1, 2), new_cache
cosyvoice/transformer/decoder.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Decoder definition."""
17
+ from typing import Tuple, List, Optional
18
+
19
+ import torch
20
+ import torch.utils.checkpoint as ckpt
21
+ import logging
22
+
23
+ from cosyvoice.transformer.decoder_layer import DecoderLayer
24
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
25
+ from cosyvoice.utils.class_utils import (
26
+ COSYVOICE_EMB_CLASSES,
27
+ COSYVOICE_ATTENTION_CLASSES,
28
+ COSYVOICE_ACTIVATION_CLASSES,
29
+ )
30
+ from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask)
31
+
32
+
33
+ class TransformerDecoder(torch.nn.Module):
34
+ """Base class of Transfomer decoder module.
35
+ Args:
36
+ vocab_size: output dim
37
+ encoder_output_size: dimension of attention
38
+ attention_heads: the number of heads of multi head attention
39
+ linear_units: the hidden units number of position-wise feedforward
40
+ num_blocks: the number of decoder blocks
41
+ dropout_rate: dropout rate
42
+ self_attention_dropout_rate: dropout rate for attention
43
+ input_layer: input layer type
44
+ use_output_layer: whether to use output layer
45
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
46
+ normalize_before:
47
+ True: use layer_norm before each sub-block of a layer.
48
+ False: use layer_norm after each sub-block of a layer.
49
+ src_attention: if false, encoder-decoder cross attention is not
50
+ applied, such as CIF model
51
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
52
+ gradient_checkpointing: rerunning a forward-pass segment for each
53
+ checkpointed segment during backward.
54
+ tie_word_embedding: Tie or clone module weights depending of whether we are
55
+ using TorchScript or not
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ vocab_size: int,
61
+ encoder_output_size: int,
62
+ attention_heads: int = 4,
63
+ linear_units: int = 2048,
64
+ num_blocks: int = 6,
65
+ dropout_rate: float = 0.1,
66
+ positional_dropout_rate: float = 0.1,
67
+ self_attention_dropout_rate: float = 0.0,
68
+ src_attention_dropout_rate: float = 0.0,
69
+ input_layer: str = "embed",
70
+ use_output_layer: bool = True,
71
+ normalize_before: bool = True,
72
+ src_attention: bool = True,
73
+ key_bias: bool = True,
74
+ activation_type: str = "relu",
75
+ gradient_checkpointing: bool = False,
76
+ tie_word_embedding: bool = False,
77
+ ):
78
+ super().__init__()
79
+ attention_dim = encoder_output_size
80
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
81
+
82
+ self.embed = torch.nn.Sequential(
83
+ torch.nn.Identity() if input_layer == "no_pos" else
84
+ torch.nn.Embedding(vocab_size, attention_dim),
85
+ COSYVOICE_EMB_CLASSES[input_layer](attention_dim,
86
+ positional_dropout_rate),
87
+ )
88
+
89
+ self.normalize_before = normalize_before
90
+ self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
91
+ self.use_output_layer = use_output_layer
92
+ if use_output_layer:
93
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
94
+ else:
95
+ self.output_layer = torch.nn.Identity()
96
+ self.num_blocks = num_blocks
97
+ self.decoders = torch.nn.ModuleList([
98
+ DecoderLayer(
99
+ attention_dim,
100
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
101
+ attention_heads, attention_dim,
102
+ self_attention_dropout_rate, key_bias),
103
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
104
+ attention_heads, attention_dim, src_attention_dropout_rate,
105
+ key_bias) if src_attention else None,
106
+ PositionwiseFeedForward(attention_dim, linear_units,
107
+ dropout_rate, activation),
108
+ dropout_rate,
109
+ normalize_before,
110
+ ) for _ in range(self.num_blocks)
111
+ ])
112
+
113
+ self.gradient_checkpointing = gradient_checkpointing
114
+ self.tie_word_embedding = tie_word_embedding
115
+
116
+ def forward(
117
+ self,
118
+ memory: torch.Tensor,
119
+ memory_mask: torch.Tensor,
120
+ ys_in_pad: torch.Tensor,
121
+ ys_in_lens: torch.Tensor,
122
+ r_ys_in_pad: torch.Tensor = torch.empty(0),
123
+ reverse_weight: float = 0.0,
124
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
125
+ """Forward decoder.
126
+ Args:
127
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
128
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
129
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
130
+ ys_in_lens: input lengths of this batch (batch)
131
+ r_ys_in_pad: not used in transformer decoder, in order to unify api
132
+ with bidirectional decoder
133
+ reverse_weight: not used in transformer decoder, in order to unify
134
+ api with bidirectional decode
135
+ Returns:
136
+ (tuple): tuple containing:
137
+ x: decoded token score before softmax (batch, maxlen_out,
138
+ vocab_size) if use_output_layer is True,
139
+ torch.tensor(0.0), in order to unify api with bidirectional decoder
140
+ olens: (batch, )
141
+ NOTE(xcsong):
142
+ We pass the `__call__` method of the modules instead of `forward` to the
143
+ checkpointing API because `__call__` attaches all the hooks of the module.
144
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
145
+ """
146
+ tgt = ys_in_pad
147
+ maxlen = tgt.size(1)
148
+ # tgt_mask: (B, 1, L)
149
+ tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
150
+ tgt_mask = tgt_mask.to(tgt.device)
151
+ # m: (1, L, L)
152
+ m = subsequent_mask(tgt_mask.size(-1),
153
+ device=tgt_mask.device).unsqueeze(0)
154
+ # tgt_mask: (B, L, L)
155
+ tgt_mask = tgt_mask & m
156
+ x, _ = self.embed(tgt)
157
+ if self.gradient_checkpointing and self.training:
158
+ x = self.forward_layers_checkpointed(x, tgt_mask, memory,
159
+ memory_mask)
160
+ else:
161
+ x = self.forward_layers(x, tgt_mask, memory, memory_mask)
162
+ if self.normalize_before:
163
+ x = self.after_norm(x)
164
+ if self.use_output_layer:
165
+ x = self.output_layer(x)
166
+ olens = tgt_mask.sum(1)
167
+ return x, torch.tensor(0.0), olens
168
+
169
+ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
170
+ memory: torch.Tensor,
171
+ memory_mask: torch.Tensor) -> torch.Tensor:
172
+ for layer in self.decoders:
173
+ x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
174
+ memory_mask)
175
+ return x
176
+
177
+ @torch.jit.unused
178
+ def forward_layers_checkpointed(self, x: torch.Tensor,
179
+ tgt_mask: torch.Tensor,
180
+ memory: torch.Tensor,
181
+ memory_mask: torch.Tensor) -> torch.Tensor:
182
+ for layer in self.decoders:
183
+ x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
184
+ layer.__call__, x, tgt_mask, memory, memory_mask)
185
+ return x
186
+
187
+ def forward_one_step(
188
+ self,
189
+ memory: torch.Tensor,
190
+ memory_mask: torch.Tensor,
191
+ tgt: torch.Tensor,
192
+ tgt_mask: torch.Tensor,
193
+ cache: Optional[List[torch.Tensor]] = None,
194
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
195
+ """Forward one step.
196
+ This is only used for decoding.
197
+ Args:
198
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
199
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
200
+ tgt: input token ids, int64 (batch, maxlen_out)
201
+ tgt_mask: input token mask, (batch, maxlen_out)
202
+ dtype=torch.uint8 in PyTorch 1.2-
203
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
204
+ cache: cached output list of (batch, max_time_out-1, size)
205
+ Returns:
206
+ y, cache: NN output value and cache per `self.decoders`.
207
+ y.shape` is (batch, maxlen_out, token)
208
+ """
209
+ x, _ = self.embed(tgt)
210
+ new_cache = []
211
+ for i, decoder in enumerate(self.decoders):
212
+ if cache is None:
213
+ c = None
214
+ else:
215
+ c = cache[i]
216
+ x, tgt_mask, memory, memory_mask = decoder(x,
217
+ tgt_mask,
218
+ memory,
219
+ memory_mask,
220
+ cache=c)
221
+ new_cache.append(x)
222
+ if self.normalize_before:
223
+ y = self.after_norm(x[:, -1])
224
+ else:
225
+ y = x[:, -1]
226
+ if self.use_output_layer:
227
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
228
+ return y, new_cache
229
+
230
+ def tie_or_clone_weights(self, jit_mode: bool = True):
231
+ """Tie or clone module weights (between word_emb and output_layer)
232
+ depending of whether we are using TorchScript or not"""
233
+ if not self.use_output_layer:
234
+ return
235
+ if jit_mode:
236
+ logging.info("clone emb.weight to output.weight")
237
+ self.output_layer.weight = torch.nn.Parameter(
238
+ self.embed[0].weight.clone())
239
+ else:
240
+ logging.info("tie emb.weight with output.weight")
241
+ self.output_layer.weight = self.embed[0].weight
242
+
243
+ if getattr(self.output_layer, "bias", None) is not None:
244
+ self.output_layer.bias.data = torch.nn.functional.pad(
245
+ self.output_layer.bias.data,
246
+ (
247
+ 0,
248
+ self.output_layer.weight.shape[0] -
249
+ self.output_layer.bias.shape[0],
250
+ ),
251
+ "constant",
252
+ 0,
253
+ )
254
+
255
+
256
+ class BiTransformerDecoder(torch.nn.Module):
257
+ """Base class of Transfomer decoder module.
258
+ Args:
259
+ vocab_size: output dim
260
+ encoder_output_size: dimension of attention
261
+ attention_heads: the number of heads of multi head attention
262
+ linear_units: the hidden units number of position-wise feedforward
263
+ num_blocks: the number of decoder blocks
264
+ r_num_blocks: the number of right to left decoder blocks
265
+ dropout_rate: dropout rate
266
+ self_attention_dropout_rate: dropout rate for attention
267
+ input_layer: input layer type
268
+ use_output_layer: whether to use output layer
269
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
270
+ normalize_before:
271
+ True: use layer_norm before each sub-block of a layer.
272
+ False: use layer_norm after each sub-block of a layer.
273
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ vocab_size: int,
279
+ encoder_output_size: int,
280
+ attention_heads: int = 4,
281
+ linear_units: int = 2048,
282
+ num_blocks: int = 6,
283
+ r_num_blocks: int = 0,
284
+ dropout_rate: float = 0.1,
285
+ positional_dropout_rate: float = 0.1,
286
+ self_attention_dropout_rate: float = 0.0,
287
+ src_attention_dropout_rate: float = 0.0,
288
+ input_layer: str = "embed",
289
+ use_output_layer: bool = True,
290
+ normalize_before: bool = True,
291
+ key_bias: bool = True,
292
+ gradient_checkpointing: bool = False,
293
+ tie_word_embedding: bool = False,
294
+ ):
295
+
296
+ super().__init__()
297
+ self.tie_word_embedding = tie_word_embedding
298
+ self.left_decoder = TransformerDecoder(
299
+ vocab_size,
300
+ encoder_output_size,
301
+ attention_heads,
302
+ linear_units,
303
+ num_blocks,
304
+ dropout_rate,
305
+ positional_dropout_rate,
306
+ self_attention_dropout_rate,
307
+ src_attention_dropout_rate,
308
+ input_layer,
309
+ use_output_layer,
310
+ normalize_before,
311
+ key_bias=key_bias,
312
+ gradient_checkpointing=gradient_checkpointing,
313
+ tie_word_embedding=tie_word_embedding)
314
+
315
+ self.right_decoder = TransformerDecoder(
316
+ vocab_size,
317
+ encoder_output_size,
318
+ attention_heads,
319
+ linear_units,
320
+ r_num_blocks,
321
+ dropout_rate,
322
+ positional_dropout_rate,
323
+ self_attention_dropout_rate,
324
+ src_attention_dropout_rate,
325
+ input_layer,
326
+ use_output_layer,
327
+ normalize_before,
328
+ key_bias=key_bias,
329
+ gradient_checkpointing=gradient_checkpointing,
330
+ tie_word_embedding=tie_word_embedding)
331
+
332
+ def forward(
333
+ self,
334
+ memory: torch.Tensor,
335
+ memory_mask: torch.Tensor,
336
+ ys_in_pad: torch.Tensor,
337
+ ys_in_lens: torch.Tensor,
338
+ r_ys_in_pad: torch.Tensor,
339
+ reverse_weight: float = 0.0,
340
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
341
+ """Forward decoder.
342
+ Args:
343
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
344
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
345
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
346
+ ys_in_lens: input lengths of this batch (batch)
347
+ r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
348
+ used for right to left decoder
349
+ reverse_weight: used for right to left decoder
350
+ Returns:
351
+ (tuple): tuple containing:
352
+ x: decoded token score before softmax (batch, maxlen_out,
353
+ vocab_size) if use_output_layer is True,
354
+ r_x: x: decoded token score (right to left decoder)
355
+ before softmax (batch, maxlen_out, vocab_size)
356
+ if use_output_layer is True,
357
+ olens: (batch, )
358
+ """
359
+ l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
360
+ ys_in_lens)
361
+ r_x = torch.tensor(0.0)
362
+ if reverse_weight > 0.0:
363
+ r_x, _, olens = self.right_decoder(memory, memory_mask,
364
+ r_ys_in_pad, ys_in_lens)
365
+ return l_x, r_x, olens
366
+
367
+ def forward_one_step(
368
+ self,
369
+ memory: torch.Tensor,
370
+ memory_mask: torch.Tensor,
371
+ tgt: torch.Tensor,
372
+ tgt_mask: torch.Tensor,
373
+ cache: Optional[List[torch.Tensor]] = None,
374
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
375
+ """Forward one step.
376
+ This is only used for decoding.
377
+ Args:
378
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
379
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
380
+ tgt: input token ids, int64 (batch, maxlen_out)
381
+ tgt_mask: input token mask, (batch, maxlen_out)
382
+ dtype=torch.uint8 in PyTorch 1.2-
383
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
384
+ cache: cached output list of (batch, max_time_out-1, size)
385
+ Returns:
386
+ y, cache: NN output value and cache per `self.decoders`.
387
+ y.shape` is (batch, maxlen_out, token)
388
+ """
389
+ return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
390
+ tgt_mask, cache)
391
+
392
+ def tie_or_clone_weights(self, jit_mode: bool = True):
393
+ """Tie or clone module weights (between word_emb and output_layer)
394
+ depending of whether we are using TorchScript or not"""
395
+ self.left_decoder.tie_or_clone_weights(jit_mode)
396
+ self.right_decoder.tie_or_clone_weights(jit_mode)
cosyvoice/transformer/decoder_layer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Decoder self-attention layer definition."""
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+
22
+ class DecoderLayer(nn.Module):
23
+ """Single decoder layer module.
24
+
25
+ Args:
26
+ size (int): Input dimension.
27
+ self_attn (torch.nn.Module): Self-attention module instance.
28
+ `MultiHeadedAttention` instance can be used as the argument.
29
+ src_attn (torch.nn.Module): Inter-attention module instance.
30
+ `MultiHeadedAttention` instance can be used as the argument.
31
+ If `None` is passed, Inter-attention is not used, such as
32
+ CIF, GPT, and other decoder only model.
33
+ feed_forward (torch.nn.Module): Feed-forward module instance.
34
+ `PositionwiseFeedForward` instance can be used as the argument.
35
+ dropout_rate (float): Dropout rate.
36
+ normalize_before (bool):
37
+ True: use layer_norm before each sub-block.
38
+ False: to use layer_norm after each sub-block.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ size: int,
44
+ self_attn: nn.Module,
45
+ src_attn: Optional[nn.Module],
46
+ feed_forward: nn.Module,
47
+ dropout_rate: float,
48
+ normalize_before: bool = True,
49
+ ):
50
+ """Construct an DecoderLayer object."""
51
+ super().__init__()
52
+ self.size = size
53
+ self.self_attn = self_attn
54
+ self.src_attn = src_attn
55
+ self.feed_forward = feed_forward
56
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
57
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
58
+ self.norm3 = nn.LayerNorm(size, eps=1e-5)
59
+ self.dropout = nn.Dropout(dropout_rate)
60
+ self.normalize_before = normalize_before
61
+
62
+ def forward(
63
+ self,
64
+ tgt: torch.Tensor,
65
+ tgt_mask: torch.Tensor,
66
+ memory: torch.Tensor,
67
+ memory_mask: torch.Tensor,
68
+ cache: Optional[torch.Tensor] = None
69
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
70
+ """Compute decoded features.
71
+
72
+ Args:
73
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
74
+ tgt_mask (torch.Tensor): Mask for input tensor
75
+ (#batch, maxlen_out).
76
+ memory (torch.Tensor): Encoded memory
77
+ (#batch, maxlen_in, size).
78
+ memory_mask (torch.Tensor): Encoded memory mask
79
+ (#batch, maxlen_in).
80
+ cache (torch.Tensor): cached tensors.
81
+ (#batch, maxlen_out - 1, size).
82
+
83
+ Returns:
84
+ torch.Tensor: Output tensor (#batch, maxlen_out, size).
85
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
86
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
87
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
88
+
89
+ """
90
+ residual = tgt
91
+ if self.normalize_before:
92
+ tgt = self.norm1(tgt)
93
+
94
+ if cache is None:
95
+ tgt_q = tgt
96
+ tgt_q_mask = tgt_mask
97
+ else:
98
+ # compute only the last frame query keeping dim: max_time_out -> 1
99
+ assert cache.shape == (
100
+ tgt.shape[0],
101
+ tgt.shape[1] - 1,
102
+ self.size,
103
+ ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
104
+ tgt_q = tgt[:, -1:, :]
105
+ residual = residual[:, -1:, :]
106
+ tgt_q_mask = tgt_mask[:, -1:, :]
107
+
108
+ x = residual + self.dropout(
109
+ self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
110
+ if not self.normalize_before:
111
+ x = self.norm1(x)
112
+
113
+ if self.src_attn is not None:
114
+ residual = x
115
+ if self.normalize_before:
116
+ x = self.norm2(x)
117
+ x = residual + self.dropout(
118
+ self.src_attn(x, memory, memory, memory_mask)[0])
119
+ if not self.normalize_before:
120
+ x = self.norm2(x)
121
+
122
+ residual = x
123
+ if self.normalize_before:
124
+ x = self.norm3(x)
125
+ x = residual + self.dropout(self.feed_forward(x))
126
+ if not self.normalize_before:
127
+ x = self.norm3(x)
128
+
129
+ if cache is not None:
130
+ x = torch.cat([cache, x], dim=1)
131
+
132
+ return x, tgt_mask, memory, memory_mask
cosyvoice/transformer/embedding.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Positonal Encoding Module."""
17
+
18
+ import math
19
+ from typing import Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import numpy as np
24
+
25
+
26
+ class PositionalEncoding(torch.nn.Module):
27
+ """Positional encoding.
28
+
29
+ :param int d_model: embedding dim
30
+ :param float dropout_rate: dropout rate
31
+ :param int max_len: maximum input length
32
+
33
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
34
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
35
+ """
36
+
37
+ def __init__(self,
38
+ d_model: int,
39
+ dropout_rate: float,
40
+ max_len: int = 5000,
41
+ reverse: bool = False):
42
+ """Construct an PositionalEncoding object."""
43
+ super().__init__()
44
+ self.d_model = d_model
45
+ self.xscale = math.sqrt(self.d_model)
46
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
47
+ self.max_len = max_len
48
+
49
+ self.pe = torch.zeros(self.max_len, self.d_model)
50
+ position = torch.arange(0, self.max_len,
51
+ dtype=torch.float32).unsqueeze(1)
52
+ div_term = torch.exp(
53
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
54
+ -(math.log(10000.0) / self.d_model))
55
+ self.pe[:, 0::2] = torch.sin(position * div_term)
56
+ self.pe[:, 1::2] = torch.cos(position * div_term)
57
+ self.pe = self.pe.unsqueeze(0)
58
+
59
+ def forward(self,
60
+ x: torch.Tensor,
61
+ offset: Union[int, torch.Tensor] = 0) \
62
+ -> Tuple[torch.Tensor, torch.Tensor]:
63
+ """Add positional encoding.
64
+
65
+ Args:
66
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
67
+ offset (int, torch.tensor): position offset
68
+
69
+ Returns:
70
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
71
+ torch.Tensor: for compatibility to RelPositionalEncoding
72
+ """
73
+
74
+ self.pe = self.pe.to(x.device)
75
+ pos_emb = self.position_encoding(offset, x.size(1), False)
76
+ x = x * self.xscale + pos_emb
77
+ return self.dropout(x), self.dropout(pos_emb)
78
+
79
+ def position_encoding(self,
80
+ offset: Union[int, torch.Tensor],
81
+ size: int,
82
+ apply_dropout: bool = True) -> torch.Tensor:
83
+ """ For getting encoding in a streaming fashion
84
+
85
+ Attention!!!!!
86
+ we apply dropout only once at the whole utterance level in a none
87
+ streaming way, but will call this function several times with
88
+ increasing input size in a streaming scenario, so the dropout will
89
+ be applied several times.
90
+
91
+ Args:
92
+ offset (int or torch.tensor): start offset
93
+ size (int): required size of position encoding
94
+
95
+ Returns:
96
+ torch.Tensor: Corresponding encoding
97
+ """
98
+ # How to subscript a Union type:
99
+ # https://github.com/pytorch/pytorch/issues/69434
100
+ if isinstance(offset, int):
101
+ assert offset + size <= self.max_len
102
+ pos_emb = self.pe[:, offset:offset + size]
103
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
104
+ assert offset + size <= self.max_len
105
+ pos_emb = self.pe[:, offset:offset + size]
106
+ else: # for batched streaming decoding on GPU
107
+ assert torch.max(offset) + size <= self.max_len
108
+ index = offset.unsqueeze(1) + \
109
+ torch.arange(0, size).to(offset.device) # B X T
110
+ flag = index > 0
111
+ # remove negative offset
112
+ index = index * flag
113
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
114
+
115
+ if apply_dropout:
116
+ pos_emb = self.dropout(pos_emb)
117
+ return pos_emb
118
+
119
+
120
+ class RelPositionalEncoding(PositionalEncoding):
121
+ """Relative positional encoding module.
122
+ See : Appendix B in https://arxiv.org/abs/1901.02860
123
+ Args:
124
+ d_model (int): Embedding dimension.
125
+ dropout_rate (float): Dropout rate.
126
+ max_len (int): Maximum input length.
127
+ """
128
+
129
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
130
+ """Initialize class."""
131
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
132
+
133
+ def forward(self,
134
+ x: torch.Tensor,
135
+ offset: Union[int, torch.Tensor] = 0) \
136
+ -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """Compute positional encoding.
138
+ Args:
139
+ x (torch.Tensor): Input tensor (batch, time, `*`).
140
+ Returns:
141
+ torch.Tensor: Encoded tensor (batch, time, `*`).
142
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
143
+ """
144
+ self.pe = self.pe.to(x.device)
145
+ x = x * self.xscale
146
+ pos_emb = self.position_encoding(offset, x.size(1), False)
147
+ return self.dropout(x), self.dropout(pos_emb)
148
+
149
+
150
+ class WhisperPositionalEncoding(PositionalEncoding):
151
+ """ Sinusoids position encoding used in openai-whisper.encoder
152
+ """
153
+
154
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
155
+ super().__init__(d_model, dropout_rate, max_len)
156
+ self.xscale = 1.0
157
+ log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
158
+ inv_timescales = torch.exp(-log_timescale_increment *
159
+ torch.arange(d_model // 2))
160
+ scaled_time = torch.arange(max_len)[:, np.newaxis] * \
161
+ inv_timescales[np.newaxis, :]
162
+ pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
163
+ delattr(self, "pe")
164
+ self.register_buffer("pe", pe.unsqueeze(0))
165
+
166
+
167
+ class LearnablePositionalEncoding(PositionalEncoding):
168
+ """ Learnable position encoding used in openai-whisper.decoder
169
+ """
170
+
171
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
172
+ super().__init__(d_model, dropout_rate, max_len)
173
+ # NOTE(xcsong): overwrite self.pe & self.xscale
174
+ self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
175
+ self.xscale = 1.0
176
+
177
+
178
+ class NoPositionalEncoding(torch.nn.Module):
179
+ """ No position encoding
180
+ """
181
+
182
+ def __init__(self, d_model: int, dropout_rate: float):
183
+ super().__init__()
184
+ self.d_model = d_model
185
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
186
+
187
+ def forward(self,
188
+ x: torch.Tensor,
189
+ offset: Union[int, torch.Tensor] = 0) \
190
+ -> Tuple[torch.Tensor, torch.Tensor]:
191
+ """ Just return zero vector for interface compatibility
192
+ """
193
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
194
+ return self.dropout(x), pos_emb
195
+
196
+ def position_encoding(self, offset: Union[int, torch.Tensor],
197
+ size: int) -> torch.Tensor:
198
+ return torch.zeros(1, size, self.d_model)
199
+
200
+
201
+ class EspnetRelPositionalEncoding(torch.nn.Module):
202
+ """Relative positional encoding module (new implementation).
203
+
204
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
205
+
206
+ See : Appendix B in https://arxiv.org/abs/1901.02860
207
+
208
+ Args:
209
+ d_model (int): Embedding dimension.
210
+ dropout_rate (float): Dropout rate.
211
+ max_len (int): Maximum input length.
212
+
213
+ """
214
+
215
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
216
+ """Construct an PositionalEncoding object."""
217
+ super(EspnetRelPositionalEncoding, self).__init__()
218
+ self.d_model = d_model
219
+ self.xscale = math.sqrt(self.d_model)
220
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
221
+ self.pe = None
222
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
223
+
224
+ def extend_pe(self, x: torch.Tensor):
225
+ """Reset the positional encodings."""
226
+ if self.pe is not None:
227
+ # self.pe contains both positive and negative parts
228
+ # the length of self.pe is 2 * input_len - 1
229
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
230
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
231
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
232
+ return
233
+ # Suppose `i` means to the position of query vecotr and `j` means the
234
+ # position of key vector. We use position relative positions when keys
235
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
236
+ pe_positive = torch.zeros(x.size(1), self.d_model)
237
+ pe_negative = torch.zeros(x.size(1), self.d_model)
238
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
239
+ div_term = torch.exp(
240
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
241
+ * -(math.log(10000.0) / self.d_model)
242
+ )
243
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
244
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
245
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
246
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
247
+
248
+ # Reserve the order of positive indices and concat both positive and
249
+ # negative indices. This is used to support the shifting trick
250
+ # as in https://arxiv.org/abs/1901.02860
251
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
252
+ pe_negative = pe_negative[1:].unsqueeze(0)
253
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
254
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
255
+
256
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
257
+ -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Add positional encoding.
259
+
260
+ Args:
261
+ x (torch.Tensor): Input tensor (batch, time, `*`).
262
+
263
+ Returns:
264
+ torch.Tensor: Encoded tensor (batch, time, `*`).
265
+
266
+ """
267
+ self.extend_pe(x)
268
+ x = x * self.xscale
269
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
270
+ return self.dropout(x), self.dropout(pos_emb)
271
+
272
+ def position_encoding(self,
273
+ offset: Union[int, torch.Tensor],
274
+ size: int) -> torch.Tensor:
275
+ """ For getting encoding in a streaming fashion
276
+
277
+ Attention!!!!!
278
+ we apply dropout only once at the whole utterance level in a none
279
+ streaming way, but will call this function several times with
280
+ increasing input size in a streaming scenario, so the dropout will
281
+ be applied several times.
282
+
283
+ Args:
284
+ offset (int or torch.tensor): start offset
285
+ size (int): required size of position encoding
286
+
287
+ Returns:
288
+ torch.Tensor: Corresponding encoding
289
+ """
290
+ # How to subscript a Union type:
291
+ # https://github.com/pytorch/pytorch/issues/69434
292
+ if isinstance(offset, int):
293
+ pos_emb = self.pe[
294
+ :,
295
+ self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
296
+ ]
297
+ elif isinstance(offset, torch.Tensor):
298
+ pos_emb = self.pe[
299
+ :,
300
+ self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
301
+ ]
302
+ return pos_emb
cosyvoice/transformer/encoder.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
3
+ # 2024 Alibaba Inc (Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ import torch.utils.checkpoint as ckpt
22
+
23
+ from cosyvoice.transformer.convolution import ConvolutionModule
24
+ from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
25
+ from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
26
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
27
+ from cosyvoice.utils.class_utils import (
28
+ COSYVOICE_EMB_CLASSES,
29
+ COSYVOICE_SUBSAMPLE_CLASSES,
30
+ COSYVOICE_ATTENTION_CLASSES,
31
+ COSYVOICE_ACTIVATION_CLASSES,
32
+ )
33
+ from cosyvoice.utils.mask import make_pad_mask
34
+ from cosyvoice.utils.mask import add_optional_chunk_mask
35
+
36
+
37
+ class BaseEncoder(torch.nn.Module):
38
+
39
+ def __init__(
40
+ self,
41
+ input_size: int,
42
+ output_size: int = 256,
43
+ attention_heads: int = 4,
44
+ linear_units: int = 2048,
45
+ num_blocks: int = 6,
46
+ dropout_rate: float = 0.1,
47
+ positional_dropout_rate: float = 0.1,
48
+ attention_dropout_rate: float = 0.0,
49
+ input_layer: str = "conv2d",
50
+ pos_enc_layer_type: str = "abs_pos",
51
+ normalize_before: bool = True,
52
+ static_chunk_size: int = 0,
53
+ use_dynamic_chunk: bool = False,
54
+ global_cmvn: torch.nn.Module = None,
55
+ use_dynamic_left_chunk: bool = False,
56
+ gradient_checkpointing: bool = False,
57
+ ):
58
+ """
59
+ Args:
60
+ input_size (int): input dim
61
+ output_size (int): dimension of attention
62
+ attention_heads (int): the number of heads of multi head attention
63
+ linear_units (int): the hidden units number of position-wise feed
64
+ forward
65
+ num_blocks (int): the number of decoder blocks
66
+ dropout_rate (float): dropout rate
67
+ attention_dropout_rate (float): dropout rate in attention
68
+ positional_dropout_rate (float): dropout rate after adding
69
+ positional encoding
70
+ input_layer (str): input layer type.
71
+ optional [linear, conv2d, conv2d6, conv2d8]
72
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
73
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
74
+ normalize_before (bool):
75
+ True: use layer_norm before each sub-block of a layer.
76
+ False: use layer_norm after each sub-block of a layer.
77
+ static_chunk_size (int): chunk size for static chunk training and
78
+ decoding
79
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
80
+ training or not, You can only use fixed chunk(chunk_size > 0)
81
+ or dyanmic chunk size(use_dynamic_chunk = True)
82
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
83
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
84
+ dynamic chunk training
85
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
86
+ gradient_checkpointing: rerunning a forward-pass segment for each
87
+ checkpointed segment during backward.
88
+ """
89
+ super().__init__()
90
+ self._output_size = output_size
91
+
92
+ self.global_cmvn = global_cmvn
93
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
94
+ input_size,
95
+ output_size,
96
+ dropout_rate,
97
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
98
+ positional_dropout_rate),
99
+ )
100
+
101
+ self.normalize_before = normalize_before
102
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
103
+ self.static_chunk_size = static_chunk_size
104
+ self.use_dynamic_chunk = use_dynamic_chunk
105
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
106
+ self.gradient_checkpointing = gradient_checkpointing
107
+
108
+ def output_size(self) -> int:
109
+ return self._output_size
110
+
111
+ def forward(
112
+ self,
113
+ xs: torch.Tensor,
114
+ xs_lens: torch.Tensor,
115
+ decoding_chunk_size: int = 0,
116
+ num_decoding_left_chunks: int = -1,
117
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
118
+ """Embed positions in tensor.
119
+
120
+ Args:
121
+ xs: padded input tensor (B, T, D)
122
+ xs_lens: input length (B)
123
+ decoding_chunk_size: decoding chunk size for dynamic chunk
124
+ 0: default for training, use random dynamic chunk.
125
+ <0: for decoding, use full chunk.
126
+ >0: for decoding, use fixed chunk size as set.
127
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
128
+ the chunk size is decoding_chunk_size.
129
+ >=0: use num_decoding_left_chunks
130
+ <0: use all left chunks
131
+ Returns:
132
+ encoder output tensor xs, and subsampled masks
133
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
134
+ masks: torch.Tensor batch padding mask after subsample
135
+ (B, 1, T' ~= T/subsample_rate)
136
+ NOTE(xcsong):
137
+ We pass the `__call__` method of the modules instead of `forward` to the
138
+ checkpointing API because `__call__` attaches all the hooks of the module.
139
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
140
+ """
141
+ T = xs.size(1)
142
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
143
+ if self.global_cmvn is not None:
144
+ xs = self.global_cmvn(xs)
145
+ xs, pos_emb, masks = self.embed(xs, masks)
146
+ mask_pad = masks # (B, 1, T/subsample_rate)
147
+ chunk_masks = add_optional_chunk_mask(xs, masks,
148
+ self.use_dynamic_chunk,
149
+ self.use_dynamic_left_chunk,
150
+ decoding_chunk_size,
151
+ self.static_chunk_size,
152
+ num_decoding_left_chunks)
153
+ if self.gradient_checkpointing and self.training:
154
+ xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
155
+ mask_pad)
156
+ else:
157
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
158
+ if self.normalize_before:
159
+ xs = self.after_norm(xs)
160
+ # Here we assume the mask is not changed in encoder layers, so just
161
+ # return the masks before encoder layers, and the masks will be used
162
+ # for cross attention with decoder later
163
+ return xs, masks
164
+
165
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
166
+ pos_emb: torch.Tensor,
167
+ mask_pad: torch.Tensor) -> torch.Tensor:
168
+ for layer in self.encoders:
169
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
170
+ return xs
171
+
172
+ @torch.jit.unused
173
+ def forward_layers_checkpointed(self, xs: torch.Tensor,
174
+ chunk_masks: torch.Tensor,
175
+ pos_emb: torch.Tensor,
176
+ mask_pad: torch.Tensor) -> torch.Tensor:
177
+ for layer in self.encoders:
178
+ xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
179
+ chunk_masks, pos_emb,
180
+ mask_pad)
181
+ return xs
182
+
183
+ @torch.jit.export
184
+ def forward_chunk(
185
+ self,
186
+ xs: torch.Tensor,
187
+ offset: int,
188
+ required_cache_size: int,
189
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
190
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
191
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
192
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
193
+ """ Forward just one chunk
194
+
195
+ Args:
196
+ xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
197
+ where `time == (chunk_size - 1) * subsample_rate + \
198
+ subsample.right_context + 1`
199
+ offset (int): current offset in encoder output time stamp
200
+ required_cache_size (int): cache size required for next chunk
201
+ compuation
202
+ >=0: actual cache size
203
+ <0: means all history cache is required
204
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
205
+ transformer/conformer attention, with shape
206
+ (elayers, head, cache_t1, d_k * 2), where
207
+ `head * d_k == hidden-dim` and
208
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
209
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
210
+ (elayers, b=1, hidden-dim, cache_t2), where
211
+ `cache_t2 == cnn.lorder - 1`
212
+
213
+ Returns:
214
+ torch.Tensor: output of current input xs,
215
+ with shape (b=1, chunk_size, hidden-dim).
216
+ torch.Tensor: new attention cache required for next chunk, with
217
+ dynamic shape (elayers, head, ?, d_k * 2)
218
+ depending on required_cache_size.
219
+ torch.Tensor: new conformer cnn cache required for next chunk, with
220
+ same shape as the original cnn_cache.
221
+
222
+ """
223
+ assert xs.size(0) == 1
224
+ # tmp_masks is just for interface compatibility
225
+ tmp_masks = torch.ones(1,
226
+ xs.size(1),
227
+ device=xs.device,
228
+ dtype=torch.bool)
229
+ tmp_masks = tmp_masks.unsqueeze(1)
230
+ if self.global_cmvn is not None:
231
+ xs = self.global_cmvn(xs)
232
+ # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
233
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
234
+ # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
235
+ elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
236
+ chunk_size = xs.size(1)
237
+ attention_key_size = cache_t1 + chunk_size
238
+ pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
239
+ size=attention_key_size)
240
+ if required_cache_size < 0:
241
+ next_cache_start = 0
242
+ elif required_cache_size == 0:
243
+ next_cache_start = attention_key_size
244
+ else:
245
+ next_cache_start = max(attention_key_size - required_cache_size, 0)
246
+ r_att_cache = []
247
+ r_cnn_cache = []
248
+ for i, layer in enumerate(self.encoders):
249
+ # NOTE(xcsong): Before layer.forward
250
+ # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
251
+ # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
252
+ xs, _, new_att_cache, new_cnn_cache = layer(
253
+ xs,
254
+ att_mask,
255
+ pos_emb,
256
+ att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
257
+ cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
258
+ # NOTE(xcsong): After layer.forward
259
+ # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
260
+ # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
261
+ r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
262
+ r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
263
+ if self.normalize_before:
264
+ xs = self.after_norm(xs)
265
+
266
+ # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
267
+ # ? may be larger than cache_t1, it depends on required_cache_size
268
+ r_att_cache = torch.cat(r_att_cache, dim=0)
269
+ # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
270
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
271
+
272
+ return (xs, r_att_cache, r_cnn_cache)
273
+
274
+ @torch.jit.unused
275
+ def forward_chunk_by_chunk(
276
+ self,
277
+ xs: torch.Tensor,
278
+ decoding_chunk_size: int,
279
+ num_decoding_left_chunks: int = -1,
280
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
281
+ """ Forward input chunk by chunk with chunk_size like a streaming
282
+ fashion
283
+
284
+ Here we should pay special attention to computation cache in the
285
+ streaming style forward chunk by chunk. Three things should be taken
286
+ into account for computation in the current network:
287
+ 1. transformer/conformer encoder layers output cache
288
+ 2. convolution in conformer
289
+ 3. convolution in subsampling
290
+
291
+ However, we don't implement subsampling cache for:
292
+ 1. We can control subsampling module to output the right result by
293
+ overlapping input instead of cache left context, even though it
294
+ wastes some computation, but subsampling only takes a very
295
+ small fraction of computation in the whole model.
296
+ 2. Typically, there are several covolution layers with subsampling
297
+ in subsampling module, it is tricky and complicated to do cache
298
+ with different convolution layers with different subsampling
299
+ rate.
300
+ 3. Currently, nn.Sequential is used to stack all the convolution
301
+ layers in subsampling, we need to rewrite it to make it work
302
+ with cache, which is not preferred.
303
+ Args:
304
+ xs (torch.Tensor): (1, max_len, dim)
305
+ chunk_size (int): decoding chunk size
306
+ """
307
+ assert decoding_chunk_size > 0
308
+ # The model is trained by static or dynamic chunk
309
+ assert self.static_chunk_size > 0 or self.use_dynamic_chunk
310
+ subsampling = self.embed.subsampling_rate
311
+ context = self.embed.right_context + 1 # Add current frame
312
+ stride = subsampling * decoding_chunk_size
313
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
314
+ num_frames = xs.size(1)
315
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
316
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
317
+ outputs = []
318
+ offset = 0
319
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
320
+
321
+ # Feed forward overlap input step by step
322
+ for cur in range(0, num_frames - context + 1, stride):
323
+ end = min(cur + decoding_window, num_frames)
324
+ chunk_xs = xs[:, cur:end, :]
325
+ (y, att_cache,
326
+ cnn_cache) = self.forward_chunk(chunk_xs, offset,
327
+ required_cache_size, att_cache,
328
+ cnn_cache)
329
+ outputs.append(y)
330
+ offset += y.size(1)
331
+ ys = torch.cat(outputs, 1)
332
+ masks = torch.ones((1, 1, ys.size(1)),
333
+ device=ys.device,
334
+ dtype=torch.bool)
335
+ return ys, masks
336
+
337
+
338
+ class TransformerEncoder(BaseEncoder):
339
+ """Transformer encoder module."""
340
+
341
+ def __init__(
342
+ self,
343
+ input_size: int,
344
+ output_size: int = 256,
345
+ attention_heads: int = 4,
346
+ linear_units: int = 2048,
347
+ num_blocks: int = 6,
348
+ dropout_rate: float = 0.1,
349
+ positional_dropout_rate: float = 0.1,
350
+ attention_dropout_rate: float = 0.0,
351
+ input_layer: str = "conv2d",
352
+ pos_enc_layer_type: str = "abs_pos",
353
+ normalize_before: bool = True,
354
+ static_chunk_size: int = 0,
355
+ use_dynamic_chunk: bool = False,
356
+ global_cmvn: torch.nn.Module = None,
357
+ use_dynamic_left_chunk: bool = False,
358
+ key_bias: bool = True,
359
+ selfattention_layer_type: str = "selfattn",
360
+ activation_type: str = "relu",
361
+ gradient_checkpointing: bool = False,
362
+ ):
363
+ """ Construct TransformerEncoder
364
+
365
+ See Encoder for the meaning of each parameter.
366
+ """
367
+ super().__init__(input_size, output_size, attention_heads,
368
+ linear_units, num_blocks, dropout_rate,
369
+ positional_dropout_rate, attention_dropout_rate,
370
+ input_layer, pos_enc_layer_type, normalize_before,
371
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
372
+ use_dynamic_left_chunk, gradient_checkpointing)
373
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
374
+ self.encoders = torch.nn.ModuleList([
375
+ TransformerEncoderLayer(
376
+ output_size,
377
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
378
+ output_size,
379
+ attention_dropout_rate,
380
+ key_bias),
381
+ PositionwiseFeedForward(output_size, linear_units,
382
+ dropout_rate, activation),
383
+ dropout_rate, normalize_before) for _ in range(num_blocks)
384
+ ])
385
+
386
+
387
+ class ConformerEncoder(BaseEncoder):
388
+ """Conformer encoder module."""
389
+
390
+ def __init__(
391
+ self,
392
+ input_size: int,
393
+ output_size: int = 256,
394
+ attention_heads: int = 4,
395
+ linear_units: int = 2048,
396
+ num_blocks: int = 6,
397
+ dropout_rate: float = 0.1,
398
+ positional_dropout_rate: float = 0.1,
399
+ attention_dropout_rate: float = 0.0,
400
+ input_layer: str = "conv2d",
401
+ pos_enc_layer_type: str = "rel_pos",
402
+ normalize_before: bool = True,
403
+ static_chunk_size: int = 0,
404
+ use_dynamic_chunk: bool = False,
405
+ global_cmvn: torch.nn.Module = None,
406
+ use_dynamic_left_chunk: bool = False,
407
+ positionwise_conv_kernel_size: int = 1,
408
+ macaron_style: bool = True,
409
+ selfattention_layer_type: str = "rel_selfattn",
410
+ activation_type: str = "swish",
411
+ use_cnn_module: bool = True,
412
+ cnn_module_kernel: int = 15,
413
+ causal: bool = False,
414
+ cnn_module_norm: str = "batch_norm",
415
+ key_bias: bool = True,
416
+ gradient_checkpointing: bool = False,
417
+ ):
418
+ """Construct ConformerEncoder
419
+
420
+ Args:
421
+ input_size to use_dynamic_chunk, see in BaseEncoder
422
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
423
+ conv1d layer.
424
+ macaron_style (bool): Whether to use macaron style for
425
+ positionwise layer.
426
+ selfattention_layer_type (str): Encoder attention layer type,
427
+ the parameter has no effect now, it's just for configure
428
+ compatibility.
429
+ activation_type (str): Encoder activation function type.
430
+ use_cnn_module (bool): Whether to use convolution module.
431
+ cnn_module_kernel (int): Kernel size of convolution module.
432
+ causal (bool): whether to use causal convolution or not.
433
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
434
+ """
435
+ super().__init__(input_size, output_size, attention_heads,
436
+ linear_units, num_blocks, dropout_rate,
437
+ positional_dropout_rate, attention_dropout_rate,
438
+ input_layer, pos_enc_layer_type, normalize_before,
439
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
440
+ use_dynamic_left_chunk, gradient_checkpointing)
441
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
442
+
443
+ # self-attention module definition
444
+ encoder_selfattn_layer_args = (
445
+ attention_heads,
446
+ output_size,
447
+ attention_dropout_rate,
448
+ key_bias,
449
+ )
450
+ # feed-forward module definition
451
+ positionwise_layer_args = (
452
+ output_size,
453
+ linear_units,
454
+ dropout_rate,
455
+ activation,
456
+ )
457
+ # convolution module definition
458
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
459
+ cnn_module_norm, causal)
460
+
461
+ self.encoders = torch.nn.ModuleList([
462
+ ConformerEncoderLayer(
463
+ output_size,
464
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
465
+ *encoder_selfattn_layer_args),
466
+ PositionwiseFeedForward(*positionwise_layer_args),
467
+ PositionwiseFeedForward(
468
+ *positionwise_layer_args) if macaron_style else None,
469
+ ConvolutionModule(
470
+ *convolution_layer_args) if use_cnn_module else None,
471
+ dropout_rate,
472
+ normalize_before,
473
+ ) for _ in range(num_blocks)
474
+ ])
cosyvoice/transformer/encoder_layer.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Encoder self-attention layer definition."""
17
+
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class TransformerEncoderLayer(nn.Module):
25
+ """Encoder layer module.
26
+
27
+ Args:
28
+ size (int): Input dimension.
29
+ self_attn (torch.nn.Module): Self-attention module instance.
30
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
31
+ instance can be used as the argument.
32
+ feed_forward (torch.nn.Module): Feed-forward module instance.
33
+ `PositionwiseFeedForward`, instance can be used as the argument.
34
+ dropout_rate (float): Dropout rate.
35
+ normalize_before (bool):
36
+ True: use layer_norm before each sub-block.
37
+ False: to use layer_norm after each sub-block.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ size: int,
43
+ self_attn: torch.nn.Module,
44
+ feed_forward: torch.nn.Module,
45
+ dropout_rate: float,
46
+ normalize_before: bool = True,
47
+ ):
48
+ """Construct an EncoderLayer object."""
49
+ super().__init__()
50
+ self.self_attn = self_attn
51
+ self.feed_forward = feed_forward
52
+ self.norm1 = nn.LayerNorm(size, eps=1e-12)
53
+ self.norm2 = nn.LayerNorm(size, eps=1e-12)
54
+ self.dropout = nn.Dropout(dropout_rate)
55
+ self.size = size
56
+ self.normalize_before = normalize_before
57
+
58
+ def forward(
59
+ self,
60
+ x: torch.Tensor,
61
+ mask: torch.Tensor,
62
+ pos_emb: torch.Tensor,
63
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
64
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
65
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
66
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
67
+ """Compute encoded features.
68
+
69
+ Args:
70
+ x (torch.Tensor): (#batch, time, size)
71
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
72
+ (0, 0, 0) means fake mask.
73
+ pos_emb (torch.Tensor): just for interface compatibility
74
+ to ConformerEncoderLayer
75
+ mask_pad (torch.Tensor): does not used in transformer layer,
76
+ just for unified api with conformer.
77
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
78
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
79
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
80
+ (#batch=1, size, cache_t2), not used here, it's for interface
81
+ compatibility to ConformerEncoderLayer.
82
+ Returns:
83
+ torch.Tensor: Output tensor (#batch, time, size).
84
+ torch.Tensor: Mask tensor (#batch, time, time).
85
+ torch.Tensor: att_cache tensor,
86
+ (#batch=1, head, cache_t1 + time, d_k * 2).
87
+ torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
88
+
89
+ """
90
+ residual = x
91
+ if self.normalize_before:
92
+ x = self.norm1(x)
93
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
94
+ x = residual + self.dropout(x_att)
95
+ if not self.normalize_before:
96
+ x = self.norm1(x)
97
+
98
+ residual = x
99
+ if self.normalize_before:
100
+ x = self.norm2(x)
101
+ x = residual + self.dropout(self.feed_forward(x))
102
+ if not self.normalize_before:
103
+ x = self.norm2(x)
104
+
105
+ fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
106
+ return x, mask, new_att_cache, fake_cnn_cache
107
+
108
+
109
+ class ConformerEncoderLayer(nn.Module):
110
+ """Encoder layer module.
111
+ Args:
112
+ size (int): Input dimension.
113
+ self_attn (torch.nn.Module): Self-attention module instance.
114
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
115
+ instance can be used as the argument.
116
+ feed_forward (torch.nn.Module): Feed-forward module instance.
117
+ `PositionwiseFeedForward` instance can be used as the argument.
118
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
119
+ instance.
120
+ `PositionwiseFeedForward` instance can be used as the argument.
121
+ conv_module (torch.nn.Module): Convolution module instance.
122
+ `ConvlutionModule` instance can be used as the argument.
123
+ dropout_rate (float): Dropout rate.
124
+ normalize_before (bool):
125
+ True: use layer_norm before each sub-block.
126
+ False: use layer_norm after each sub-block.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ size: int,
132
+ self_attn: torch.nn.Module,
133
+ feed_forward: Optional[nn.Module] = None,
134
+ feed_forward_macaron: Optional[nn.Module] = None,
135
+ conv_module: Optional[nn.Module] = None,
136
+ dropout_rate: float = 0.1,
137
+ normalize_before: bool = True,
138
+ ):
139
+ """Construct an EncoderLayer object."""
140
+ super().__init__()
141
+ self.self_attn = self_attn
142
+ self.feed_forward = feed_forward
143
+ self.feed_forward_macaron = feed_forward_macaron
144
+ self.conv_module = conv_module
145
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
146
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
147
+ if feed_forward_macaron is not None:
148
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
149
+ self.ff_scale = 0.5
150
+ else:
151
+ self.ff_scale = 1.0
152
+ if self.conv_module is not None:
153
+ self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
154
+ self.norm_final = nn.LayerNorm(
155
+ size, eps=1e-12) # for the final output of the block
156
+ self.dropout = nn.Dropout(dropout_rate)
157
+ self.size = size
158
+ self.normalize_before = normalize_before
159
+
160
+ def forward(
161
+ self,
162
+ x: torch.Tensor,
163
+ mask: torch.Tensor,
164
+ pos_emb: torch.Tensor,
165
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
166
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
167
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
168
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
169
+ """Compute encoded features.
170
+
171
+ Args:
172
+ x (torch.Tensor): (#batch, time, size)
173
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
174
+ (0, 0, 0) means fake mask.
175
+ pos_emb (torch.Tensor): positional encoding, must not be None
176
+ for ConformerEncoderLayer.
177
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
178
+ (#batch, 1,time), (0, 0, 0) means fake mask.
179
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
180
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
181
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
182
+ (#batch=1, size, cache_t2)
183
+ Returns:
184
+ torch.Tensor: Output tensor (#batch, time, size).
185
+ torch.Tensor: Mask tensor (#batch, time, time).
186
+ torch.Tensor: att_cache tensor,
187
+ (#batch=1, head, cache_t1 + time, d_k * 2).
188
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
189
+ """
190
+
191
+ # whether to use macaron style
192
+ if self.feed_forward_macaron is not None:
193
+ residual = x
194
+ if self.normalize_before:
195
+ x = self.norm_ff_macaron(x)
196
+ x = residual + self.ff_scale * self.dropout(
197
+ self.feed_forward_macaron(x))
198
+ if not self.normalize_before:
199
+ x = self.norm_ff_macaron(x)
200
+
201
+ # multi-headed self-attention module
202
+ residual = x
203
+ if self.normalize_before:
204
+ x = self.norm_mha(x)
205
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
206
+ att_cache)
207
+ x = residual + self.dropout(x_att)
208
+ if not self.normalize_before:
209
+ x = self.norm_mha(x)
210
+
211
+ # convolution module
212
+ # Fake new cnn cache here, and then change it in conv_module
213
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
214
+ if self.conv_module is not None:
215
+ residual = x
216
+ if self.normalize_before:
217
+ x = self.norm_conv(x)
218
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
219
+ x = residual + self.dropout(x)
220
+
221
+ if not self.normalize_before:
222
+ x = self.norm_conv(x)
223
+
224
+ # feed forward module
225
+ residual = x
226
+ if self.normalize_before:
227
+ x = self.norm_ff(x)
228
+
229
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
230
+ if not self.normalize_before:
231
+ x = self.norm_ff(x)
232
+
233
+ if self.conv_module is not None:
234
+ x = self.norm_final(x)
235
+
236
+ return x, mask, new_att_cache, new_cnn_cache
cosyvoice/transformer/label_smoothing_loss.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Label smoothing module."""
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+
21
+ class LabelSmoothingLoss(nn.Module):
22
+ """Label-smoothing loss.
23
+
24
+ In a standard CE loss, the label's data distribution is:
25
+ [0,1,2] ->
26
+ [
27
+ [1.0, 0.0, 0.0],
28
+ [0.0, 1.0, 0.0],
29
+ [0.0, 0.0, 1.0],
30
+ ]
31
+
32
+ In the smoothing version CE Loss,some probabilities
33
+ are taken from the true label prob (1.0) and are divided
34
+ among other labels.
35
+
36
+ e.g.
37
+ smoothing=0.1
38
+ [0,1,2] ->
39
+ [
40
+ [0.9, 0.05, 0.05],
41
+ [0.05, 0.9, 0.05],
42
+ [0.05, 0.05, 0.9],
43
+ ]
44
+
45
+ Args:
46
+ size (int): the number of class
47
+ padding_idx (int): padding class id which will be ignored for loss
48
+ smoothing (float): smoothing rate (0.0 means the conventional CE)
49
+ normalize_length (bool):
50
+ normalize loss by sequence length if True
51
+ normalize loss by batch size if False
52
+ """
53
+
54
+ def __init__(self,
55
+ size: int,
56
+ padding_idx: int,
57
+ smoothing: float,
58
+ normalize_length: bool = False):
59
+ """Construct an LabelSmoothingLoss object."""
60
+ super(LabelSmoothingLoss, self).__init__()
61
+ self.criterion = nn.KLDivLoss(reduction="none")
62
+ self.padding_idx = padding_idx
63
+ self.confidence = 1.0 - smoothing
64
+ self.smoothing = smoothing
65
+ self.size = size
66
+ self.normalize_length = normalize_length
67
+
68
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
69
+ """Compute loss between x and target.
70
+
71
+ The model outputs and data labels tensors are flatten to
72
+ (batch*seqlen, class) shape and a mask is applied to the
73
+ padding part which should not be calculated for loss.
74
+
75
+ Args:
76
+ x (torch.Tensor): prediction (batch, seqlen, class)
77
+ target (torch.Tensor):
78
+ target signal masked with self.padding_id (batch, seqlen)
79
+ Returns:
80
+ loss (torch.Tensor) : The KL loss, scalar float value
81
+ """
82
+ assert x.size(2) == self.size
83
+ batch_size = x.size(0)
84
+ x = x.view(-1, self.size)
85
+ target = target.view(-1)
86
+ # use zeros_like instead of torch.no_grad() for true_dist,
87
+ # since no_grad() can not be exported by JIT
88
+ true_dist = torch.zeros_like(x)
89
+ true_dist.fill_(self.smoothing / (self.size - 1))
90
+ ignore = target == self.padding_idx # (B,)
91
+ total = len(target) - ignore.sum().item()
92
+ target = target.masked_fill(ignore, 0) # avoid -1 index
93
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
94
+ kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
95
+ denom = total if self.normalize_length else batch_size
96
+ return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
cosyvoice/transformer/positionwise_feed_forward.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Positionwise feed forward layer definition."""
16
+
17
+ import torch
18
+
19
+
20
+ class PositionwiseFeedForward(torch.nn.Module):
21
+ """Positionwise feed forward layer.
22
+
23
+ FeedForward are appied on each position of the sequence.
24
+ The output dim is same with the input dim.
25
+
26
+ Args:
27
+ idim (int): Input dimenstion.
28
+ hidden_units (int): The number of hidden units.
29
+ dropout_rate (float): Dropout rate.
30
+ activation (torch.nn.Module): Activation function
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ idim: int,
36
+ hidden_units: int,
37
+ dropout_rate: float,
38
+ activation: torch.nn.Module = torch.nn.ReLU(),
39
+ ):
40
+ """Construct a PositionwiseFeedForward object."""
41
+ super(PositionwiseFeedForward, self).__init__()
42
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
43
+ self.activation = activation
44
+ self.dropout = torch.nn.Dropout(dropout_rate)
45
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
46
+
47
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
48
+ """Forward function.
49
+
50
+ Args:
51
+ xs: input tensor (B, L, D)
52
+ Returns:
53
+ output tensor, (B, L, D)
54
+ """
55
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56
+
57
+
58
+ class MoEFFNLayer(torch.nn.Module):
59
+ """
60
+ Mixture of expert with Positionwise feed forward layer
61
+ See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
62
+ The output dim is same with the input dim.
63
+
64
+ Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
65
+ https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
66
+ Args:
67
+ n_expert: number of expert.
68
+ n_expert_per_token: The actual number of experts used for each frame
69
+ idim (int): Input dimenstion.
70
+ hidden_units (int): The number of hidden units.
71
+ dropout_rate (float): Dropout rate.
72
+ activation (torch.nn.Module): Activation function
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ n_expert: int,
78
+ n_expert_per_token: int,
79
+ idim: int,
80
+ hidden_units: int,
81
+ dropout_rate: float,
82
+ activation: torch.nn.Module = torch.nn.ReLU(),
83
+ ):
84
+ super(MoEFFNLayer, self).__init__()
85
+ self.gate = torch.nn.Linear(idim, n_expert, bias=False)
86
+ self.experts = torch.nn.ModuleList(
87
+ PositionwiseFeedForward(idim, hidden_units, dropout_rate,
88
+ activation) for _ in range(n_expert))
89
+ self.n_expert_per_token = n_expert_per_token
90
+
91
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
92
+ """Foward function.
93
+ Args:
94
+ xs: input tensor (B, L, D)
95
+ Returns:
96
+ output tensor, (B, L, D)
97
+
98
+ """
99
+ B, L, D = xs.size(
100
+ ) # batch size, sequence length, embedding dimension (idim)
101
+ xs = xs.view(-1, D) # (B*L, D)
102
+ router = self.gate(xs) # (B*L, n_expert)
103
+ logits, indices = torch.topk(
104
+ router, self.n_expert_per_token
105
+ ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
106
+ weights = torch.nn.functional.softmax(
107
+ logits, dim=1,
108
+ dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
109
+ output = torch.zeros_like(xs) # (B*L, D)
110
+ for i, expert in enumerate(self.experts):
111
+ mask = indices == i
112
+ batch_idx, ith_expert = torch.where(mask)
113
+ output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
114
+ xs[batch_idx])
115
+ return output.view(B, L, D)
cosyvoice/transformer/subsampling.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Subsampling layer definition."""
17
+
18
+ from typing import Tuple, Union
19
+
20
+ import torch
21
+
22
+
23
+ class BaseSubsampling(torch.nn.Module):
24
+
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.right_context = 0
28
+ self.subsampling_rate = 1
29
+
30
+ def position_encoding(self, offset: Union[int, torch.Tensor],
31
+ size: int) -> torch.Tensor:
32
+ return self.pos_enc.position_encoding(offset, size)
33
+
34
+
35
+ class EmbedinigNoSubsampling(BaseSubsampling):
36
+ """Embedding input without subsampling
37
+ """
38
+
39
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
40
+ pos_enc_class: torch.nn.Module):
41
+ super().__init__()
42
+ self.embed = torch.nn.Embedding(idim, odim)
43
+ self.pos_enc = pos_enc_class
44
+
45
+ def forward(
46
+ self,
47
+ x: torch.Tensor,
48
+ x_mask: torch.Tensor,
49
+ offset: Union[int, torch.Tensor] = 0
50
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
+ """Input x.
52
+
53
+ Args:
54
+ x (torch.Tensor): Input tensor (#batch, time, idim).
55
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
56
+
57
+ Returns:
58
+ torch.Tensor: linear input tensor (#batch, time', odim),
59
+ where time' = time .
60
+ torch.Tensor: linear input mask (#batch, 1, time'),
61
+ where time' = time .
62
+
63
+ """
64
+ x = self.embed(x)
65
+ x, pos_emb = self.pos_enc(x, offset)
66
+ return x, pos_emb, x_mask
67
+
68
+
69
+ class LinearNoSubsampling(BaseSubsampling):
70
+ """Linear transform the input without subsampling
71
+
72
+ Args:
73
+ idim (int): Input dimension.
74
+ odim (int): Output dimension.
75
+ dropout_rate (float): Dropout rate.
76
+
77
+ """
78
+
79
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
80
+ pos_enc_class: torch.nn.Module):
81
+ """Construct an linear object."""
82
+ super().__init__()
83
+ self.out = torch.nn.Sequential(
84
+ torch.nn.Linear(idim, odim),
85
+ torch.nn.LayerNorm(odim, eps=1e-5),
86
+ torch.nn.Dropout(dropout_rate),
87
+ )
88
+ self.pos_enc = pos_enc_class
89
+ self.right_context = 0
90
+ self.subsampling_rate = 1
91
+
92
+ def forward(
93
+ self,
94
+ x: torch.Tensor,
95
+ x_mask: torch.Tensor,
96
+ offset: Union[int, torch.Tensor] = 0
97
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
+ """Input x.
99
+
100
+ Args:
101
+ x (torch.Tensor): Input tensor (#batch, time, idim).
102
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
103
+
104
+ Returns:
105
+ torch.Tensor: linear input tensor (#batch, time', odim),
106
+ where time' = time .
107
+ torch.Tensor: linear input mask (#batch, 1, time'),
108
+ where time' = time .
109
+
110
+ """
111
+ x = self.out(x)
112
+ x, pos_emb = self.pos_enc(x, offset)
113
+ return x, pos_emb, x_mask
114
+
115
+
116
+ class Conv1dSubsampling2(BaseSubsampling):
117
+ """Convolutional 1D subsampling (to 1/2 length).
118
+ It is designed for Whisper, ref:
119
+ https://github.com/openai/whisper/blob/main/whisper/model.py
120
+
121
+ Args:
122
+ idim (int): Input dimension.
123
+ odim (int): Output dimension.
124
+ dropout_rate (float): Dropout rate.
125
+
126
+ """
127
+
128
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
129
+ pos_enc_class: torch.nn.Module):
130
+ """Construct an Conv1dSubsampling2 object."""
131
+ super().__init__()
132
+ self.conv = torch.nn.Sequential(
133
+ torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
134
+ torch.nn.GELU(),
135
+ torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
136
+ torch.nn.GELU(),
137
+ )
138
+ self.pos_enc = pos_enc_class
139
+ # The right context for every conv layer is computed by:
140
+ # (kernel_size - 1) * frame_rate_of_this_layer
141
+ self.subsampling_rate = 2
142
+ # 4 = (3 - 1) * 1 + (3 - 1) * 1
143
+ self.right_context = 4
144
+
145
+ def forward(
146
+ self,
147
+ x: torch.Tensor,
148
+ x_mask: torch.Tensor,
149
+ offset: Union[int, torch.Tensor] = 0
150
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
151
+ """Subsample x.
152
+
153
+ Args:
154
+ x (torch.Tensor): Input tensor (#batch, time, idim).
155
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
156
+
157
+ Returns:
158
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
159
+ where time' = time // 2.
160
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
161
+ where time' = time // 2.
162
+ torch.Tensor: positional encoding
163
+
164
+ """
165
+ time = x.size(1)
166
+ x = x.transpose(1, 2) # (b, f, t)
167
+ x = self.conv(x)
168
+ x = x.transpose(1, 2) # (b, t, f)
169
+ x, pos_emb = self.pos_enc(x, offset)
170
+ return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
171
+
172
+
173
+ class Conv2dSubsampling4(BaseSubsampling):
174
+ """Convolutional 2D subsampling (to 1/4 length).
175
+
176
+ Args:
177
+ idim (int): Input dimension.
178
+ odim (int): Output dimension.
179
+ dropout_rate (float): Dropout rate.
180
+
181
+ """
182
+
183
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
184
+ pos_enc_class: torch.nn.Module):
185
+ """Construct an Conv2dSubsampling4 object."""
186
+ super().__init__()
187
+ self.conv = torch.nn.Sequential(
188
+ torch.nn.Conv2d(1, odim, 3, 2),
189
+ torch.nn.ReLU(),
190
+ torch.nn.Conv2d(odim, odim, 3, 2),
191
+ torch.nn.ReLU(),
192
+ )
193
+ self.out = torch.nn.Sequential(
194
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
195
+ self.pos_enc = pos_enc_class
196
+ # The right context for every conv layer is computed by:
197
+ # (kernel_size - 1) * frame_rate_of_this_layer
198
+ self.subsampling_rate = 4
199
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
200
+ self.right_context = 6
201
+
202
+ def forward(
203
+ self,
204
+ x: torch.Tensor,
205
+ x_mask: torch.Tensor,
206
+ offset: Union[int, torch.Tensor] = 0
207
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
208
+ """Subsample x.
209
+
210
+ Args:
211
+ x (torch.Tensor): Input tensor (#batch, time, idim).
212
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
213
+
214
+ Returns:
215
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
216
+ where time' = time // 4.
217
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
218
+ where time' = time // 4.
219
+ torch.Tensor: positional encoding
220
+
221
+ """
222
+ x = x.unsqueeze(1) # (b, c=1, t, f)
223
+ x = self.conv(x)
224
+ b, c, t, f = x.size()
225
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
226
+ x, pos_emb = self.pos_enc(x, offset)
227
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
228
+
229
+
230
+ class Conv2dSubsampling6(BaseSubsampling):
231
+ """Convolutional 2D subsampling (to 1/6 length).
232
+ Args:
233
+ idim (int): Input dimension.
234
+ odim (int): Output dimension.
235
+ dropout_rate (float): Dropout rate.
236
+ pos_enc (torch.nn.Module): Custom position encoding layer.
237
+ """
238
+
239
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
240
+ pos_enc_class: torch.nn.Module):
241
+ """Construct an Conv2dSubsampling6 object."""
242
+ super().__init__()
243
+ self.conv = torch.nn.Sequential(
244
+ torch.nn.Conv2d(1, odim, 3, 2),
245
+ torch.nn.ReLU(),
246
+ torch.nn.Conv2d(odim, odim, 5, 3),
247
+ torch.nn.ReLU(),
248
+ )
249
+ self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
250
+ odim)
251
+ self.pos_enc = pos_enc_class
252
+ # 10 = (3 - 1) * 1 + (5 - 1) * 2
253
+ self.subsampling_rate = 6
254
+ self.right_context = 10
255
+
256
+ def forward(
257
+ self,
258
+ x: torch.Tensor,
259
+ x_mask: torch.Tensor,
260
+ offset: Union[int, torch.Tensor] = 0
261
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
262
+ """Subsample x.
263
+ Args:
264
+ x (torch.Tensor): Input tensor (#batch, time, idim).
265
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
266
+
267
+ Returns:
268
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
269
+ where time' = time // 6.
270
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
271
+ where time' = time // 6.
272
+ torch.Tensor: positional encoding
273
+ """
274
+ x = x.unsqueeze(1) # (b, c, t, f)
275
+ x = self.conv(x)
276
+ b, c, t, f = x.size()
277
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
278
+ x, pos_emb = self.pos_enc(x, offset)
279
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
280
+
281
+
282
+ class Conv2dSubsampling8(BaseSubsampling):
283
+ """Convolutional 2D subsampling (to 1/8 length).
284
+
285
+ Args:
286
+ idim (int): Input dimension.
287
+ odim (int): Output dimension.
288
+ dropout_rate (float): Dropout rate.
289
+
290
+ """
291
+
292
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
293
+ pos_enc_class: torch.nn.Module):
294
+ """Construct an Conv2dSubsampling8 object."""
295
+ super().__init__()
296
+ self.conv = torch.nn.Sequential(
297
+ torch.nn.Conv2d(1, odim, 3, 2),
298
+ torch.nn.ReLU(),
299
+ torch.nn.Conv2d(odim, odim, 3, 2),
300
+ torch.nn.ReLU(),
301
+ torch.nn.Conv2d(odim, odim, 3, 2),
302
+ torch.nn.ReLU(),
303
+ )
304
+ self.linear = torch.nn.Linear(
305
+ odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
306
+ self.pos_enc = pos_enc_class
307
+ self.subsampling_rate = 8
308
+ # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
309
+ self.right_context = 14
310
+
311
+ def forward(
312
+ self,
313
+ x: torch.Tensor,
314
+ x_mask: torch.Tensor,
315
+ offset: Union[int, torch.Tensor] = 0
316
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
317
+ """Subsample x.
318
+
319
+ Args:
320
+ x (torch.Tensor): Input tensor (#batch, time, idim).
321
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
322
+
323
+ Returns:
324
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
325
+ where time' = time // 8.
326
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
327
+ where time' = time // 8.
328
+ torch.Tensor: positional encoding
329
+ """
330
+ x = x.unsqueeze(1) # (b, c, t, f)
331
+ x = self.conv(x)
332
+ b, c, t, f = x.size()
333
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
334
+ x, pos_emb = self.pos_enc(x, offset)
335
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
336
+
337
+
338
+ class LegacyLinearNoSubsampling(BaseSubsampling):
339
+ """Linear transform the input without subsampling
340
+
341
+ Args:
342
+ idim (int): Input dimension.
343
+ odim (int): Output dimension.
344
+ dropout_rate (float): Dropout rate.
345
+
346
+ """
347
+
348
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
349
+ pos_enc_class: torch.nn.Module):
350
+ """Construct an linear object."""
351
+ super().__init__()
352
+ self.out = torch.nn.Sequential(
353
+ torch.nn.Linear(idim, odim),
354
+ torch.nn.LayerNorm(odim, eps=1e-5),
355
+ torch.nn.Dropout(dropout_rate),
356
+ torch.nn.ReLU(),
357
+ )
358
+ self.pos_enc = pos_enc_class
359
+ self.right_context = 0
360
+ self.subsampling_rate = 1
361
+
362
+ def forward(
363
+ self,
364
+ x: torch.Tensor,
365
+ x_mask: torch.Tensor,
366
+ offset: Union[int, torch.Tensor] = 0
367
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
368
+ """Input x.
369
+
370
+ Args:
371
+ x (torch.Tensor): Input tensor (#batch, time, idim).
372
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
373
+
374
+ Returns:
375
+ torch.Tensor: linear input tensor (#batch, time', odim),
376
+ where time' = time .
377
+ torch.Tensor: linear input mask (#batch, 1, time'),
378
+ where time' = time .
379
+
380
+ """
381
+ x = self.out(x)
382
+ x, pos_emb = self.pos_enc(x, offset)
383
+ return x, pos_emb, x_mask
cosyvoice/transformer/upsample_encoder.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
3
+ # 2024 Alibaba Inc (Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+
24
+ from cosyvoice.transformer.convolution import ConvolutionModule
25
+ from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
26
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
27
+ from cosyvoice.utils.class_utils import (
28
+ COSYVOICE_EMB_CLASSES,
29
+ COSYVOICE_SUBSAMPLE_CLASSES,
30
+ COSYVOICE_ATTENTION_CLASSES,
31
+ COSYVOICE_ACTIVATION_CLASSES,
32
+ )
33
+ from cosyvoice.utils.mask import make_pad_mask
34
+ from cosyvoice.utils.mask import add_optional_chunk_mask
35
+
36
+
37
+ class Upsample1D(nn.Module):
38
+ """A 1D upsampling layer with an optional convolution.
39
+
40
+ Parameters:
41
+ channels (`int`):
42
+ number of channels in the inputs and outputs.
43
+ use_conv (`bool`, default `False`):
44
+ option to use a convolution.
45
+ use_conv_transpose (`bool`, default `False`):
46
+ option to use a convolution transpose.
47
+ out_channels (`int`, optional):
48
+ number of output channels. Defaults to `channels`.
49
+ """
50
+
51
+ def __init__(self, channels: int, out_channels: int, stride: int = 2):
52
+ super().__init__()
53
+ self.channels = channels
54
+ self.out_channels = out_channels
55
+ self.stride = stride
56
+ # In this mode, first repeat interpolate, than conv with stride=1
57
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
58
+
59
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
60
+ outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
61
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
62
+ outputs = self.conv(outputs)
63
+ return outputs, input_lengths * self.stride
64
+
65
+
66
+ class PreLookaheadLayer(nn.Module):
67
+ def __init__(self, channels: int, pre_lookahead_len: int = 1):
68
+ super().__init__()
69
+ self.channels = channels
70
+ self.pre_lookahead_len = pre_lookahead_len
71
+ self.conv1 = nn.Conv1d(
72
+ channels, channels,
73
+ kernel_size=pre_lookahead_len + 1,
74
+ stride=1, padding=0,
75
+ )
76
+ self.conv2 = nn.Conv1d(
77
+ channels, channels,
78
+ kernel_size=3, stride=1, padding=0,
79
+ )
80
+
81
+ def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
82
+ """
83
+ inputs: (batch_size, seq_len, channels)
84
+ """
85
+ outputs = inputs.transpose(1, 2).contiguous()
86
+ context = context.transpose(1, 2).contiguous()
87
+ # look ahead
88
+ if context.size(2) == 0:
89
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
90
+ else:
91
+ assert self.training is False, 'you have passed context, make sure that you are running inference mode'
92
+ assert context.size(2) == self.pre_lookahead_len
93
+ outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
94
+ outputs = F.leaky_relu(self.conv1(outputs))
95
+ # outputs
96
+ outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
97
+ outputs = self.conv2(outputs)
98
+ outputs = outputs.transpose(1, 2).contiguous()
99
+
100
+ # residual connection
101
+ outputs = outputs + inputs
102
+ return outputs
103
+
104
+
105
+ class UpsampleConformerEncoder(torch.nn.Module):
106
+
107
+ def __init__(
108
+ self,
109
+ input_size: int,
110
+ output_size: int = 256,
111
+ attention_heads: int = 4,
112
+ linear_units: int = 2048,
113
+ num_blocks: int = 6,
114
+ dropout_rate: float = 0.1,
115
+ positional_dropout_rate: float = 0.1,
116
+ attention_dropout_rate: float = 0.0,
117
+ input_layer: str = "conv2d",
118
+ pos_enc_layer_type: str = "rel_pos",
119
+ normalize_before: bool = True,
120
+ static_chunk_size: int = 0,
121
+ use_dynamic_chunk: bool = False,
122
+ global_cmvn: torch.nn.Module = None,
123
+ use_dynamic_left_chunk: bool = False,
124
+ positionwise_conv_kernel_size: int = 1,
125
+ macaron_style: bool = True,
126
+ selfattention_layer_type: str = "rel_selfattn",
127
+ activation_type: str = "swish",
128
+ use_cnn_module: bool = True,
129
+ cnn_module_kernel: int = 15,
130
+ causal: bool = False,
131
+ cnn_module_norm: str = "batch_norm",
132
+ key_bias: bool = True,
133
+ gradient_checkpointing: bool = False,
134
+ ):
135
+ """
136
+ Args:
137
+ input_size (int): input dim
138
+ output_size (int): dimension of attention
139
+ attention_heads (int): the number of heads of multi head attention
140
+ linear_units (int): the hidden units number of position-wise feed
141
+ forward
142
+ num_blocks (int): the number of decoder blocks
143
+ dropout_rate (float): dropout rate
144
+ attention_dropout_rate (float): dropout rate in attention
145
+ positional_dropout_rate (float): dropout rate after adding
146
+ positional encoding
147
+ input_layer (str): input layer type.
148
+ optional [linear, conv2d, conv2d6, conv2d8]
149
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
150
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
151
+ normalize_before (bool):
152
+ True: use layer_norm before each sub-block of a layer.
153
+ False: use layer_norm after each sub-block of a layer.
154
+ static_chunk_size (int): chunk size for static chunk training and
155
+ decoding
156
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
157
+ training or not, You can only use fixed chunk(chunk_size > 0)
158
+ or dyanmic chunk size(use_dynamic_chunk = True)
159
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
160
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
161
+ dynamic chunk training
162
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
163
+ gradient_checkpointing: rerunning a forward-pass segment for each
164
+ checkpointed segment during backward.
165
+ """
166
+ super().__init__()
167
+ self._output_size = output_size
168
+
169
+ self.global_cmvn = global_cmvn
170
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
171
+ input_size,
172
+ output_size,
173
+ dropout_rate,
174
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
175
+ positional_dropout_rate),
176
+ )
177
+
178
+ self.normalize_before = normalize_before
179
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
180
+ self.static_chunk_size = static_chunk_size
181
+ self.use_dynamic_chunk = use_dynamic_chunk
182
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
183
+ self.gradient_checkpointing = gradient_checkpointing
184
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
185
+ # self-attention module definition
186
+ encoder_selfattn_layer_args = (
187
+ attention_heads,
188
+ output_size,
189
+ attention_dropout_rate,
190
+ key_bias,
191
+ )
192
+ # feed-forward module definition
193
+ positionwise_layer_args = (
194
+ output_size,
195
+ linear_units,
196
+ dropout_rate,
197
+ activation,
198
+ )
199
+ # convolution module definition
200
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
201
+ cnn_module_norm, causal)
202
+ self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
203
+ self.encoders = torch.nn.ModuleList([
204
+ ConformerEncoderLayer(
205
+ output_size,
206
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
207
+ *encoder_selfattn_layer_args),
208
+ PositionwiseFeedForward(*positionwise_layer_args),
209
+ PositionwiseFeedForward(
210
+ *positionwise_layer_args) if macaron_style else None,
211
+ ConvolutionModule(
212
+ *convolution_layer_args) if use_cnn_module else None,
213
+ dropout_rate,
214
+ normalize_before,
215
+ ) for _ in range(num_blocks)
216
+ ])
217
+ self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
218
+ self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
219
+ input_size,
220
+ output_size,
221
+ dropout_rate,
222
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
223
+ positional_dropout_rate),
224
+ )
225
+ self.up_encoders = torch.nn.ModuleList([
226
+ ConformerEncoderLayer(
227
+ output_size,
228
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
229
+ *encoder_selfattn_layer_args),
230
+ PositionwiseFeedForward(*positionwise_layer_args),
231
+ PositionwiseFeedForward(
232
+ *positionwise_layer_args) if macaron_style else None,
233
+ ConvolutionModule(
234
+ *convolution_layer_args) if use_cnn_module else None,
235
+ dropout_rate,
236
+ normalize_before,
237
+ ) for _ in range(4)
238
+ ])
239
+
240
+ def output_size(self) -> int:
241
+ return self._output_size
242
+
243
+ def forward(
244
+ self,
245
+ xs: torch.Tensor,
246
+ xs_lens: torch.Tensor,
247
+ context: torch.Tensor = torch.zeros(0, 0, 0),
248
+ decoding_chunk_size: int = 0,
249
+ num_decoding_left_chunks: int = -1,
250
+ streaming: bool = False,
251
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
252
+ """Embed positions in tensor.
253
+
254
+ Args:
255
+ xs: padded input tensor (B, T, D)
256
+ xs_lens: input length (B)
257
+ decoding_chunk_size: decoding chunk size for dynamic chunk
258
+ 0: default for training, use random dynamic chunk.
259
+ <0: for decoding, use full chunk.
260
+ >0: for decoding, use fixed chunk size as set.
261
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
262
+ the chunk size is decoding_chunk_size.
263
+ >=0: use num_decoding_left_chunks
264
+ <0: use all left chunks
265
+ Returns:
266
+ encoder output tensor xs, and subsampled masks
267
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
268
+ masks: torch.Tensor batch padding mask after subsample
269
+ (B, 1, T' ~= T/subsample_rate)
270
+ NOTE(xcsong):
271
+ We pass the `__call__` method of the modules instead of `forward` to the
272
+ checkpointing API because `__call__` attaches all the hooks of the module.
273
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
274
+ """
275
+ T = xs.size(1)
276
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
277
+ if self.global_cmvn is not None:
278
+ xs = self.global_cmvn(xs)
279
+ xs, pos_emb, masks = self.embed(xs, masks)
280
+ if context.size(1) != 0:
281
+ assert self.training is False, 'you have passed context, make sure that you are running inference mode'
282
+ context_masks = torch.ones(1, 1, context.size(1)).to(masks)
283
+ context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
284
+ mask_pad = masks # (B, 1, T/subsample_rate)
285
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
286
+ # lookahead + conformer encoder
287
+ xs = self.pre_lookahead_layer(xs, context=context)
288
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
289
+
290
+ # upsample + conformer encoder
291
+ xs = xs.transpose(1, 2).contiguous()
292
+ xs, xs_lens = self.up_layer(xs, xs_lens)
293
+ xs = xs.transpose(1, 2).contiguous()
294
+ T = xs.size(1)
295
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
296
+ xs, pos_emb, masks = self.up_embed(xs, masks)
297
+ mask_pad = masks # (B, 1, T/subsample_rate)
298
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
299
+ xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
300
+
301
+ if self.normalize_before:
302
+ xs = self.after_norm(xs)
303
+ # Here we assume the mask is not changed in encoder layers, so just
304
+ # return the masks before encoder layers, and the masks will be used
305
+ # for cross attention with decoder later
306
+ return xs, masks
307
+
308
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
309
+ pos_emb: torch.Tensor,
310
+ mask_pad: torch.Tensor) -> torch.Tensor:
311
+ for layer in self.encoders:
312
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
313
+ return xs
314
+
315
+ def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
316
+ pos_emb: torch.Tensor,
317
+ mask_pad: torch.Tensor) -> torch.Tensor:
318
+ for layer in self.up_encoders:
319
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
320
+ return xs
cosyvoice/utils/__init__.py ADDED
File without changes
cosyvoice/utils/class_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright [2023-11-28] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+
17
+ from cosyvoice.transformer.activation import Swish
18
+ from cosyvoice.transformer.subsampling import (
19
+ LinearNoSubsampling,
20
+ EmbedinigNoSubsampling,
21
+ Conv1dSubsampling2,
22
+ Conv2dSubsampling4,
23
+ Conv2dSubsampling6,
24
+ Conv2dSubsampling8,
25
+ )
26
+ from cosyvoice.transformer.embedding import (PositionalEncoding,
27
+ RelPositionalEncoding,
28
+ WhisperPositionalEncoding,
29
+ LearnablePositionalEncoding,
30
+ NoPositionalEncoding)
31
+ from cosyvoice.transformer.attention import (MultiHeadedAttention,
32
+ RelPositionMultiHeadedAttention)
33
+ from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
34
+ from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
35
+ from cosyvoice.llm.llm import TransformerLM, Qwen2LM
36
+ from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
37
+ from cosyvoice.hifigan.generator import HiFTGenerator
38
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
39
+
40
+
41
+ COSYVOICE_ACTIVATION_CLASSES = {
42
+ "hardtanh": torch.nn.Hardtanh,
43
+ "tanh": torch.nn.Tanh,
44
+ "relu": torch.nn.ReLU,
45
+ "selu": torch.nn.SELU,
46
+ "swish": getattr(torch.nn, "SiLU", Swish),
47
+ "gelu": torch.nn.GELU,
48
+ }
49
+
50
+ COSYVOICE_SUBSAMPLE_CLASSES = {
51
+ "linear": LinearNoSubsampling,
52
+ "linear_legacy": LegacyLinearNoSubsampling,
53
+ "embed": EmbedinigNoSubsampling,
54
+ "conv1d2": Conv1dSubsampling2,
55
+ "conv2d": Conv2dSubsampling4,
56
+ "conv2d6": Conv2dSubsampling6,
57
+ "conv2d8": Conv2dSubsampling8,
58
+ 'paraformer_dummy': torch.nn.Identity
59
+ }
60
+
61
+ COSYVOICE_EMB_CLASSES = {
62
+ "embed": PositionalEncoding,
63
+ "abs_pos": PositionalEncoding,
64
+ "rel_pos": RelPositionalEncoding,
65
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
66
+ "no_pos": NoPositionalEncoding,
67
+ "abs_pos_whisper": WhisperPositionalEncoding,
68
+ "embed_learnable_pe": LearnablePositionalEncoding,
69
+ }
70
+
71
+ COSYVOICE_ATTENTION_CLASSES = {
72
+ "selfattn": MultiHeadedAttention,
73
+ "rel_selfattn": RelPositionMultiHeadedAttention,
74
+ }
75
+
76
+
77
+ def get_model_type(configs):
78
+ # NOTE CosyVoice2Model inherits CosyVoiceModel
79
+ if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
80
+ return CosyVoiceModel
81
+ if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
82
+ return CosyVoice2Model
83
+ raise TypeError('No valid model type found!')