Xinsheng-Wang commited on
Commit
c7f3ffb
·
verified ·
1 Parent(s): a81bc3b

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +38 -0
  3. DEPLOY.md +201 -0
  4. LICENSE +201 -0
  5. README.md +226 -9
  6. app.py +63 -0
  7. assets/performance_radar.png +3 -0
  8. assets/soul_wechat01.jpg +3 -0
  9. assets/soulx-logo.png +3 -0
  10. assets/technical-report.pdf +3 -0
  11. cli/inference.py +147 -0
  12. deploy_to_hf.sh +70 -0
  13. example/audio/en_prompt.json +16 -0
  14. example/audio/en_prompt.mp3 +0 -0
  15. example/audio/en_target.json +16 -0
  16. example/audio/en_target.mp3 +0 -0
  17. example/audio/music.json +16 -0
  18. example/audio/music.mp3 +3 -0
  19. example/audio/yue_target.json +16 -0
  20. example/audio/yue_target.mp3 +3 -0
  21. example/audio/zh_prompt.json +16 -0
  22. example/audio/zh_prompt.mp3 +0 -0
  23. example/audio/zh_target.json +16 -0
  24. example/audio/zh_target.mp3 +0 -0
  25. example/infer.sh +28 -0
  26. example/preprocess.sh +41 -0
  27. preprocess/README.md +155 -0
  28. preprocess/pipeline.py +146 -0
  29. preprocess/requirements.txt +33 -0
  30. preprocess/tools/__init__.py +53 -0
  31. preprocess/tools/f0_extraction.py +527 -0
  32. preprocess/tools/g2p.py +72 -0
  33. preprocess/tools/lyric_transcription.py +279 -0
  34. preprocess/tools/midi_parser.py +669 -0
  35. preprocess/tools/note_transcription/__init__.py +0 -0
  36. preprocess/tools/note_transcription/model.py +522 -0
  37. preprocess/tools/note_transcription/modules/__init__.py +1 -0
  38. preprocess/tools/note_transcription/modules/commons/__init__.py +1 -0
  39. preprocess/tools/note_transcription/modules/commons/conformer/__init__.py +1 -0
  40. preprocess/tools/note_transcription/modules/commons/conformer/conformer.py +96 -0
  41. preprocess/tools/note_transcription/modules/commons/conformer/espnet_positional_embedding.py +113 -0
  42. preprocess/tools/note_transcription/modules/commons/conformer/espnet_transformer_attn.py +198 -0
  43. preprocess/tools/note_transcription/modules/commons/conformer/layers.py +260 -0
  44. preprocess/tools/note_transcription/modules/commons/conv.py +175 -0
  45. preprocess/tools/note_transcription/modules/commons/layers.py +85 -0
  46. preprocess/tools/note_transcription/modules/commons/rel_transformer.py +378 -0
  47. preprocess/tools/note_transcription/modules/commons/rnn.py +261 -0
  48. preprocess/tools/note_transcription/modules/commons/transformer.py +751 -0
  49. preprocess/tools/note_transcription/modules/commons/wavenet.py +109 -0
  50. preprocess/tools/note_transcription/modules/pe/__init__.py +1 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/performance_radar.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/soul_wechat01.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/soulx-logo.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/technical-report.pdf filter=lfs diff=lfs merge=lfs -text
40
+ example/audio/music.mp3 filter=lfs diff=lfs merge=lfs -text
41
+ example/audio/yue_target.mp3 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+
4
+ dev/
5
+ results/
6
+ wandb/
7
+ .ipynb_checkpoints/
8
+ .vscode/
9
+ .cache
10
+ local/
11
+ outputs/
12
+
13
+ *.pt
14
+ *.ckpt
15
+
16
+ # Logs
17
+ logs/
18
+ *.log
19
+ results/
20
+ runs/
21
+ dev*
22
+ local/
23
+ generated/
24
+
25
+ .DS_Store
26
+ pretrained_models/
27
+
28
+ *.err
29
+ *.out
30
+
31
+ # Dev
32
+ dev/
33
+
34
+ # Data
35
+ data/
36
+ outputs/
37
+ deploy/
38
+ .gradio/
DEPLOY.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 部署到 Hugging Face Space 指南
2
+
3
+ 本指南将帮助您将 SoulX-Singer 部署到 Hugging Face Space。
4
+
5
+ ## 📋 前置要求
6
+
7
+ 1. **Hugging Face 账号**:如果没有,请先注册 [huggingface.co](https://huggingface.co/join)
8
+ 2. **Git**:确保已安装 Git
9
+ 3. **Hugging Face CLI**(可选但推荐):`pip install huggingface_hub`
10
+
11
+ ## 🎯 部署步骤
12
+
13
+ ### 方法一:通过 Web 界面创建(推荐)
14
+
15
+ #### 步骤 1:准备代码仓库
16
+
17
+ 确保您的代码已准备好:
18
+ - ✅ `app.py` - Space 入口文件
19
+ - ✅ `webui.py` - Gradio 界面代码
20
+ - ✅ `requirements.txt` - Python 依赖
21
+ - ✅ `README.md` - 包含 Space 配置的 YAML 头部
22
+
23
+ #### 步骤 2:创建 Space
24
+
25
+ 1. 访问 [huggingface.co/spaces](https://huggingface.co/spaces)
26
+ 2. 点击 **"Create new Space"** 按钮
27
+ 3. 填写 Space 信息:
28
+ - **Space name**: 例如 `SoulX-Singer` 或 `soulx-singer-demo`
29
+ - **SDK**: 选择 **Gradio**
30
+ - **Hardware**: 推荐选择 **GPU T4 small**(推理更快,首次下载模型后缓存)
31
+ - **Visibility**: 选择 Public(公开)或 Private(私有)
32
+ 4. 点击 **"Create Space"**
33
+
34
+ #### 步骤 3:上传代码
35
+
36
+ **选项 A:使用 Git 推送(推荐)**
37
+
38
+ ```bash
39
+ # 1. 在本地代码目录初始化 Git(如果还没有)
40
+ git init
41
+ git add .
42
+ git commit -m "Initial commit for HF Space"
43
+
44
+ # 2. 添加 Hugging Face 远程仓库
45
+ # 替换 YOUR_USERNAME 和 YOUR_SPACE_NAME
46
+ git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
47
+
48
+ # 3. 推送代码
49
+ git push -u origin main
50
+ ```
51
+
52
+ **选项 B:使用 Web 界面上传**
53
+
54
+ 1. 在 Space 页面点击 **"Files and versions"** 标签
55
+ 2. 点击 **"Add file"** → **"Upload files"**
56
+ 3. 拖拽或选择以下必需文件:
57
+ - `app.py`
58
+ - `webui.py`
59
+ - `requirements.txt`
60
+ - `README.md`
61
+ - `soulxsinger/` 目录(整个文件夹)
62
+ - `preprocess/` 目录(整个文件夹)
63
+ - `cli/` 目录(整个文件夹)
64
+ - `example/` 目录(整个文件夹)
65
+ - `assets/` 目录(整个文件夹)
66
+ - 其他配置文件(如 `LICENSE`, `.gitignore` 等)
67
+
68
+ #### 步骤 4:等待构建和首次运行
69
+
70
+ 1. Space 会自动检测到代码并开始构建
71
+ 2. 查看 **"Logs"** 标签页监控构建进度
72
+ 3. 首次运行会:
73
+ - 安装 `requirements.txt` 中的依赖
74
+ - 执行 `app.py`
75
+ - **自动下载** `Soul-AILab/SoulX-Singer` 和 `Soul-AILab/SoulX-Singer-Preprocess` 模型(可能需要 5-15 分钟,取决于网络速度)
76
+ 4. 构建完成后,Space 会自动启动,您可以在 **"App"** 标签页看到界面
77
+
78
+ ### 方法二:使用 Hugging Face CLI
79
+
80
+ ```bash
81
+ # 1. 安装 Hugging Face Hub CLI
82
+ pip install huggingface_hub
83
+
84
+ # 2. 登录(会打开浏览器)
85
+ huggingface-cli login
86
+
87
+ # 3. 创建 Space(替换 YOUR_USERNAME 和 YOUR_SPACE_NAME)
88
+ huggingface-cli repo create YOUR_SPACE_NAME --type space --sdk gradio
89
+
90
+ # 4. 克隆 Space 仓库
91
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
92
+ cd YOUR_SPACE_NAME
93
+
94
+ # 5. 复制代码文件到 Space 目录
95
+ # (将当前代码目录的所有文件复制过来)
96
+
97
+ # 6. 提交并推送
98
+ git add .
99
+ git commit -m "Deploy SoulX-Singer to HF Space"
100
+ git push
101
+ ```
102
+
103
+ ## ⚙️ Space 配置说明
104
+
105
+ Space 配置在 `README.md` 的 YAML 头部:
106
+
107
+ ```yaml
108
+ ---
109
+ title: SoulX-Singer
110
+ emoji: 🎤
111
+ sdk: gradio
112
+ sdk_version: "6.3.0"
113
+ app_file: app.py
114
+ python_version: "3.10"
115
+ suggested_hardware: t4-small # 取消注释以启用 GPU
116
+ ---
117
+ ```
118
+
119
+ ### 硬件选择建议
120
+
121
+ - **CPU Basic**: 免费,但推理速度较慢,适合测试
122
+ - **GPU T4 Small**: 推荐,推理速度快,首次下载模型后缓存
123
+ - **GPU T4 Medium/Large**: 适合高并发或更复杂的推理
124
+
125
+ ### 修改硬件配置
126
+
127
+ 1. 进入 Space 页面
128
+ 2. 点击 **"Settings"** 标签
129
+ 3. 在 **"Hardware"** 部分选择所需硬件
130
+ 4. 保存后 Space 会重启
131
+
132
+ ## 🔍 故障排查
133
+
134
+ ### 问题 1:构建失败
135
+
136
+ **检查点:**
137
+ - ✅ `requirements.txt` 中所有依赖版本是否兼容
138
+ - ✅ `app.py` 文件是否存在且可执行
139
+ - ✅ `README.md` 的 YAML 配置是否正确
140
+
141
+ **查看日志:**
142
+ - 在 Space 页面的 **"Logs"** 标签查看详细错误信息
143
+
144
+ ### 问题 2:模型下载失败
145
+
146
+ **可能原因:**
147
+ - 网络连接问题
148
+ - Hugging Face Hub 认证问题
149
+
150
+ **解决方案:**
151
+ - 确保 Space 有网络访问权限(默认有)
152
+ - 如果使用私有模型,需要在 Space Settings 中添加 HF Token
153
+
154
+ ### 问题 3:应用启动后无法访问
155
+
156
+ **检查点:**
157
+ - ✅ `app.py` 中 `server_name="0.0.0.0"` 已设置
158
+ - ✅ 端口使用环境变量 `PORT`(Space 会自动注入)
159
+ - ✅ 查看 **"Logs"** 确认应用是否成功启动
160
+
161
+ ### 问题 4:内存不足
162
+
163
+ **解决方案:**
164
+ - 升级到更大的硬件(T4 Medium/Large)
165
+ - 或优化代码,减少内存占用
166
+
167
+ ## 📝 重要提示
168
+
169
+ 1. **首次运行时间**:首次部署时,模型下载可能需要 5-15 分钟,请耐心等待
170
+ 2. **模型缓存**:下载的模型会缓存在 Space 的存��中,重启后无需重新下载
171
+ 3. **存储限制**:免费 Space 有存储限制,确保模型文件不会超过限制
172
+ 4. **自动重启**:Space 会在代码更新后自动重启
173
+ 5. **日志查看**:遇到问题时,首先查看 **"Logs"** 标签页的详细日志
174
+
175
+ ## 🔗 相关链接
176
+
177
+ - [Hugging Face Spaces 文档](https://huggingface.co/docs/hub/spaces)
178
+ - [Gradio 文档](https://gradio.app/docs/)
179
+ - [SoulX-Singer 模型页面](https://huggingface.co/Soul-AILab/SoulX-Singer)
180
+ - [SoulX-Singer-Preprocess 模型页面](https://huggingface.co/Soul-AILab/SoulX-Singer-Preprocess)
181
+
182
+ ## ✅ 部署检查清单
183
+
184
+ 部署前确认:
185
+ - [ ] `app.py` 文件存在且正确
186
+ - [ ] `requirements.txt` 包含所有依赖(包括 `huggingface_hub`)
187
+ - [ ] `README.md` 包含正确的 YAML 配置
188
+ - [ ] 所有必需的代码文件都已上传
189
+ - [ ] `.gitignore` 正确配置(排除 `pretrained_models/` 和 `outputs/`)
190
+ - [ ] Space 硬件配置合适(推荐 GPU T4 Small)
191
+
192
+ 部署后验证:
193
+ - [ ] Space 构建成功(无错误日志)
194
+ - [ ] 模型自动下载完成
195
+ - [ ] Web 界面可以正常访问
196
+ - [ ] 可以上传音频文件进行测试
197
+ - [ ] 推理功能正常工作
198
+
199
+ ---
200
+
201
+ **祝部署顺利!** 🎉
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,14 +1,231 @@
1
  ---
2
- title: SoulX Singer
3
- emoji: 👁
4
- colorFrom: purple
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.5.1
8
  app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Zero-shot Singing Voice Synthesis
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SoulX-Singer
3
+ emoji: 🎤
 
 
4
  sdk: gradio
5
+ sdk_version: "6.3.0"
6
  app_file: app.py
7
+ python_version: "3.10"
8
+ # GPU recommended for inference speed (optional: use CPU for light usage)
9
+ # suggested_hardware: t4-small
10
  ---
11
 
12
+ <div align="center">
13
+ <h1>🎤 SoulX-Singer</h1>
14
+ <p>
15
+ Official inference code for<br>
16
+ <b><em>SoulX-Singer: Towards High-Quality Zero-Shot Singing Voice Synthesis</em></b>
17
+ </p>
18
+ <p>
19
+ <img src="assets/soulx-logo.png" alt="SoulX-Logo" style="height:80px;">
20
+ </p>
21
+ <p>
22
+ <a href="https://soul-ailab.github.io/soulx-singer/"><img src="https://img.shields.io/badge/Demo-Page-lightgrey" alt="Demo Page"></a>
23
+ <a href="https://huggingface.co/Soul-AILab/SoulX-Singer"><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue' alt="HF-model"></a>
24
+ <a href="assets/technical-report.pdf"><img src="https://img.shields.io/badge/Report-Github-red" alt="Technical Report"></a>
25
+ <a href="https://github.com/Soul-AILab/SoulX-Singer"><img src="https://img.shields.io/badge/License-Apache%202.0-blue" alt="License"></a>
26
+ </p>
27
+ </div>
28
+
29
+ ---
30
+
31
+ ## 🎵 Overview
32
+
33
+ **SoulX-Singer** is a high-fidelity, zero-shot singing voice synthesis model that enables users to generate realistic singing voices for unseen singers.
34
+ It supports **melody-conditioned (F0 contour)** and **score-conditioned (MIDI notes)** control for precise pitch, rhythm, and expression.
35
+
36
+ ---
37
+
38
+ ## ✨ Key Features
39
+
40
+ - **🎤 Zero-Shot Singing** – Generate high-fidelity voices for unseen singers, no fine-tuning needed.
41
+ - **🎵 Flexible Control Modes** – Melody (F0) and Score (MIDI) conditioning.
42
+ - **📚 Large-Scale Dataset** – 42,000+ hours of aligned vocals, lyrics, notes across Mandarin, English, Cantonese.
43
+ - **🧑‍🎤 Timbre Cloning** – Preserve singer identity across languages, styles, and edited lyrics.
44
+ - **✏️ Singing Voice Editing** – Modify lyrics while keeping natural prosody.
45
+ - **🌐 Cross-Lingual Synthesis** – High-fidelity synthesis by disentangling timbre from content.
46
+
47
+ ---
48
+
49
+ <p align="center">
50
+ <img src="assets/performance_radar.png" width="80%" alt="Performance Radar"/>
51
+ </p>
52
+
53
+ ---
54
+
55
+ ## 🎬 Demo Examples
56
+
57
+
58
+ <div align="center">
59
+
60
+ <https://github.com/user-attachments/assets/13306f10-3a29-46ba-bcef-d6308d05cbcc>
61
+
62
+ </div>
63
+ <div align="center">
64
+
65
+ <https://github.com/user-attachments/assets/2eb260fe-6f0b-408c-aab8-5b81ddddb284>
66
+
67
+ </div>
68
+
69
+ ---
70
+
71
+ ## 📰 News
72
+
73
+ - **[2026-02-06]** SoulX-Singer inference code and models released.
74
+
75
+ ---
76
+
77
+ ## 🚀 Quick Start
78
+
79
+ **Note:** This repo does not ship pretrained weights. SVS and preprocessing models must be downloaded from Hugging Face (see step 3).
80
+
81
+ ### 1. Clone Repository
82
+
83
+ ```bash
84
+ git clone https://github.com/Soul-AILab/SoulX-Singer.git
85
+ cd SoulX-Singer
86
+ ```
87
+
88
+ ### 2. Set Up Environment
89
+
90
+ **1. Install Conda** (if not already installed): https://docs.conda.io/en/latest/miniconda.html
91
+
92
+ **2. Create and activate a Conda environment:**
93
+ ```
94
+ conda create -n soulxsinger -y python=3.10
95
+ conda activate soulxsinger
96
+ ```
97
+ **3. Install dependencies:**
98
+ ```
99
+ pip install -r requirements.txt
100
+ ```
101
+ ⚠️ If you are in mainland China, use a PyPI mirror:
102
+ ```
103
+ pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
104
+ ```
105
+
106
+
107
+ ---
108
+
109
+ ### 3. Download Pretrained Models
110
+
111
+ **This repository does not include pretrained models.** You must download them from Hugging Face:
112
+
113
+ - [Soul-AILab/SoulX-Singer](https://huggingface.co/Soul-AILab/SoulX-Singer) (SVS model)
114
+ - [Soul-AILab/SoulX-Singer-Preprocess](https://huggingface.co/Soul-AILab/SoulX-Singer-Preprocess) (preprocessing models)
115
+
116
+ Install Hugging Face Hub and download:
117
+
118
+ ```sh
119
+ pip install -U huggingface_hub
120
+
121
+ # SoulX-Singer SVS model
122
+ huggingface-cli download Soul-AILab/SoulX-Singer --local-dir pretrained_models/SoulX-Singer
123
+
124
+ # Preprocessing models (vocal separation, F0, ASR, etc.)
125
+ huggingface-cli download Soul-AILab/SoulX-Singer-Preprocess --local-dir pretrained_models/SoulX-Singer-Preprocess
126
+ ```
127
+
128
+
129
+ ### 4. Run the Demo
130
+
131
+ Run the inference demo:
132
+ ``` sh
133
+ bash example/infer.sh
134
+ ```
135
+
136
+ This script relies on metadata generated from the preprocessing pipeline, including vocal separation and transcription. Users should follow the steps in [preprocess](preprocess/README.md) to prepare the necessary metadata before running the demo with their own data.
137
+
138
+ **⚠️ Important Note**
139
+ The metadata produced by the automatic preprocessing pipeline may not perfectly align the singing audio with the corresponding lyrics and musical notes. For best synthesis quality, we strongly recommend manually correcting the alignment using the 🎼 [Midi-Editor](https://huggingface.co/spaces/Soul-AILab/SoulX-Singer-Midi-Editor).
140
+
141
+ How to use the Midi-Editor:
142
+ - [Eiditing Metadata with Midi-Editor](preprocess/README.md#L104-L105)
143
+
144
+
145
+ ### 🌐 WebUI
146
+
147
+ You can launch the interactive interface with:
148
+ ```
149
+ python webui.py
150
+ ```
151
+
152
+ ### 🚀 Deploy as Hugging Face Space
153
+
154
+ This repo is ready to deploy as a [Hugging Face Space](https://huggingface.co/spaces). **Pretrained models are not included;** `app.py` downloads them from the Hub on first run.
155
+
156
+ **📖 详细部署指南请查看:[DEPLOY.md](DEPLOY.md)**
157
+
158
+ **快速步骤:**
159
+
160
+ 1. **创建 Space**:访问 [huggingface.co/spaces](https://huggingface.co/spaces),点击 "Create new Space",选择 **Gradio** SDK
161
+ 2. **上传代码**:使用 Git 推送或 Web 界面上传代码文件
162
+ 3. **配置硬件**:在 Space Settings 中选择 **GPU T4 Small**(推荐)以加快推理速度
163
+ 4. **等待启动**:Space 会自动安装依赖、下载模型并启动应用(首次运行可能需要 5-15 分钟)
164
+
165
+ 模型会自动从以下仓库下载:
166
+ - [Soul-AILab/SoulX-Singer](https://huggingface.co/Soul-AILab/SoulX-Singer) (SVS model)
167
+ - [Soul-AILab/SoulX-Singer-Preprocess](https://huggingface.co/Soul-AILab/SoulX-Singer-Preprocess) (preprocessing models)
168
+
169
+
170
+
171
+ ## 🚧 Roadmap
172
+
173
+ - [ ] 🖥️ Web-based UI for easy and interactive inference
174
+ - [ ] 🌐 Online demo deployment on Hugging Face Spaces
175
+ - [ ] 📊 Release the SoulX-Singer-Eval benchmark
176
+ - [ ] 📚 Comprehensive tutorials and usage documentation
177
+
178
+
179
+ ## 🙏 Acknowledgements
180
+
181
+ Special thanks to the following open-source projects:
182
+
183
+ - [F5-TTS](https://github.com/SWivid/F5-TTS)
184
+ - [Amphion](https://github.com/open-mmlab/Amphion/tree/main)
185
+ - [Music Source Separation Training](https://github.com/ZFTurbo/Music-Source-Separation-Training)
186
+ - [Lead Vocal Separation](https://huggingface.co/becruily/mel-band-roformer-karaoke)
187
+ - [Vocal Dereverberation](https://huggingface.co/anvuew/dereverb_mel_band_roformer)
188
+ - [RMVPE](https://github.com/Dream-High/RMVPE)
189
+ [Paraformer](https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
190
+ - [Parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2)
191
+ - [ROSVOT](https://github.com/RickyL-2000/ROSVOT)
192
+
193
+
194
+
195
+ ## 📄 License
196
+
197
+ We use the Apache 2.0 license. Researchers and developers are free to use the codes and model weights of our SoulX-Singer. Check the license at [LICENSE](LICENSE) for more details.
198
+
199
+
200
+ ## ⚠️ Usage Disclaimer
201
+
202
+ SoulX-Singer is intended for academic research, educational purposes, and legitimate applications such as personalized singing synthesis and assistive technologies.
203
+
204
+ Please note:
205
+
206
+ - 🎤 Respect intellectual property, privacy, and personal consent when generating singing content.
207
+ - 🚫 Do not use the model to impersonate individuals without authorization or to create deceptive audio.
208
+ - ⚠️ The developers assume no liability for any misuse of this model.
209
+
210
+ We advocate for the responsible development and use of AI and encourage the community to uphold safety and ethical principles. For ethics or misuse concerns, please contact us.
211
+
212
+
213
+ ## 📬 Contact Us
214
+
215
+ We welcome your feedback, questions, and collaboration:
216
+
217
+ - **Email**: qianjiale@soulapp.cn | menghao@soulapp.cn | wangxinsheng@soulapp.cn
218
+
219
+ - **Join discussions**: WeChat or Soul APP groups for technical discussions and updates:
220
+
221
+ <p align="center">
222
+ <!-- <em>Due to group limits, if you can't scan the QR code, please add my WeChat for group access -->
223
+ <!-- : <strong>Tiamo James</strong></em> -->
224
+ <br>
225
+ <span style="display: inline-block; margin-right: 10px;">
226
+ <img src="assets/soul_wechat01.jpg" width="500" alt="WeChat Group QR Code"/>
227
+ </span>
228
+ <!-- <span style="display: inline-block;">
229
+ <img src="assets/wechat_tiamo.jpg" width="300" alt="WeChat QR Code"/>
230
+ </span> -->
231
+ </p>
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Space entry point for SoulX-Singer.
3
+ Downloads pretrained models from the Hub if needed, then launches the Gradio app.
4
+ """
5
+ import os
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ ROOT = Path(__file__).resolve().parent
10
+ PRETRAINED_DIR = ROOT / "pretrained_models"
11
+ MODEL_DIR_SVS = PRETRAINED_DIR / "SoulX-Singer"
12
+ MODEL_DIR_PREPROCESS = PRETRAINED_DIR / "SoulX-Singer-Preprocess"
13
+
14
+
15
+ def ensure_pretrained_models():
16
+ """Download SoulX-Singer and Preprocess models from Hugging Face Hub if not present."""
17
+ if (MODEL_DIR_SVS / "model.pt").exists() and MODEL_DIR_PREPROCESS.exists():
18
+ print("Pretrained models already present, skipping download.", flush=True)
19
+ return
20
+
21
+ try:
22
+ from huggingface_hub import snapshot_download
23
+ except ImportError:
24
+ print(
25
+ "huggingface_hub not installed. Install with: pip install huggingface_hub",
26
+ file=sys.stderr,
27
+ flush=True,
28
+ )
29
+ raise
30
+
31
+ PRETRAINED_DIR.mkdir(parents=True, exist_ok=True)
32
+
33
+ if not (MODEL_DIR_SVS / "model.pt").exists():
34
+ print("Downloading SoulX-Singer model...", flush=True)
35
+ snapshot_download(
36
+ repo_id="Soul-AILab/SoulX-Singer",
37
+ local_dir=str(MODEL_DIR_SVS),
38
+ local_dir_use_symlinks=False,
39
+ )
40
+ print("SoulX-Singer model ready.", flush=True)
41
+
42
+ if not MODEL_DIR_PREPROCESS.exists():
43
+ print("Downloading SoulX-Singer-Preprocess models...", flush=True)
44
+ snapshot_download(
45
+ repo_id="Soul-AILab/SoulX-Singer-Preprocess",
46
+ local_dir=str(MODEL_DIR_PREPROCESS),
47
+ local_dir_use_symlinks=False,
48
+ )
49
+ print("SoulX-Singer-Preprocess models ready.", flush=True)
50
+
51
+
52
+ if __name__ == "__main__":
53
+ os.chdir(ROOT)
54
+ ensure_pretrained_models()
55
+
56
+ from webui import render_interface
57
+
58
+ page = render_interface()
59
+ page.queue()
60
+ page.launch(
61
+ server_name="0.0.0.0",
62
+ server_port=int(os.environ.get("PORT", "7860")),
63
+ )
assets/performance_radar.png ADDED

Git LFS Details

  • SHA256: 8a5fe64523e65072d7c8014e4584b9f20b5e4f43bbd54edee9f2a068ef174162
  • Pointer size: 131 Bytes
  • Size of remote file: 137 kB
assets/soul_wechat01.jpg ADDED

Git LFS Details

  • SHA256: b452c23c33f4d0771f922aed4ceb92c0d6e893e74061f78b69a222f94bbd3c4a
  • Pointer size: 131 Bytes
  • Size of remote file: 835 kB
assets/soulx-logo.png ADDED

Git LFS Details

  • SHA256: 4fe6c191a71be0323d52b236d8ed57f346821ee66c4a9bd8b6232cbca9bf3daf
  • Pointer size: 131 Bytes
  • Size of remote file: 636 kB
assets/technical-report.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab2876f8850ce09e2b8ce7e929f8b9adf7de10f13900cb013f548f9707b80061
3
+ size 7927691
cli/inference.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ import argparse
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ import soundfile as sf
8
+ from collections import OrderedDict
9
+ from omegaconf import DictConfig
10
+
11
+ from soulxsinger.utils.file_utils import load_config
12
+ from soulxsinger.models.soulxsinger import SoulXSinger
13
+ from soulxsinger.utils.data_processor import DataProcessor
14
+
15
+
16
+ def build_model(
17
+ model_path: str,
18
+ config: DictConfig,
19
+ device: str = "cuda",
20
+ ):
21
+ """
22
+ Build the model from the pre-trained model path and model configuration.
23
+
24
+ Args:
25
+ model_path (str): Path to the checkpoint file.
26
+ config (DictConfig): Model configuration.
27
+ device (str, optional): Device to use. Defaults to "cuda".
28
+
29
+ Returns:
30
+ Tuple[torch.nn.Module, torch.nn.Module]: The initialized model and vocoder.
31
+ """
32
+
33
+ if not os.path.isfile(model_path):
34
+ raise FileNotFoundError(
35
+ f"Model checkpoint not found: {model_path}. "
36
+ "Please download the pretrained model and place it at the path, or set --model_path."
37
+ )
38
+ model = SoulXSinger(config).to(device)
39
+ print("Model initialized.")
40
+ print("Model parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M")
41
+
42
+ checkpoint = torch.load(model_path, weights_only=False, map_location=device)
43
+ if "state_dict" not in checkpoint:
44
+ raise KeyError(
45
+ f"Checkpoint at {model_path} has no 'state_dict' key. "
46
+ "Expected a checkpoint saved with model.state_dict()."
47
+ )
48
+ model.load_state_dict(checkpoint["state_dict"], strict=True)
49
+
50
+ model.eval()
51
+ model.to(device)
52
+ print("Model checkpoint loaded.")
53
+
54
+ return model
55
+
56
+
57
+ def process(args, config, model: torch.nn.Module):
58
+ """Run the full inference pipeline given a data_processor and model.
59
+ """
60
+ if args.control not in ("melody", "score"):
61
+ raise ValueError(f"control must be 'melody' or 'score', got: {args.control}")
62
+
63
+ print(f"prompt_metadata_path: {args.prompt_metadata_path}")
64
+ print(f"target_metadata_path: {args.target_metadata_path}")
65
+
66
+ os.makedirs(args.save_dir, exist_ok=True)
67
+ data_processor = DataProcessor(
68
+ hop_size=config.audio.hop_size,
69
+ sample_rate=config.audio.sample_rate,
70
+ phoneset_path=args.phoneset_path,
71
+ device=args.device,
72
+ )
73
+
74
+ with open(args.prompt_metadata_path, "r", encoding="utf-8") as f:
75
+ prompt_meta_list = json.load(f)
76
+ if not prompt_meta_list:
77
+ raise ValueError("Prompt metadata is empty. Please run preprocess on prompt audio first.")
78
+ prompt_meta = prompt_meta_list[0] # load the first segment as the prompt
79
+ with open(args.target_metadata_path, "r", encoding="utf-8") as f:
80
+ target_meta_list = json.load(f)
81
+ infer_prompt_data = data_processor.process(prompt_meta, args.prompt_wav_path)
82
+
83
+ assert len(target_meta_list) > 0, "No target segments found in the target metadata."
84
+ generated_len = int(target_meta_list[-1]["time"][1] / 1000 * config.audio.sample_rate)
85
+ generated_merged = np.zeros(generated_len, dtype=np.float32)
86
+
87
+ for idx, target_meta in enumerate(
88
+ tqdm(target_meta_list, total=len(target_meta_list), desc="Inferring segments"),
89
+ ):
90
+ start_sample_idx = int(target_meta["time"][0] / 1000 * config.audio.sample_rate)
91
+ end_sample_idx = int(target_meta["time"][1] / 1000 * config.audio.sample_rate)
92
+ infer_target_data = data_processor.process(target_meta, None)
93
+
94
+ infer_data = {
95
+ "prompt": infer_prompt_data,
96
+ "target": infer_target_data,
97
+ }
98
+
99
+ with torch.no_grad():
100
+ generated_audio = model.infer(
101
+ infer_data,
102
+ auto_shift=args.auto_shift,
103
+ pitch_shift=args.pitch_shift,
104
+ n_steps=config.infer.n_steps,
105
+ cfg=config.infer.cfg,
106
+ control=args.control,
107
+ )
108
+
109
+ generated_audio = generated_audio.squeeze().cpu().numpy()
110
+ generated_merged[start_sample_idx : start_sample_idx + generated_audio.shape[0]] = generated_audio
111
+
112
+ merged_path = os.path.join(args.save_dir, "generated.wav")
113
+ sf.write(merged_path, generated_merged, 24000)
114
+ print(f"Generated audio saved to {merged_path}")
115
+
116
+
117
+ def main(args, config):
118
+ model = build_model(
119
+ model_path=args.model_path,
120
+ config=config,
121
+ device=args.device,
122
+ )
123
+ process(args, config, model)
124
+
125
+ if __name__ == "__main__":
126
+ parser = argparse.ArgumentParser()
127
+ parser.add_argument("--device", type=str, default="cuda")
128
+ parser.add_argument("--model_path", type=str, default='pretrained_models/soulx-singer/model.pt')
129
+ parser.add_argument("--config", type=str, default='soulxsinger/config/soulxsinger.yaml')
130
+ parser.add_argument("--prompt_wav_path", type=str, default='example/audio/zh_prompt.wav')
131
+ parser.add_argument("--prompt_metadata_path", type=str, default='example/metadata/zh_prompt.json')
132
+ parser.add_argument("--target_metadata_path", type=str, default='example/metadata/zh_target.json')
133
+ parser.add_argument("--phoneset_path", type=str, default='soulxsinger/utils/phoneme/phone_set.json')
134
+ parser.add_argument("--save_dir", type=str, default='outputs')
135
+ parser.add_argument("--auto_shift", action="store_true")
136
+ parser.add_argument("--pitch_shift", type=int, default=0)
137
+ parser.add_argument(
138
+ "--control",
139
+ type=str,
140
+ default="melody",
141
+ choices=["melody", "score"],
142
+ help="Control mode: melody or score only",
143
+ )
144
+ args = parser.parse_args()
145
+
146
+ config = load_config(args.config)
147
+ main(args, config)
deploy_to_hf.sh ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # 快速部署脚本:将 SoulX-Singer 部署到 Hugging Face Space
3
+ # 使用方法: ./deploy_to_hf.sh YOUR_USERNAME YOUR_SPACE_NAME
4
+
5
+ set -e
6
+
7
+ if [ $# -lt 2 ]; then
8
+ echo "用法: $0 <YOUR_USERNAME> <YOUR_SPACE_NAME>"
9
+ echo "示例: $0 myusername soulx-singer-demo"
10
+ exit 1
11
+ fi
12
+
13
+ USERNAME=$1
14
+ SPACE_NAME=$2
15
+ SPACE_REPO="https://huggingface.co/spaces/${USERNAME}/${SPACE_NAME}"
16
+
17
+ echo "🚀 开始部署到 Hugging Face Space..."
18
+ echo "Space: ${USERNAME}/${SPACE_NAME}"
19
+ echo ""
20
+
21
+ # 检查是否已安装 huggingface_hub
22
+ if ! command -v huggingface-cli &> /dev/null; then
23
+ echo "⚠️ 未检测到 huggingface-cli,正在安装..."
24
+ pip install -U huggingface_hub
25
+ fi
26
+
27
+ # 检查是否已登录
28
+ if ! huggingface-cli whoami &> /dev/null; then
29
+ echo "🔐 请先登录 Hugging Face..."
30
+ huggingface-cli login
31
+ fi
32
+
33
+ # 创建 Space(如果不存在)
34
+ echo "📦 检查 Space 是否存在..."
35
+ if ! huggingface-cli repo info "${USERNAME}/${SPACE_NAME}" --repo-type space &> /dev/null; then
36
+ echo "✨ 创建新的 Space..."
37
+ huggingface-cli repo create "${SPACE_NAME}" --type space --sdk gradio
38
+ else
39
+ echo "✅ Space 已存在"
40
+ fi
41
+
42
+ # 检查是否已初始化 Git
43
+ if [ ! -d ".git" ]; then
44
+ echo "📝 初始化 Git 仓库..."
45
+ git init
46
+ git add .
47
+ git commit -m "Initial commit for HF Space deployment" || echo "⚠️ 没有新文件需要提交"
48
+ fi
49
+
50
+ # 检查远程仓库
51
+ if git remote | grep -q "^origin$"; then
52
+ echo "🔄 更新远程仓库地址..."
53
+ git remote set-url origin "${SPACE_REPO}"
54
+ else
55
+ echo "➕ 添加远程仓库..."
56
+ git remote add origin "${SPACE_REPO}"
57
+ fi
58
+
59
+ # 推送代码
60
+ echo "📤 推送代码到 Hugging Face..."
61
+ git push -u origin main || git push -u origin master
62
+
63
+ echo ""
64
+ echo "✅ 部署完成!"
65
+ echo "🌐 Space 地址: ${SPACE_REPO}"
66
+ echo ""
67
+ echo "💡 提示:"
68
+ echo " - Space 会自动开始构建,请查看 Logs 标签页"
69
+ echo " - 首次运行会下载模型,可能需要 5-15 分钟"
70
+ echo " - 建议在 Space Settings 中选择 GPU T4 Small 硬件"
example/audio/en_prompt.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "index": "vocal_5220_10280",
4
+ "language": "English",
5
+ "time": [
6
+ 5220,
7
+ 10280
8
+ ],
9
+ "duration": "0.24 0.36 0.30 0.78 0.24 0.56 0.19 0.53 0.36 0.20 0.32 0.57 0.19 0.22",
10
+ "text": "<SP> Ooh Ooh <SP> I wish nothing nothing more more the best best <SP>",
11
+ "phoneme": "<SP> en_UW1 en_UW1 <SP> en_AY1 en_W-IH1-SH en_N-AH1-TH-IH0-NG en_N-AH1-TH-IH0-NG en_M-AO1-R en_M-AO1-R en_DH-AH0 en_B-EH1-S-T en_B-EH1-S-T <SP>",
12
+ "note_pitch": "0 63 65 0 65 67 68 62 62 64 67 67 65 0",
13
+ "note_type": "1 2 3 1 2 2 2 3 2 3 2 2 3 1",
14
+ "f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 345.2 343.1 341.6 339.8 337.8 331.9 319.5 312.1 310.8 312.6 315.1 316.1 315.3 314.6 315.3 317.9 322.0 329.6 337.5 344.7 347.5 347.2 344.3 339.5 338.2 341.7 342.8 342.2 340.7 343.0 342.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 347.0 345.3 348.7 350.2 350.9 350.3 344.7 340.3 338.3 338.0 342.8 347.4 348.3 346.7 343.4 339.5 340.5 345.2 350.4 357.7 367.3 376.6 385.9 392.6 393.6 389.9 384.7 381.8 382.0 383.0 380.6 373.5 367.9 377.0 385.4 391.4 393.8 395.6 396.1 397.2 399.8 406.0 413.5 416.1 416.0 414.4 413.5 412.9 415.5 418.9 417.5 408.8 389.2 373.9 0.0 0.0 0.0 288.5 286.0 284.2 285.6 288.9 291.3 293.5 294.5 295.2 297.8 299.5 301.0 303.0 305.9 306.8 306.0 304.4 301.8 301.0 300.8 301.8 310.2 309.8 308.2 305.9 303.6 301.5 299.3 298.5 300.0 302.1 303.5 303.6 302.2 299.7 297.5 296.3 296.4 296.8 298.6 302.6 311.8 322.0 333.8 349.0 368.8 393.3 407.1 410.7 407.0 402.3 401.2 401.7 403.9 405.7 403.5 396.8 387.4 378.6 377.8 381.4 384.0 384.7 383.5 382.5 380.8 377.3 378.4 383.5 390.0 392.7 390.5 387.6 385.3 382.7 381.0 382.8 383.9 382.2 379.6 379.3 380.2 383.1 386.0 386.5 385.4 384.3 383.7 384.4 386.2 388.2 388.5 385.0 378.6 360.4 333.7 328.2 332.4 340.2 348.9 339.6 334.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
15
+ }
16
+ ]
example/audio/en_prompt.mp3 ADDED
Binary file (86.8 kB). View file
 
example/audio/en_target.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "index": "vocal_0_6900",
4
+ "language": "English",
5
+ "time": [
6
+ 0,
7
+ 6900
8
+ ],
9
+ "duration": "0.16 0.24 0.32 0.15 0.17 0.24 0.15 0.44 0.29 0.32 0.24 0.32 0.22 0.18 0.24 0.25 1.01 0.26 0.48 0.29 0.79 0.14",
10
+ "text": "<SP> Who says you're you're not pretty <SP> pretty <SP> Who says you're you're not beautiful beautiful <SP> Who says says <SP>",
11
+ "phoneme": "<SP> en_HH-UW1 en_S-EH1-Z en_Y-UH1-R en_Y-UH1-R en_N-AA1-T en_P-R-IH1-T-IY0 <SP> en_P-R-IH1-T-IY0 <SP> en_HH-UW1 en_S-EH1-Z en_Y-UH1-R en_Y-UH1-R en_N-AA1-T en_B-Y-UW1-T-AH0-F-AH0-L en_B-Y-UW1-T-AH0-F-AH0-L <SP> en_HH-UW1 en_S-EH1-Z en_S-EH1-Z <SP>",
12
+ "note_pitch": "0 68 67 65 63 63 66 67 70 66 68 67 65 63 63 67 65 63 65 61 58 0",
13
+ "note_type": "1 2 2 2 3 2 2 1 3 1 2 2 2 3 2 2 3 1 2 2 3 1",
14
+ "f0": "0.0 0.0 382.7 387.7 385.9 379.8 376.0 380.9 390.1 403.2 415.3 423.6 421.6 402.6 385.2 381.1 0.0 0.0 425.8 419.0 409.6 397.8 392.2 389.0 388.5 391.4 389.1 381.4 375.9 0.0 0.0 0.0 0.0 359.0 354.7 353.8 353.7 354.7 353.1 351.1 350.4 349.0 348.9 346.3 337.4 328.0 312.8 303.1 298.4 296.0 298.9 302.0 306.3 307.9 307.3 307.5 307.3 302.9 301.8 0.0 0.0 0.0 0.0 0.0 343.7 364.3 375.9 368.5 358.1 359.1 365.9 378.4 393.1 406.0 412.5 410.9 407.0 404.1 403.5 403.4 401.5 399.4 397.7 395.4 394.4 394.8 395.5 396.5 397.5 400.8 407.9 415.1 417.8 453.1 472.2 481.0 482.3 481.9 480.8 478.7 477.4 476.8 474.8 467.5 446.0 390.4 382.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 374.3 375.5 370.9 370.7 373.1 378.3 392.2 407.6 418.5 423.9 423.2 415.9 395.4 0.0 0.0 0.0 421.5 416.2 405.3 391.0 383.1 380.8 383.0 388.3 388.8 378.3 371.7 0.0 0.0 0.0 371.4 365.1 362.7 358.5 353.0 352.0 353.5 356.1 356.4 353.6 348.3 341.1 330.6 317.7 303.8 293.3 296.5 297.7 301.4 305.3 308.8 308.8 308.2 308.2 306.3 305.6 285.0 269.8 265.6 280.0 304.4 331.2 351.0 357.9 364.2 370.6 381.1 392.9 399.0 399.1 395.0 389.5 379.9 363.0 338.9 318.5 305.6 300.3 299.6 296.3 292.2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 309.6 322.1 329.8 331.2 332.1 332.6 332.4 335.4 340.7 345.0 347.2 346.2 342.6 339.6 337.4 338.3 340.9 342.6 344.0 344.6 344.0 344.2 343.6 341.9 338.8 336.7 337.6 341.1 347.0 350.4 343.0 326.6 330.8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 310.0 315.7 317.7 317.4 316.5 314.4 314.1 322.5 336.3 350.0 354.2 352.9 350.5 348.4 347.1 347.3 348.4 349.8 349.8 350.4 350.3 323.9 324.7 0.0 0.0 0.0 0.0 0.0 0.0 0.0 297.3 289.8 279.5 275.1 276.1 276.1 274.9 275.4 274.6 271.8 268.6 264.0 258.3 251.7 244.3 239.9 236.1 233.7 234.0 236.0 237.3 236.9 235.2 233.5 231.7 231.0 232.1 233.6 235.4 236.2 236.7 235.8 234.1 232.2 231.3 232.6 233.5 235.2 236.0 232.3 228.8 229.6 233.8 241.3 239.4 226.3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
15
+ }
16
+ ]
example/audio/en_target.mp3 ADDED
Binary file (66.9 kB). View file
 
example/audio/music.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "index": "vocal_240_51240",
4
+ "language": "Mandarin",
5
+ "time": [
6
+ 240,
7
+ 51240
8
+ ],
9
+ "duration": "0.21 0.18 0.34 0.38 0.26 0.22 0.50 0.20 0.33 0.13 0.60 0.22 0.38 0.52 0.30 0.12 1.82 0.98 0.20 0.26 0.38 0.36 0.58 0.54 0.26 0.50 1.46 3.03 0.24 0.24 0.29 0.25 0.20 0.14 0.80 0.22 0.54 0.30 0.24 0.16 0.58 0.28 0.38 1.73 0.79 0.24 0.30 0.34 0.30 0.30 0.34 0.21 0.21 0.36 0.34 0.23 1.85 1.90 0.23 0.39 0.68 0.50 0.31 0.43 0.76 0.38 2.00 1.87 0.68 0.72 0.56 0.62 0.80 0.40 0.42 1.68 1.79 0.70 0.66 0.54 0.24 0.48 0.68 0.40 2.34 0.14",
10
+ "text": "<SP> 只 是 因 为 为 在 人 群 群 中 多 看 了 你 一 眼 <SP> 再 也 没 能 忘 掉 你 容 颜 <SP> 梦 想 着 着 偶 偶 然 然 有 <SP> 一 一 天 再 相 见 <SP> 从 此 我 开 始 始 孤 孤 单 思 念 念 <SP> 想 想 你 时 你 你 在 天 边 <SP> 想 你 时 你 在 眼 前 前 <SP> 想 你 时 你 你 在 脑 海 <SP>",
11
+ "phoneme": "<SP> zh_zhi3 zh_shi4 zh_yin1 zh_wei4 zh_wei2 zh_zai4 zh_ren2 zh_qun2 zh_qun2 zh_zhong1 zh_duo1 zh_kan4 zh_le5 zh_ni3 zh_yi1 zh_yan3 <SP> zh_zai4 zh_ye3 zh_mei2 zh_neng2 zh_wang4 zh_diao4 zh_ni3 zh_rong2 zh_yan2 <SP> zh_meng4 zh_xiang3 zh_zhe5 zh_zhe5 zh_ou3 zh_ou3 zh_ran2 zh_ran2 zh_you3 <SP> zh_yi1 zh_yi1 zh_tian1 zh_zai4 zh_xiang1 zh_jian4 <SP> zh_cong2 zh_ci3 zh_wo3 zh_kai1 zh_shi3 zh_shi3 zh_gu1 zh_gu1 zh_dan1 zh_si1 zh_nian4 zh_nian4 <SP> zh_xiang3 zh_xiang3 zh_ni3 zh_shi2 zh_ni3 zh_ni3 zh_zai4 zh_tian1 zh_bian1 <SP> zh_xiang3 zh_ni3 zh_shi2 zh_ni3 zh_zai4 zh_yan3 zh_qian2 zh_qian2 <SP> zh_xiang3 zh_ni3 zh_shi2 zh_ni3 zh_ni3 zh_zai4 zh_nao3 zh_hai3 <SP>",
12
+ "note_pitch": "0 64 64 64 66 68 66 66 66 64 64 64 64 66 64 60 61 0 63 63 63 64 66 63 61 59 56 0 68 66 66 68 68 66 66 64 64 0 64 66 61 61 61 64 0 62 63 63 64 64 66 65 66 61 58 59 56 0 69 71 66 68 68 71 66 64 61 0 66 61 68 66 64 63 61 59 0 71 66 68 68 71 66 64 61 0",
13
+ "note_type": "1 2 2 2 2 3 2 2 2 3 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 1 2 2 2 3 2 3 2 3 2 1 2 3 2 2 2 2 1 2 2 2 2 2 3 2 3 2 2 2 3 1 2 3 2 2 2 3 2 2 2 1 2 2 2 2 2 2 2 3 1 2 2 2 2 3 2 2 2 1",
14
+ "f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 318.9 332.2 331.2 323.9 312.3 0.0 0.0 0.0 0.0 0.0 0.0 340.9 342.3 340.2 337.4 333.5 329.3 327.8 328.3 329.1 330.5 331.9 331.5 329.2 327.6 327.4 327.7 329.6 330.5 329.5 328.1 328.5 330.1 330.4 329.5 328.9 331.0 333.6 332.8 331.7 330.2 330.6 330.9 329.5 327.6 332.4 343.5 363.9 368.7 369.3 368.2 366.7 371.3 382.0 392.9 402.3 409.8 413.8 416.3 416.4 415.4 414.1 413.4 415.3 417.4 418.3 413.6 416.0 0.0 0.0 0.0 0.0 368.2 362.5 361.3 362.0 362.9 364.3 366.4 367.5 368.4 368.9 368.1 367.6 369.7 371.1 371.1 369.5 368.0 367.3 366.6 367.1 369.5 372.0 371.3 368.2 368.1 369.6 369.9 373.4 376.5 360.2 285.6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 384.0 376.7 372.0 371.7 373.9 376.3 375.5 361.2 333.0 317.8 319.0 329.1 333.5 333.2 333.2 333.3 328.0 314.5 0.0 0.0 0.0 320.6 326.5 332.9 335.7 334.1 329.9 327.4 326.3 326.6 329.1 330.5 330.0 328.0 326.7 328.6 331.1 332.4 332.1 332.8 333.2 333.1 331.8 328.6 318.9 293.2 323.6 327.5 330.0 332.8 331.1 319.4 272.2 0.0 0.0 0.0 0.0 278.5 294.5 310.0 318.5 322.7 327.4 334.1 337.2 336.3 328.5 321.0 324.2 341.5 362.2 374.7 375.8 370.5 366.0 365.2 368.7 371.7 374.4 373.0 368.7 370.4 374.8 375.1 372.5 368.7 363.0 357.1 359.0 366.8 377.1 379.5 371.5 359.9 351.0 358.4 371.7 377.9 375.8 367.8 359.7 363.6 375.3 379.9 377.5 372.9 359.8 345.2 337.1 334.5 333.4 332.8 329.6 319.2 292.8 261.7 253.4 262.8 273.0 278.1 279.3 278.6 277.6 277.9 278.3 278.0 277.3 276.7 275.6 275.3 276.1 277.8 278.2 277.7 277.6 277.2 276.4 275.5 273.9 271.8 270.7 274.7 282.7 285.5 281.4 273.1 264.4 262.2 268.9 278.4 284.3 283.6 277.3 268.5 262.2 263.2 270.7 278.8 285.0 287.3 282.0 272.1 265.2 266.0 272.1 278.2 283.8 285.1 281.4 272.6 267.2 269.1 276.1 282.7 289.3 290.5 285.6 273.9 267.5 267.4 271.8 277.7 282.6 283.1 277.8 269.8 265.6 268.6 273.8 281.1 286.2 287.2 282.3 270.4 264.5 263.3 263.0 268.5 268.3 268.7 270.9 277.2 277.4 284.3 285.7 282.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 311.5 306.4 306.0 306.8 305.8 305.0 307.2 310.8 312.4 313.5 314.4 314.6 313.7 312.2 311.1 311.1 312.8 314.6 312.7 306.5 299.7 302.7 308.7 315.2 314.2 310.9 309.4 309.8 310.6 310.7 310.8 311.4 312.2 313.0 312.5 312.2 312.2 311.9 312.0 312.2 311.4 306.1 303.3 309.3 315.0 325.3 324.4 323.9 325.0 326.4 327.7 327.7 328.1 328.6 327.8 328.3 328.3 327.6 328.6 336.1 347.2 359.4 371.0 374.9 371.3 366.5 369.5 375.5 376.2 370.2 365.7 367.3 372.9 376.9 375.3 368.4 360.1 358.0 365.4 380.8 383.5 380.4 374.4 367.6 365.9 372.0 376.3 377.9 375.7 372.6 368.6 361.2 353.7 0.0 309.4 302.3 299.2 300.6 305.7 308.5 308.6 309.0 310.9 310.9 309.5 308.6 309.0 311.3 313.0 314.4 314.5 313.0 312.1 311.2 308.0 299.5 295.8 295.0 285.7 278.0 280.1 281.2 279.9 278.3 278.3 279.5 279.7 272.2 259.1 0.0 0.0 0.0 240.7 236.0 233.8 233.5 234.6 235.7 235.6 236.2 239.8 245.4 248.8 250.1 250.4 249.9 249.1 247.2 241.7 231.7 216.0 197.6 192.1 197.2 204.3 206.5 205.4 201.7 200.8 203.5 208.2 210.9 210.3 207.1 203.5 202.3 202.1 202.8 205.1 207.3 208.1 207.3 205.4 202.1 199.3 200.5 206.1 212.9 212.9 206.5 197.1 191.1 191.2 196.8 203.6 207.8 208.7 206.4 201.7 197.3 197.4 200.7 206.5 209.9 209.9 206.9 203.3 201.2 203.4 207.7 210.5 208.9 207.6 206.5 204.2 202.0 203.6 205.7 210.8 213.1 214.3 210.6 204.1 199.7 202.3 211.9 217.7 215.1 215.0 215.3 213.3 0.0 0.0 216.3 0.0 0.0 195.2 209.0 205.3 201.3 196.8 195.7 195.8 195.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 431.9 425.8 421.0 417.7 419.3 422.7 423.4 421.1 417.9 411.9 408.6 411.5 414.9 419.6 426.7 424.2 0.0 0.0 0.0 280.8 279.4 369.7 371.8 370.6 369.8 370.5 370.5 371.4 374.6 374.0 361.0 332.7 328.0 0.0 359.0 365.0 371.1 373.8 371.9 369.7 369.6 371.0 378.3 390.4 403.4 414.1 417.9 417.5 417.0 416.6 415.7 414.9 414.5 413.8 413.5 413.6 413.8 415.2 416.9 417.3 415.4 415.0 414.1 411.4 409.2 403.5 392.7 375.6 367.5 365.8 368.8 370.7 370.8 369.0 367.0 366.0 366.2 367.8 369.2 368.8 367.2 366.7 367.3 367.3 368.8 368.5 366.8 365.3 364.0 363.0 365.0 367.3 367.7 365.5 364.0 364.9 367.4 370.3 371.3 369.9 366.5 364.1 368.3 385.8 406.7 409.4 399.7 376.0 357.2 359.9 372.4 377.5 366.9 345.1 320.6 315.8 321.8 329.0 329.9 326.4 324.5 324.8 324.9 325.3 325.1 324.7 325.9 326.8 326.6 326.6 326.4 325.4 323.2 322.5 323.9 326.2 328.6 329.9 329.5 329.0 328.2 327.7 327.9 329.6 331.2 331.4 334.2 340.1 335.5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 309.0 320.1 329.6 334.4 334.4 334.5 338.2 346.0 355.1 361.4 366.1 369.5 372.7 373.3 372.9 375.2 375.8 374.7 367.6 360.1 0.0 0.0 0.0 0.0 0.0 0.0 283.6 280.4 277.2 274.3 271.8 271.0 273.3 277.2 278.8 278.1 277.2 277.4 278.8 279.3 280.2 281.4 279.2 274.9 268.3 254.1 239.9 0.0 0.0 267.7 267.7 270.1 271.8 273.1 276.3 278.3 269.1 256.1 0.0 0.0 0.0 0.0 286.0 281.1 275.7 272.7 271.7 271.7 272.5 274.0 275.5 277.9 281.1 287.1 314.0 342.2 363.0 376.8 375.4 365.8 351.8 323.3 0.0 0.0 0.0 0.0 299.6 322.7 336.5 336.9 335.4 332.9 330.0 326.6 323.4 322.3 324.4 326.7 328.5 328.5 326.4 324.0 322.5 323.2 325.5 327.7 328.6 328.1 325.4 320.9 316.5 316.1 320.8 327.7 333.7 333.2 325.6 315.5 306.9 305.5 315.1 331.0 343.3 341.6 331.4 319.5 308.3 307.6 318.1 329.5 338.2 342.0 337.4 329.4 321.0 316.6 319.8 330.8 340.6 344.5 342.3 331.4 319.3 315.7 322.9 329.3 336.5 344.0 341.1 328.3 318.7 320.0 328.2 333.5 337.9 339.7 338.5 327.3 323.3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 285.3 284.6 284.3 288.2 295.1 305.6 312.2 315.5 287.8 0.0 0.0 0.0 0.0 0.0 309.8 312.4 313.8 312.8 310.8 310.7 312.5 313.2 310.7 305.7 305.9 308.1 309.9 311.3 311.0 309.8 309.7 310.6 311.1 311.1 311.0 311.9 313.7 316.4 319.5 311.6 281.3 283.4 306.6 0.0 0.0 0.0 334.4 326.7 322.6 321.6 321.7 324.5 332.0 336.7 333.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 332.5 331.7 332.7 334.8 334.7 335.9 339.7 345.6 353.9 362.3 370.2 374.0 373.8 371.4 370.0 368.2 367.0 368.0 370.4 371.2 371.3 370.0 360.3 0.0 0.0 0.0 319.1 326.8 331.5 332.0 331.6 327.2 328.2 334.1 343.1 350.7 358.0 363.5 368.4 370.8 371.8 371.1 370.2 368.5 368.0 371.1 375.5 376.3 369.5 0.0 0.0 0.0 280.4 267.2 263.6 265.8 271.2 275.0 276.4 276.9 277.1 280.9 287.5 284.7 280.7 0.0 0.0 0.0 0.0 0.0 0.0 231.6 226.0 222.9 224.1 226.2 228.5 234.3 240.9 246.6 249.8 249.6 246.3 245.2 245.8 247.2 248.3 249.3 249.0 245.9 242.7 235.8 226.3 217.8 213.1 209.3 208.0 207.0 205.9 205.8 203.4 201.5 205.3 207.5 209.1 210.4 208.9 203.2 199.2 199.3 201.2 205.6 206.3 204.0 202.7 202.2 203.3 206.1 205.6 201.9 198.8 195.5 195.6 198.7 204.2 214.5 218.3 214.3 207.0 199.1 192.4 189.9 193.6 201.0 210.6 212.4 209.2 202.0 195.2 190.0 190.4 196.0 204.2 209.2 208.9 203.0 195.8 190.7 189.8 194.6 200.8 208.7 213.5 210.5 202.0 193.9 187.7 187.8 193.1 199.8 202.8 204.1 203.8 200.4 197.0 193.9 191.6 192.5 198.4 204.8 203.9 203.5 201.3 200.2 198.4 198.5 201.0 204.0 204.6 205.4 202.0 199.2 194.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 426.6 432.4 436.6 437.2 434.6 433.0 434.0 453.1 473.8 492.8 503.2 501.8 491.3 479.6 478.1 489.7 503.6 508.3 503.2 489.5 482.3 484.2 495.7 505.4 506.6 501.6 495.4 492.7 497.1 500.4 496.1 490.1 480.4 453.8 412.6 373.6 357.7 354.5 356.8 361.0 365.0 367.8 369.5 370.1 371.1 371.3 370.8 370.1 369.7 369.0 369.5 370.4 371.1 370.9 370.2 370.8 371.3 372.7 380.2 386.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 435.0 427.8 417.0 411.6 412.6 411.2 410.7 412.5 413.7 414.9 414.8 415.7 415.4 414.0 413.8 414.6 415.2 416.3 415.2 413.6 413.4 414.9 416.8 416.4 414.7 413.8 413.2 414.7 417.5 420.7 428.2 443.1 459.8 481.7 502.3 510.8 509.1 497.8 490.6 493.5 500.8 508.3 511.8 504.6 493.5 487.3 492.3 507.2 518.8 514.8 496.4 489.6 490.4 491.8 493.3 0.0 0.0 0.0 0.0 0.0 391.9 368.8 360.2 358.3 360.0 362.0 361.8 360.9 360.3 361.8 365.7 368.7 369.3 367.7 364.7 365.6 368.4 370.4 371.7 369.0 366.8 367.5 367.9 369.4 374.6 384.6 396.8 388.2 382.8 0.0 0.0 0.0 0.0 0.0 352.6 339.5 328.2 326.0 326.5 326.9 328.2 329.4 329.4 329.3 337.3 349.1 350.9 343.6 333.4 326.5 327.1 331.5 333.4 326.8 318.4 0.0 0.0 0.0 282.0 283.6 281.6 279.5 273.2 267.9 268.8 273.2 277.5 279.8 278.7 276.8 273.3 268.7 265.0 266.1 272.4 282.3 284.9 279.4 268.5 255.3 253.0 259.9 270.2 281.5 284.3 279.7 270.1 260.2 256.4 260.1 266.6 273.0 279.6 283.6 282.3 272.3 262.1 255.8 259.8 269.6 274.9 278.7 278.4 271.5 261.7 255.8 256.8 263.9 272.1 279.7 278.4 269.9 259.0 255.9 262.1 269.2 274.5 279.6 282.1 280.7 274.3 269.9 270.9 270.7 273.8 276.9 274.9 271.1 266.5 266.7 269.4 276.9 282.9 282.0 279.5 274.9 271.2 267.9 263.1 270.2 281.3 285.5 283.9 280.6 271.1 262.8 263.9 263.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 362.9 364.0 366.8 366.6 363.6 359.8 357.7 359.9 367.6 371.6 365.4 359.3 358.6 359.5 365.5 375.7 378.1 371.5 359.5 350.8 358.6 374.7 381.7 376.8 363.6 353.3 356.9 367.8 378.3 380.8 375.7 372.3 372.6 374.7 372.9 365.2 346.4 314.8 285.6 277.3 275.4 275.3 277.5 279.2 279.2 278.8 277.6 276.4 276.5 277.6 278.8 279.9 281.8 282.7 281.0 279.2 278.6 279.4 279.9 279.2 278.7 281.7 285.8 289.9 288.7 284.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 428.0 417.9 410.0 409.5 411.2 414.0 417.5 417.9 418.1 419.3 419.0 417.8 417.1 417.2 416.6 415.8 415.0 415.1 414.5 410.9 401.5 388.2 371.5 359.2 363.0 371.7 375.7 370.6 359.6 353.4 358.7 367.0 374.5 378.2 373.7 362.8 358.0 361.6 368.5 373.7 377.7 375.8 368.8 361.7 360.5 366.7 372.5 375.6 380.1 384.9 383.5 376.6 0.0 0.0 0.0 0.0 0.0 331.7 320.0 316.8 320.7 325.1 327.0 327.4 326.6 326.4 325.5 325.1 326.2 327.6 328.5 328.4 327.7 328.7 329.9 329.8 329.7 327.7 326.5 327.8 328.9 329.3 329.8 329.2 328.8 329.4 330.4 331.0 331.0 330.5 329.2 328.2 328.3 328.7 329.4 327.5 322.9 314.8 299.3 290.8 292.3 299.0 305.4 309.2 312.1 313.9 316.8 320.4 327.6 333.1 338.0 341.2 342.1 338.4 331.0 0.0 0.0 0.0 0.0 0.0 0.0 311.7 292.5 281.8 276.1 275.1 276.6 277.6 277.4 274.0 264.8 249.6 239.1 238.8 243.0 245.7 245.0 241.9 237.9 234.6 237.1 242.7 250.4 252.5 248.3 240.7 231.6 228.3 234.0 240.6 247.0 250.6 248.5 242.1 233.6 226.8 227.2 234.0 240.7 246.9 248.9 244.4 237.6 230.8 228.8 236.4 246.9 250.2 249.2 245.4 239.9 232.5 224.7 226.5 238.5 252.0 257.4 255.9 248.4 237.4 230.5 231.6 238.8 247.6 252.6 253.9 253.0 248.9 241.8 236.4 234.1 236.5 246.2 257.7 255.4 248.2 238.5 234.8 237.8 244.4 250.5 253.8 251.7 246.7 240.0 238.3 244.8 251.6 256.4 256.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 484.0 484.5 487.5 490.6 487.2 479.7 478.9 483.8 494.1 505.7 510.3 501.1 485.4 472.4 472.3 495.6 515.8 517.1 503.8 486.5 475.1 479.9 491.4 507.9 509.9 505.0 496.2 485.7 484.5 490.8 495.4 497.7 496.6 489.9 469.1 424.3 387.2 373.0 367.3 364.4 364.6 366.2 366.6 368.8 371.9 373.1 371.4 368.9 367.1 367.0 367.7 368.8 370.1 369.8 369.2 369.8 370.0 369.0 369.0 371.0 375.2 385.3 391.6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 425.2 418.5 409.4 409.5 408.1 406.3 408.4 410.3 412.8 413.8 412.6 411.2 410.8 410.4 410.2 410.5 412.2 413.6 413.4 412.6 412.6 413.4 415.6 417.9 418.5 417.0 415.0 415.0 416.8 421.5 437.9 455.0 472.6 493.0 502.2 496.7 484.6 480.6 487.2 498.6 506.3 506.2 499.4 489.7 489.5 496.2 503.8 507.2 501.8 491.1 489.9 501.1 520.5 533.7 525.8 0.0 0.0 0.0 0.0 0.0 421.8 365.2 351.0 352.0 358.0 362.6 365.6 365.4 363.2 362.5 362.2 362.4 364.1 365.2 364.9 363.5 363.4 366.3 368.5 370.0 369.7 368.5 366.9 365.7 366.3 368.6 370.8 372.2 370.2 373.1 377.0 376.5 371.0 346.8 326.1 325.2 328.4 331.4 334.0 332.2 327.2 324.4 331.3 349.7 363.2 362.0 346.1 327.4 321.8 329.8 340.7 348.4 349.9 346.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 269.5 275.7 282.6 282.9 278.8 275.2 274.3 273.5 273.9 275.3 276.4 277.3 276.1 272.3 268.0 268.3 272.5 279.1 281.7 280.1 271.3 261.0 257.3 262.7 273.2 280.8 283.3 279.3 269.1 258.3 258.0 266.3 278.2 286.7 287.8 283.4 274.9 264.7 257.7 260.0 272.0 286.3 294.7 289.8 274.5 263.8 263.4 268.8 277.3 284.1 286.5 285.2 281.7 272.7 264.3 260.3 267.7 281.7 289.2 289.0 281.2 266.2 256.0 254.4 261.5 276.4 286.8 288.8 286.7 273.8 260.9 260.8 270.2 283.4 291.3 292.8 286.1 273.8 265.2 264.0 271.3 283.8 290.3 289.7 277.7 266.9 260.5 263.0 267.8 281.5 286.7 286.1 285.5 280.1 279.0 283.0 284.1 284.5 285.5 282.2 0.0 0.0 280.0 276.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
15
+ }
16
+ ]
example/audio/music.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04b35a7b9d03adc494c304af5c4413aa33a02a54a7110016d6e3b559843d90de
3
+ size 1243961
example/audio/yue_target.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "index": "vocal_420_14370",
4
+ "language": "Cantonese",
5
+ "time": [
6
+ 420,
7
+ 14370
8
+ ],
9
+ "duration": "0.31 0.26 0.28 0.26 0.40 0.20 0.42 0.24 0.36 0.24 0.32 0.26 0.94 0.32 0.24 0.30 0.34 0.22 0.34 0.90 0.22 0.36 0.32 0.30 0.22 0.36 0.22 0.32 0.34 0.20 0.40 0.24 0.30 0.38 0.22 0.32 0.28 0.36 0.24 0.34 0.26 0.60",
10
+ "text": "<SP> 我 的 心 情 又 像 真 该 等 被 揭 开 嘴 巴 却 再 仰 千 台 人 潮 内 越 文 静 越 变 得 不 受 理 睬 睬 自 己 己 要 交 出 意 外",
11
+ "phoneme": "<SP> yue_ngo5 yue_dik1 yue_sam1 yue_cing4 yue_jau6 yue_zoeng6 yue_zan1 yue_goi1 yue_dang2 yue_bei6 yue_kit3 yue_hoi1 yue_zeoi2 yue_baa1 yue_koek3 yue_zoi3 yue_joeng5 yue_cin1 yue_toi4 yue_jan4 yue_ciu4 yue_noi6 yue_jyut6 yue_man4 yue_zing6 yue_jyut6 yue_bin3 yue_dak1 yue_bat1 yue_sau6 yue_lei5 yue_coi2 yue_coi2 yue_zi6 yue_gei2 yue_gei2 yue_jiu3 yue_gaau1 yue_ceot1 yue_ji3 yue_ngoi6",
12
+ "note_pitch": "0 52 57 59 55 57 59 62 60 58 54 57 59 59 57 55 54 53 57 51 50 54 57 58 54 57 59 61 64 59 54 54 57 59 51 56 58 57 56 56 55 52",
13
+ "note_type": "1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 2 3 2 2 2 2 2",
14
+ "f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 74.8 74.8 82.4 85.7 81.0 76.2 78.8 111.9 129.0 146.8 160.6 175.0 182.1 172.9 163.9 190.3 214.7 218.4 221.5 223.5 220.2 209.1 173.8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 254.5 248.7 245.1 243.3 239.8 238.8 241.2 244.0 245.7 245.6 239.8 216.8 191.2 0.0 0.0 170.9 179.4 189.0 192.6 193.3 192.8 193.5 194.1 194.1 194.4 195.4 197.4 199.7 202.2 205.8 210.6 212.3 214.4 216.5 219.4 221.9 222.4 222.4 222.8 222.9 217.1 189.6 175.4 0.0 0.0 255.5 251.2 246.5 247.3 248.9 249.1 249.4 251.9 253.6 250.2 247.1 246.6 239.1 193.8 191.0 0.0 295.9 301.1 302.8 301.5 296.6 287.7 286.4 290.4 294.2 297.1 297.3 294.7 287.9 273.4 221.0 262.3 265.3 259.7 255.9 254.7 255.1 256.2 257.4 259.3 260.8 261.3 249.3 209.1 194.0 236.7 224.9 210.2 202.6 197.9 201.4 210.6 220.6 230.1 237.9 242.4 242.7 241.0 234.7 220.3 190.9 179.0 185.6 182.3 178.3 177.1 179.0 181.3 182.4 184.5 187.6 185.9 172.3 161.7 167.7 0.0 0.0 206.9 210.4 213.0 214.2 215.6 217.3 217.1 194.7 181.9 184.5 182.3 171.8 155.9 161.4 167.7 195.8 235.3 245.0 245.8 241.9 237.0 232.3 231.3 234.9 241.8 250.9 253.2 248.9 238.0 226.0 220.5 224.9 236.6 250.9 259.6 262.0 259.1 252.5 246.5 241.8 236.0 228.8 216.3 210.5 211.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 209.1 197.4 191.2 191.2 200.4 222.5 235.6 248.0 252.9 249.4 243.0 235.6 216.1 182.1 175.8 213.4 214.1 214.7 215.1 215.2 215.9 215.7 214.8 211.6 202.2 187.0 183.6 0.0 0.0 0.0 191.3 192.1 191.4 192.4 193.3 192.6 188.9 171.2 0.0 0.0 0.0 0.0 0.0 0.0 192.1 186.5 181.8 178.2 176.1 176.8 178.2 179.4 179.7 180.4 183.3 184.4 182.7 179.0 176.1 174.6 168.7 166.5 167.2 168.8 171.8 175.6 185.4 194.7 197.2 192.6 181.9 0.0 0.0 0.0 192.6 204.8 217.5 220.0 220.4 218.7 216.0 214.2 216.0 216.8 213.0 200.5 183.9 0.0 0.0 158.5 156.8 159.6 161.4 160.7 159.4 159.5 159.3 157.1 152.4 152.1 156.4 162.9 167.3 166.4 160.3 149.6 146.1 149.5 155.4 161.3 164.1 163.3 161.3 154.8 148.5 144.6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 99.4 97.6 102.0 115.3 131.4 142.5 150.1 155.1 159.7 162.5 161.1 152.3 0.0 0.0 0.0 0.0 174.2 183.7 189.8 191.6 190.4 189.1 189.5 190.4 191.9 192.1 192.5 196.2 203.3 212.2 214.1 216.7 217.4 217.2 216.8 216.4 214.9 214.7 216.3 218.9 220.4 221.5 224.6 231.2 237.9 242.2 242.6 240.3 239.0 240.4 241.7 242.3 238.4 184.9 178.6 194.1 203.6 194.7 185.4 180.5 181.9 185.4 188.7 191.6 193.9 194.6 192.2 191.0 189.5 186.7 181.0 0.0 0.0 219.2 217.8 216.8 215.3 213.3 212.0 213.6 215.2 215.0 214.7 215.8 218.2 221.0 224.4 229.9 237.5 244.9 246.6 246.6 248.0 250.2 249.8 193.6 186.3 193.0 0.0 0.0 0.0 288.3 289.9 287.8 287.2 287.7 285.0 283.4 281.9 280.1 283.4 287.8 290.9 291.6 287.6 240.0 238.5 0.0 333.3 328.9 325.3 323.1 321.6 317.4 298.0 274.9 0.0 0.0 0.0 0.0 0.0 0.0 243.7 248.7 245.5 243.3 243.8 246.2 244.5 226.2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 197.1 189.0 184.3 182.5 184.8 188.1 188.8 190.3 190.4 189.5 188.7 188.0 189.8 188.2 185.0 182.2 179.7 178.1 178.7 188.8 208.5 217.2 220.0 218.6 212.8 191.9 167.3 0.0 0.0 0.0 183.9 201.1 210.5 210.3 209.5 206.7 206.5 212.0 221.1 235.9 250.8 253.0 247.7 238.4 229.7 229.1 235.1 244.7 251.7 253.5 251.4 246.7 242.4 234.6 209.5 180.1 173.1 0.0 0.0 0.0 147.9 147.5 155.4 159.1 160.2 161.4 162.0 161.7 159.0 152.4 139.6 121.5 126.4 0.0 0.0 197.2 200.2 196.8 195.8 200.3 203.2 203.8 202.5 203.7 212.8 220.0 225.8 231.6 234.1 231.7 228.3 225.4 226.0 229.7 234.3 237.3 238.2 236.9 232.9 227.1 220.3 215.1 211.3 205.9 210.1 216.0 217.6 218.4 218.9 219.3 218.5 217.1 216.9 217.7 216.4 212.1 196.3 171.7 171.3 210.9 203.0 194.9 194.8 199.1 205.4 210.5 214.9 219.6 224.8 225.1 221.4 212.4 198.5 0.0 0.0 204.9 204.1 208.9 212.7 212.8 213.2 214.9 214.4 208.4 189.0 159.7 160.4 193.1 198.9 196.0 192.8 193.0 195.3 195.2 194.5 193.6 193.9 192.5 192.3 192.2 183.0 164.8 150.0 147.3 150.7 155.9 160.8 163.6 164.8 161.8 156.6 155.0 160.5 166.2 167.3 165.1 162.0 154.3 142.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
15
+ }
16
+ ]
example/audio/yue_target.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a699c2649eec48ed1e9a6caae2af918bf7d49e5e4ad39cf3cca0916942bc7db2
3
+ size 353361
example/audio/zh_prompt.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "index": "vocal_320_10687",
4
+ "language": "Mandarin",
5
+ "time": [
6
+ 320,
7
+ 10687
8
+ ],
9
+ "duration": "0.23 0.34 0.26 0.70 0.52 0.46 0.36 0.44 0.14 0.24 0.64 0.47 0.51 1.10 0.28 0.38 0.32 0.32 0.38 0.32 0.31 0.19 1.45",
10
+ "text": "<SP> 除 了 想 你 你 <SP> 除 了 了 爱 你 你 <SP> 我 什 么 什 么 都 愿 愿 意",
11
+ "phoneme": "<SP> zh_chu2 zh_le5 zh_xiang3 zh_ni3 zh_ni3 <SP> zh_chu2 zh_le5 zh_le5 zh_ai4 zh_ni3 zh_ni3 <SP> zh_wo3 zh_shen2 zh_me5 zh_shen2 zh_me5 zh_dou1 zh_yuan4 zh_yuan4 zh_yi4",
12
+ "note_pitch": "0 62 65 67 67 69 0 67 69 67 65 67 69 67 67 66 64 64 60 60 65 67 0",
13
+ "note_type": "1 2 2 2 2 3 1 2 2 3 2 2 3 1 2 2 2 2 2 2 2 3 2",
14
+ "f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 294.1 288.2 290.6 294.6 295.7 292.8 291.5 294.4 295.4 294.7 293.5 292.2 294.4 295.8 293.2 297.7 320.1 338.1 348.3 348.7 344.4 342.6 346.2 354.8 356.9 353.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 401.1 403.8 405.0 402.1 398.4 395.6 393.3 392.0 391.4 390.3 390.2 390.3 390.3 391.3 392.8 393.6 391.7 390.5 391.4 391.6 391.5 393.5 393.8 390.8 387.4 387.8 389.3 390.8 392.2 391.6 390.2 389.8 389.1 388.4 390.0 395.5 397.2 396.7 395.5 395.1 394.6 394.9 395.6 395.2 394.6 395.4 395.9 394.0 391.7 390.7 391.7 392.6 391.6 395.7 405.8 441.7 462.5 463.8 450.2 430.9 414.5 415.2 426.7 439.8 454.4 462.9 447.8 422.5 400.6 403.2 423.5 451.3 482.3 492.8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 435.9 414.3 406.2 402.8 398.2 396.7 396.7 395.2 398.3 416.9 441.6 446.9 441.5 437.3 435.3 432.9 431.6 432.3 434.8 436.7 433.8 422.8 406.7 390.9 382.0 381.7 384.7 382.7 368.1 357.4 355.6 355.1 352.4 348.1 346.4 348.7 351.6 355.0 354.4 351.7 349.9 349.2 348.1 346.0 345.4 344.2 344.4 345.5 346.6 349.0 349.7 349.1 349.5 349.6 349.7 349.4 349.5 352.2 354.5 355.7 355.6 356.9 359.4 361.5 363.8 360.4 354.3 357.2 363.9 372.4 382.6 399.1 402.9 400.6 395.5 390.5 388.9 390.2 391.1 391.9 391.5 390.4 390.2 391.3 391.6 391.0 388.6 386.2 389.1 403.8 430.2 441.8 449.5 448.1 443.2 438.3 432.9 430.7 434.6 442.0 447.6 446.0 440.3 434.7 431.1 435.7 442.3 445.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 423.8 402.1 398.5 397.6 393.4 390.5 391.4 402.5 427.6 442.5 442.5 435.1 430.1 430.8 439.9 447.4 442.2 426.1 412.3 399.8 391.8 389.5 388.5 387.8 386.4 384.7 384.5 387.9 391.2 391.8 392.9 393.8 392.0 392.0 395.4 398.1 398.2 396.3 393.1 391.1 388.9 386.6 383.1 381.5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 357.8 367.6 370.9 365.4 359.3 355.7 358.1 372.7 396.8 404.0 398.4 392.6 389.2 388.8 383.6 362.3 341.1 325.1 326.3 327.9 331.3 333.3 326.0 319.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 356.4 353.5 349.0 346.5 338.4 328.4 323.7 326.2 334.2 338.4 331.6 309.8 280.6 256.4 252.8 256.1 258.3 256.9 257.6 261.1 259.7 258.7 258.6 260.2 262.2 262.4 262.9 264.2 263.4 260.2 256.4 253.2 238.7 223.3 231.6 257.7 258.1 258.4 258.8 258.0 256.8 255.8 254.0 255.1 258.1 261.9 263.7 262.6 256.7 253.1 250.2 246.7 258.7 294.4 327.1 342.6 346.0 344.1 341.2 342.3 345.2 350.1 364.7 382.5 396.1 396.1 389.2 381.8 381.1 387.2 397.0 399.3 390.7 374.3 360.1 350.6 346.6 347.4 350.7 354.4 354.3 351.7 349.8 348.4 346.9 347.0 348.4 349.5 351.0 352.3 353.6 353.3 350.5 348.3 345.5 344.3 344.4 347.0 350.6 352.0 351.0 350.8 350.1 347.6 345.7 347.0 350.3 351.7 350.7 348.5 346.9 347.7 349.0 349.1 348.5 346.8 346.3 348.4 349.0 349.2 351.1 349.6 348.3 350.5 351.1 348.0 347.6 349.1 351.3 356.0 361.3 360.6 354.0 341.0 316.2 302.9 302.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
15
+ }
16
+ ]
example/audio/zh_prompt.mp3 ADDED
Binary file (86.1 kB). View file
 
example/audio/zh_target.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "index": "vocal_0_6710",
4
+ "language": "Mandarin",
5
+ "time": [
6
+ 0,
7
+ 6710
8
+ ],
9
+ "duration": "0.13 0.26 0.24 0.22 0.24 0.33 0.13 0.24 0.22 0.46 0.69 0.84 0.26 0.30 0.16 0.26 0.26 0.20 0.32 0.94",
10
+ "text": "<SP> 像 我 这 样 懦 懦 弱 的 人 人 <SP> 凡 事 都 要 留 留 几 分",
11
+ "phoneme": "<SP> zh_xiang4 zh_wo3 zh_zhe4 zh_yang4 zh_nuo4 zh_nuo4 zh_ruo4 zh_de5 zh_ren2 zh_ren2 <SP> zh_fan2 zh_shi4 zh_dou1 zh_yao4 zh_liu2 zh_liu2 zh_ji3 zh_fen1",
12
+ "note_pitch": "0 50 53 55 53 56 54 53 50 51 53 0 51 53 55 53 54 56 51 53",
13
+ "note_type": "1 2 2 2 2 2 3 2 2 2 3 1 2 2 2 2 2 3 2 2",
14
+ "f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 132.7 137.2 144.9 147.0 148.0 148.6 148.9 148.5 147.9 147.4 148.1 149.1 154.3 166.4 173.1 175.0 176.3 175.9 172.7 173.6 175.9 172.9 159.1 165.8 0.0 214.2 213.4 210.2 201.5 198.1 197.1 197.2 197.8 200.8 206.4 206.1 200.5 189.9 180.6 172.0 170.8 171.4 176.2 180.9 182.3 182.2 181.0 180.5 183.4 192.6 211.8 220.6 223.7 219.3 211.9 207.0 203.6 202.6 204.2 204.4 204.1 202.1 198.0 192.9 185.9 177.9 174.1 174.6 174.5 173.8 173.6 172.4 168.3 168.3 172.7 173.2 171.9 170.7 170.2 169.9 170.6 173.1 172.4 164.2 148.2 147.8 152.3 148.5 143.8 145.6 149.2 149.9 150.1 152.5 153.6 154.7 156.1 155.0 152.4 152.1 153.7 155.3 156.4 156.8 157.3 157.7 157.1 156.8 157.8 158.9 157.9 157.5 157.1 157.0 159.1 162.0 167.7 172.1 174.9 176.2 174.5 172.0 170.9 171.3 172.5 173.1 173.5 173.1 174.1 174.6 175.2 176.7 177.2 177.3 176.9 175.9 174.1 172.4 174.1 174.8 171.8 172.1 176.5 177.3 176.0 179.4 179.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 137.5 138.5 145.8 151.0 153.3 154.3 156.2 158.5 161.0 158.1 148.6 151.3 157.0 163.5 172.7 176.4 176.1 175.8 175.6 174.7 171.4 169.3 161.7 194.2 199.3 199.7 201.0 202.7 201.4 199.4 198.5 196.6 194.0 190.6 186.1 181.9 179.4 177.4 177.1 176.3 175.8 175.5 174.8 173.6 172.4 170.8 165.0 160.6 175.9 179.3 179.8 180.4 180.3 178.8 178.6 181.5 185.0 190.7 198.3 206.3 210.6 210.9 207.4 203.5 203.4 204.6 203.5 195.6 182.5 0.0 0.0 0.0 0.0 144.7 144.1 146.6 150.2 151.9 153.4 155.0 156.0 155.9 155.4 155.2 153.5 147.0 144.8 0.0 0.0 0.0 0.0 0.0 0.0 181.1 178.9 178.0 177.5 176.0 173.2 172.9 172.9 174.1 176.2 177.3 178.7 178.8 176.0 175.2 175.1 176.3 178.1 177.6 177.4 177.9 177.6 177.3 177.3 177.5 176.6 175.7 176.6 177.5 177.2 175.9 174.8 173.5 174.0 175.7 177.4 177.8 174.7"
15
+ }
16
+ ]
example/audio/zh_target.mp3 ADDED
Binary file (54.2 kB). View file
 
example/infer.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ script_dir=$(dirname "$(realpath "$0")")
4
+ root_dir=$(dirname "$script_dir")
5
+
6
+ cd $root_dir || exit
7
+ export PYTHONPATH=$root_dir:$PYTHONPATH
8
+
9
+ model_path=pretrained_models/SoulX-Singer/model.pt
10
+ config=soulxsinger/config/soulxsinger.yaml
11
+ prompt_wav_path=example/audio/zh_prompt.mp3
12
+ prompt_metadata_path=example/audio/zh_prompt.json
13
+ target_metadata_path=example/audio/music.json
14
+ phoneset_path=soulxsinger/utils/phoneme/phone_set.json
15
+ save_dir=example/generated/music
16
+ control=score # melody or score
17
+
18
+ python -m cli.inference \
19
+ --device cuda \
20
+ --model_path $model_path \
21
+ --config $config \
22
+ --prompt_wav_path $prompt_wav_path \
23
+ --prompt_metadata_path $prompt_metadata_path \
24
+ --target_metadata_path $target_metadata_path \
25
+ --phoneset_path $phoneset_path \
26
+ --save_dir $save_dir \
27
+ --auto_shift \
28
+ --pitch_shift 0
example/preprocess.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ script_dir=$(dirname "$(realpath "$0")")
4
+ root_dir=$(dirname "$script_dir")
5
+
6
+ cd $root_dir || exit
7
+ export PYTHONPATH=$root_dir:$PYTHONPATH
8
+
9
+ device=cuda
10
+
11
+
12
+ ####### Run Prompt Annotation #######
13
+ audio_path=example/audio/zh_prompt.mp3
14
+ save_dir=example/transcriptions/zh_prompt
15
+ language=Mandarin
16
+ vocal_sep=False
17
+ max_merge_duration=30000
18
+
19
+ python -m preprocess.pipeline \
20
+ --audio_path $audio_path \
21
+ --save_dir $save_dir \
22
+ --language $language \
23
+ --device $device \
24
+ --vocal_sep $vocal_sep \
25
+ --max_merge_duration $max_merge_duration
26
+
27
+
28
+ ####### Run Target Annotation #######
29
+ audio_path=example/audio/music.mp3
30
+ save_dir=example/transcriptions/music
31
+ language=Mandarin
32
+ vocal_sep=True
33
+ max_merge_duration=60000
34
+
35
+ python -m preprocess.pipeline \
36
+ --audio_path $audio_path \
37
+ --save_dir $save_dir \
38
+ --language $language \
39
+ --device $device \
40
+ --vocal_sep $vocal_sep \
41
+ --max_merge_duration $max_merge_duration
preprocess/README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎵 SoulX-Singer-Preprocess
2
+
3
+ This part offers a comprehensive **singing transcription and editing toolkit** for real-world music audio. It provides the pipeline from vocal extraction to high-level annotation optimized for SVS dataset construction. By integrating state-of-the-art models, it transforms raw audio into structured singing data and supports the **customizable creation and editing of lyric-aligned MIDI scores**.
4
+
5
+
6
+ ## ✨ Features
7
+
8
+ The toolkit includes the following core modules:
9
+
10
+ - 🎤 **Clean Dry Vocal Extraction**
11
+ Extracts the lead vocal track from polyphonic music audio and dereverberation.
12
+
13
+ - 📝 **Lyrics Transcription**
14
+ Automatically transcribes lyrics from clean vocal.
15
+
16
+ - 🎶 **Note Transcription**
17
+ Converts singing voice into note-level representations for SVS.
18
+
19
+ - 🎼 **MIDI Editor**
20
+ Supports customizable creation and editing of MIDI scores integrated with lyrics.
21
+
22
+
23
+ ## 🔧 Python Environment
24
+
25
+ Before running the pipeline, set up the Python environment as follows:
26
+
27
+ 1. **Install Conda** (if not already installed): https://docs.conda.io/en/latest/miniconda.html
28
+
29
+ 2. **Activate or create a conda environment** (recommended Python 3.10):
30
+
31
+ - If you already have the `soulxsinger` environment:
32
+
33
+ ```bash
34
+ conda activate soulxsinger
35
+ ```
36
+
37
+ - Otherwise, create it first:
38
+
39
+ ```bash
40
+ conda create -n soulxsinger -y python=3.10
41
+ conda activate soulxsinger
42
+ ```
43
+
44
+ 3. **Install dependencies** from the `preprocess` directory:
45
+
46
+ ```bash
47
+ cd preprocess
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+ ## 📁 Data Preparation
52
+
53
+ Before running the pipeline, prepare the following inputs:
54
+
55
+ - **Prompt audio**
56
+ Reference audio that provides timbre and style
57
+
58
+ - **Target audio**
59
+ Original vocal or music audio to be processed and transcribed.
60
+
61
+ Configure the corresponding parameters in:
62
+
63
+ ```
64
+ example/preprocess.sh
65
+ ```
66
+
67
+ Typical configuration includes:
68
+ - Input / output paths
69
+ - Module enable switches
70
+
71
+ ## 🚀 Usage
72
+
73
+ After configuring `preprocess.sh`, run the transcription pipeline with:
74
+
75
+ ```bash
76
+ bash example/preprocess.sh
77
+ ```
78
+
79
+ The script will automatically execute the following steps:
80
+
81
+ 1. **Vocal separation and dereverberation**
82
+ 2. **F0 extraction and voice activity detection (VAD)**
83
+ 3. **Lyrics transcription**
84
+ 4. **Note transcription**
85
+
86
+ ---
87
+
88
+ After the pipeline completes, you will obtain **SoulX-Singer–style metadata** that can be directly used for Singing Voice Synthesis (SVS).
89
+
90
+ **Output paths:**
91
+ - The final metadata (**JSON file**) is written **in the same directory as your input audio**, with the **same filename** (e.g. `audio.mp3` → `audio.json`)
92
+ - All **intermediate results** (separated vocal and accompaniment, F0, VAD outputs, etc.) are also saved under the configured **`save_dir`**.
93
+
94
+ ⚠️ **Important Note**
95
+
96
+ Transcription errors—especially in **lyrics** and **note annotations**—can significantly affect the final SVS quality. We **strongly recommend manually reviewing and correcting** the generated metadata before inference.
97
+
98
+ To support this, we provide a **MIDI Editor** for editing lyrics, phoneme alignment, note pitches, and durations. The workflow is:
99
+
100
+ **Export metadata to MIDI** → edit in the MIDI Editor → **Import edited MIDI back to metadata** for SVS.
101
+
102
+ ---
103
+
104
+ #### Step 1: Metadata → MIDI (for editing)
105
+
106
+ Convert SoulX-Singer metadata to a MIDI file so you can open it in the MIDI Editor:
107
+
108
+ ```bash
109
+ preprocess_root=example/transcriptions/music
110
+
111
+ python -m preprocess.tools.midi_parser \
112
+ --meta2midi \
113
+ --meta "${preprocess_root}/metadata.json" \
114
+ --midi "${preprocess_root}/vocal.mid"
115
+ ```
116
+
117
+ #### Step 2: Edit in the MIDI Editor
118
+
119
+ Open the MIDI Editor (see [MIDI Editor Tutorial](tools/midi_editor/README.md)), load `vocal.mid`, and correct lyrics, pitches, or durations as needed. Save the result as e.g. `vocal_edited.mid`.
120
+
121
+ #### Step 3: MIDI → Metadata (for SoulX-Singer inference)
122
+
123
+ Convert the edited MIDI back into SoulX-Singer-style metadata (and cut wavs) for SVS:
124
+
125
+ ```bash
126
+ python -m preprocess.tools.midi_parser \
127
+ --midi2meta \
128
+ --midi "${preprocess_root}/vocal_edited.mid" \
129
+ --meta "${preprocess_root}/edit_metadata.json" \
130
+ --vocal "${preprocess_root}/vocal.wav" \
131
+ ```
132
+
133
+ Use `edit_metadata.json` (and the wavs under `edit_cut_wavs`) as the target metadata in your inference pipeline.
134
+
135
+
136
+ ## 🔗 References & Dependencies
137
+
138
+ This project builds upon the following excellent open-source works:
139
+
140
+ ### 🎧 Vocal Separation & Dereverberation
141
+ - [Music Source Separation Training](https://github.com/ZFTurbo/Music-Source-Separation-Training)
142
+ - [Lead Vocal Separation](https://huggingface.co/becruily/mel-band-roformer-karaoke)
143
+ - [Vocal Dereverberation](https://huggingface.co/anvuew/dereverb_mel_band_roformer)
144
+
145
+ ### 🎼 F0 Extraction
146
+ - [RMVPE](https://github.com/Dream-High/RMVPE)
147
+
148
+ ### 📝 Lyrics Transcription (ASR)
149
+ - [Paraformer](https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
150
+ - [Parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2)
151
+
152
+ ### 🎶 Note Transcription
153
+ - [ROSVOT](https://github.com/RickyL-2000/ROSVOT)
154
+
155
+ We sincerely thank the authors of these repositories for their exceptional open-source contributions, which have been fundamental to the development of this toolkit.
preprocess/pipeline.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import shutil
3
+ import soundfile as sf
4
+ from pathlib import Path
5
+ import librosa
6
+
7
+ from preprocess.utils import convert_metadata, merge_short_segments
8
+
9
+ from preprocess.tools import (
10
+ F0Extractor,
11
+ VocalDetector,
12
+ VocalSeparator,
13
+ NoteTranscriber,
14
+ LyricTranscriber,
15
+ )
16
+
17
+
18
+ class PreprocessPipeline:
19
+ def __init__(self, device: str, language: str, save_dir: str, vocal_sep: bool = True, max_merge_duration: int = 60000):
20
+ self.device = device
21
+ self.language = language
22
+ self.save_dir = save_dir
23
+ self.vocal_sep = vocal_sep
24
+ self.max_merge_duration = max_merge_duration
25
+
26
+ if vocal_sep:
27
+ self.vocal_separator = VocalSeparator(
28
+ sep_model_path="pretrained_models/SoulX-Singer-Preprocess/mel-band-roformer-karaoke/mel_band_roformer_karaoke_becruily.ckpt",
29
+ sep_config_path="pretrained_models/SoulX-Singer-Preprocess/mel-band-roformer-karaoke/config_karaoke_becruily.yaml",
30
+ der_model_path="pretrained_models/SoulX-Singer-Preprocess/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt",
31
+ der_config_path="pretrained_models/SoulX-Singer-Preprocess/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew.yaml",
32
+ device=device
33
+ )
34
+ else:
35
+ self.vocal_separator = None
36
+ self.f0_extractor = F0Extractor(
37
+ model_path="pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt",
38
+ device=device,
39
+ )
40
+ self.vocal_detector = VocalDetector(
41
+ cut_wavs_output_dir= f"{save_dir}/cut_wavs",
42
+ )
43
+ self.lyric_transcriber = LyricTranscriber(
44
+ zh_model_path="pretrained_models/SoulX-Singer-Preprocess/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
45
+ en_model_path="pretrained_models/SoulX-Singer-Preprocess/parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo",
46
+ device=device
47
+ )
48
+ self.note_transcriber = NoteTranscriber(
49
+ rosvot_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rosvot/model.pt",
50
+ rwbd_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rwbd/model.pt",
51
+ device=device
52
+ )
53
+
54
+ def run(
55
+ self,
56
+ audio_path: str,
57
+ vocal_sep: bool = True,
58
+ max_merge_duration: int = 60000,
59
+ language: str = "Mandarin"
60
+ ) -> None:
61
+ vocal_sep = self.vocal_sep if vocal_sep is None else vocal_sep
62
+ max_merge_duration = self.max_merge_duration if max_merge_duration is None else max_merge_duration
63
+ language = self.language if language is None else language
64
+ output_dir = Path(self.save_dir)
65
+ output_dir.mkdir(parents=True, exist_ok=True)
66
+
67
+ if vocal_sep:
68
+ # Perform vocal/accompaniment separation
69
+ sep = self.vocal_separator.process(audio_path)
70
+ vocal = sep.vocals_dereverbed.T
71
+ acc = sep.accompaniment.T
72
+ sample_rate = sep.sample_rate
73
+
74
+ vocal_path = output_dir / "vocal.wav"
75
+ acc_path = output_dir / "acc.wav"
76
+ sf.write(vocal_path, vocal, sample_rate)
77
+ sf.write(acc_path, acc, sample_rate)
78
+ else:
79
+ # Use the original audio as vocal source (no separation)
80
+ vocal, sample_rate = librosa.load(audio_path, sr=None, mono=True)
81
+ vocal_path = output_dir / "vocal.wav"
82
+ sf.write(vocal_path, vocal, sample_rate)
83
+
84
+ vocal_f0 = self.f0_extractor.process(str(vocal_path))
85
+ segments = self.vocal_detector.process(str(vocal_path), f0=vocal_f0)
86
+
87
+ metadata = []
88
+ for seg in segments:
89
+ self.f0_extractor.process(seg["wav_fn"], f0_path=seg["wav_fn"].replace(".wav", "_f0.npy"))
90
+ words, durs = self.lyric_transcriber.process(
91
+ seg["wav_fn"], language
92
+ )
93
+ seg["words"] = words
94
+ seg["word_durs"] = durs
95
+ seg["language"] = language
96
+ metadata.append(
97
+ self.note_transcriber.process(seg, segment_info=seg)
98
+ )
99
+
100
+ merged = merge_short_segments(
101
+ vocal,
102
+ sample_rate,
103
+ metadata,
104
+ output_dir / "long_cut_wavs",
105
+ max_duration_ms=max_merge_duration,
106
+ )
107
+
108
+ final_metadata = []
109
+
110
+ for item in merged:
111
+ self.f0_extractor.process(item.wav_fn, f0_path=item.wav_fn.replace(".wav", "_f0.npy"))
112
+ final_metadata.append(convert_metadata(item))
113
+
114
+ with open(output_dir / "metadata.json", "w", encoding="utf-8") as f:
115
+ json.dump(final_metadata, f, ensure_ascii=False, indent=2)
116
+
117
+ shutil.copy(output_dir / "metadata.json", audio_path.replace(".wav", ".json").replace(".mp3", ".json").replace(".flac", ".json"))
118
+
119
+
120
+ def main(args):
121
+ pipeline = PreprocessPipeline(
122
+ device=args.device,
123
+ language=args.language,
124
+ save_dir=args.save_dir,
125
+ vocal_sep=args.vocal_sep,
126
+ max_merge_duration=args.max_merge_duration,
127
+ )
128
+ pipeline.run(
129
+ audio_path=args.audio_path,
130
+ language=args.language
131
+ )
132
+
133
+
134
+ if __name__ == "__main__":
135
+ import argparse
136
+
137
+ parser = argparse.ArgumentParser()
138
+ parser.add_argument("--audio_path", type=str, required=True, help="Path to the input audio file")
139
+ parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the output files")
140
+ parser.add_argument("--language", type=str, default="Mandarin", help="Language of the audio")
141
+ parser.add_argument("--device", type=str, default="cuda:0", help="Device to run the models on")
142
+ parser.add_argument("--vocal_sep", type=bool, default=True, help="Whether to perform vocal separation")
143
+ parser.add_argument("--max_merge_duration", type=int, default=60000, help="Maximum merged segment duration in milliseconds")
144
+ args = parser.parse_args()
145
+
146
+ main(args)
preprocess/requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ beartype==0.22.9
2
+ einops==0.8.2
3
+ funasr==1.3.0
4
+ g2p_en==2.1.0
5
+ g2pM==0.1.2.5
6
+ librosa==0.11.0
7
+ loralib==0.1.2
8
+ matplotlib==3.10.8
9
+ mido==1.3.3
10
+ ml_collections==1.1.0
11
+ nemo_toolkit==2.6.1
12
+ nltk==3.9.2
13
+ numba==0.63.1
14
+ numpy==2.2.6
15
+ omegaconf==2.3.0
16
+ packaging==24.2
17
+ praat-parselmouth==0.4.7
18
+ pretty_midi==0.2.11
19
+ pyloudnorm==0.2.0
20
+ pyworld==0.3.5
21
+ rotary_embedding_torch==0.8.9
22
+ sageattention==1.0.6
23
+ scikit_learn==1.7.2
24
+ scipy==1.15.3
25
+ six==1.17.0
26
+ scikit_image==0.25.2
27
+ soundfile==0.13.1
28
+ ToJyutping==3.2.0
29
+ torch==2.10.0
30
+ torchaudio==2.10.0
31
+ tqdm==4.67.1
32
+ wandb==0.24.2
33
+ webrtcvad==2.0.10
preprocess/tools/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Preprocess tools.
2
+
3
+ This package provides a thin, stable import surface for common preprocess components.
4
+
5
+ Examples:
6
+ from preprocess.tools import (
7
+ F0Extractor,
8
+ PitchExtractor,
9
+ VocalDetectionModel,
10
+ VocalSeparationModel,
11
+ VocalExtractionModel,
12
+ NoteTranscriptionModel,
13
+ LyricTranscriptionModel,
14
+ )
15
+
16
+ Note:
17
+ Keep these imports lightweight. If a tool pulls heavy dependencies at import time,
18
+ consider switching to lazy imports.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ # Core tools
24
+ from .f0_extraction import F0Extractor
25
+ from .vocal_detection import VocalDetector
26
+
27
+ # Some tools may live outside this package in different layouts across branches.
28
+ # Keep the public surface stable while avoiding hard import failures.
29
+ try:
30
+ from .vocal_separation.model import VocalSeparator # type: ignore
31
+ except Exception: # pragma: no cover
32
+ VocalSeparator = None # type: ignore
33
+
34
+ try:
35
+ from .note_transcription.model import NoteTranscriber # type: ignore
36
+ except Exception: # pragma: no cover
37
+ NoteTranscriber = None # type: ignore
38
+ try:
39
+ from .lyric_transcription import LyricTranscriber
40
+ except Exception: # pragma: no cover
41
+ LyricTranscriber = None # type: ignore
42
+
43
+ __all__ = [
44
+ "F0Extractor",
45
+ "VocalDetector",
46
+ ]
47
+
48
+ if VocalSeparator is not None:
49
+ __all__.append("VocalSeparator")
50
+ if LyricTranscriber is not None:
51
+ __all__.append("LyricTranscriber")
52
+ if NoteTranscriber is not None:
53
+ __all__.append("NoteTranscriber")
preprocess/tools/f0_extraction.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/Dream-High/RMVPE
2
+ import math
3
+ import time
4
+ import librosa
5
+ import numpy as np
6
+ from librosa.filters import mel
7
+ from scipy.interpolate import interp1d
8
+
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ class BiGRU(nn.Module):
17
+ def __init__(self, input_features, hidden_features, num_layers):
18
+ super(BiGRU, self).__init__()
19
+ self.gru = nn.GRU(
20
+ input_features,
21
+ hidden_features,
22
+ num_layers=num_layers,
23
+ batch_first=True,
24
+ bidirectional=True,
25
+ )
26
+
27
+ def forward(self, x):
28
+ return self.gru(x)[0]
29
+
30
+
31
+ class ConvBlockRes(nn.Module):
32
+ def __init__(self, in_channels, out_channels, momentum=0.01):
33
+ super(ConvBlockRes, self).__init__()
34
+ self.conv = nn.Sequential(
35
+ nn.Conv2d(
36
+ in_channels=in_channels,
37
+ out_channels=out_channels,
38
+ kernel_size=(3, 3),
39
+ stride=(1, 1),
40
+ padding=(1, 1),
41
+ bias=False,
42
+ ),
43
+ nn.BatchNorm2d(out_channels, momentum=momentum),
44
+ nn.ReLU(),
45
+ nn.Conv2d(
46
+ in_channels=out_channels,
47
+ out_channels=out_channels,
48
+ kernel_size=(3, 3),
49
+ stride=(1, 1),
50
+ padding=(1, 1),
51
+ bias=False,
52
+ ),
53
+ nn.BatchNorm2d(out_channels, momentum=momentum),
54
+ nn.ReLU(),
55
+ )
56
+ if in_channels != out_channels:
57
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
58
+
59
+ def forward(self, x):
60
+ if not hasattr(self, "shortcut"):
61
+ return self.conv(x) + x
62
+ else:
63
+ return self.conv(x) + self.shortcut(x)
64
+
65
+
66
+ class ResEncoderBlock(nn.Module):
67
+ def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
68
+ super(ResEncoderBlock, self).__init__()
69
+ self.n_blocks = n_blocks
70
+ self.conv = nn.ModuleList()
71
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
72
+ for i in range(n_blocks - 1):
73
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
74
+ self.kernel_size = kernel_size
75
+ if self.kernel_size is not None:
76
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
77
+
78
+ def forward(self, x):
79
+ for conv in self.conv:
80
+ x = conv(x)
81
+ if self.kernel_size is not None:
82
+ return x, self.pool(x)
83
+ else:
84
+ return x
85
+
86
+
87
+ class Encoder(nn.Module):
88
+ def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
89
+ super(Encoder, self).__init__()
90
+ self.n_encoders = n_encoders
91
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
92
+ self.layers = nn.ModuleList()
93
+ self.latent_channels = []
94
+ for i in range(self.n_encoders):
95
+ self.layers.append(
96
+ ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)
97
+ )
98
+ self.latent_channels.append([out_channels, in_size])
99
+ in_channels = out_channels
100
+ out_channels *= 2
101
+ in_size //= 2
102
+ self.out_size = in_size
103
+ self.out_channel = out_channels
104
+
105
+ def forward(self, x):
106
+ concat_tensors = []
107
+ x = self.bn(x)
108
+ for layer in self.layers:
109
+ t, x = layer(x)
110
+ concat_tensors.append(t)
111
+ return x, concat_tensors
112
+
113
+
114
+ class Intermediate(nn.Module):
115
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
116
+ super(Intermediate, self).__init__()
117
+ self.n_inters = n_inters
118
+ self.layers = nn.ModuleList()
119
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
120
+ for i in range(self.n_inters - 1):
121
+ self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
122
+
123
+ def forward(self, x):
124
+ for layer in self.layers:
125
+ x = layer(x)
126
+ return x
127
+
128
+
129
+ class ResDecoderBlock(nn.Module):
130
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
131
+ super(ResDecoderBlock, self).__init__()
132
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
133
+ self.n_blocks = n_blocks
134
+ self.conv1 = nn.Sequential(
135
+ nn.ConvTranspose2d(
136
+ in_channels=in_channels,
137
+ out_channels=out_channels,
138
+ kernel_size=(3, 3),
139
+ stride=stride,
140
+ padding=(1, 1),
141
+ output_padding=out_padding,
142
+ bias=False,
143
+ ),
144
+ nn.BatchNorm2d(out_channels, momentum=momentum),
145
+ nn.ReLU(),
146
+ )
147
+ self.conv2 = nn.ModuleList()
148
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
149
+ for i in range(n_blocks - 1):
150
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
151
+
152
+ def forward(self, x, concat_tensor):
153
+ x = self.conv1(x)
154
+ x = torch.cat((x, concat_tensor), dim=1)
155
+ for conv2 in self.conv2:
156
+ x = conv2(x)
157
+ return x
158
+
159
+
160
+ class Decoder(nn.Module):
161
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
162
+ super(Decoder, self).__init__()
163
+ self.layers = nn.ModuleList()
164
+ self.n_decoders = n_decoders
165
+ for i in range(self.n_decoders):
166
+ out_channels = in_channels // 2
167
+ self.layers.append(
168
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
169
+ )
170
+ in_channels = out_channels
171
+
172
+ def forward(self, x, concat_tensors):
173
+ for i, layer in enumerate(self.layers):
174
+ x = layer(x, concat_tensors[-1 - i])
175
+ return x
176
+
177
+
178
+ class DeepUnet(nn.Module):
179
+ def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
180
+ super(DeepUnet, self).__init__()
181
+ self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
182
+ self.intermediate = Intermediate(
183
+ self.encoder.out_channel // 2,
184
+ self.encoder.out_channel,
185
+ inter_layers,
186
+ n_blocks,
187
+ )
188
+ self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
189
+
190
+ def forward(self, x):
191
+ x, concat_tensors = self.encoder(x)
192
+ x = self.intermediate(x)
193
+ x = self.decoder(x, concat_tensors)
194
+ return x
195
+
196
+
197
+ class E2E(nn.Module):
198
+ def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
199
+ super(E2E, self).__init__()
200
+ self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
201
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
202
+ if n_gru:
203
+ self.fc = nn.Sequential(
204
+ BiGRU(3 * 128, 256, n_gru),
205
+ nn.Linear(512, 360),
206
+ nn.Dropout(0.25),
207
+ nn.Sigmoid(),
208
+ )
209
+ else:
210
+ self.fc = nn.Sequential(
211
+ nn.Linear(3 * 128, 360),
212
+ nn.Dropout(0.25),
213
+ nn.Sigmoid()
214
+ )
215
+
216
+ def forward(self, mel):
217
+ mel = mel.transpose(-1, -2).unsqueeze(1)
218
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
219
+ x = self.fc(x)
220
+ return x
221
+
222
+
223
+
224
+ class MelSpectrogram(torch.nn.Module):
225
+ def __init__(self, is_half, n_mel_channels, sampling_rate, win_length, hop_length,
226
+ n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
227
+ super().__init__()
228
+ n_fft = win_length if n_fft is None else n_fft
229
+ self.hann_window = {}
230
+ mel_basis = mel(
231
+ sr=sampling_rate,
232
+ n_fft=n_fft,
233
+ n_mels=n_mel_channels,
234
+ fmin=mel_fmin,
235
+ fmax=mel_fmax,
236
+ htk=True,
237
+ )
238
+ mel_basis = torch.from_numpy(mel_basis).float()
239
+ self.register_buffer("mel_basis", mel_basis)
240
+ self.n_fft = win_length if n_fft is None else n_fft
241
+ self.hop_length = hop_length
242
+ self.win_length = win_length
243
+ self.sampling_rate = sampling_rate
244
+ self.n_mel_channels = n_mel_channels
245
+ self.clamp = clamp
246
+ self.is_half = is_half
247
+
248
+ def forward(self, audio, keyshift=0, speed=1, center=True):
249
+ factor = 2 ** (keyshift / 12)
250
+ n_fft_new = int(np.round(self.n_fft * factor))
251
+ win_length_new = int(np.round(self.win_length * factor))
252
+ hop_length_new = int(np.round(self.hop_length * speed))
253
+
254
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
255
+ if keyshift_key not in self.hann_window:
256
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
257
+
258
+ fft = torch.stft(
259
+ audio,
260
+ n_fft=n_fft_new,
261
+ hop_length=hop_length_new,
262
+ win_length=win_length_new,
263
+ window=self.hann_window[keyshift_key],
264
+ center=center,
265
+ return_complex=True,
266
+ )
267
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
268
+
269
+ if keyshift != 0:
270
+ size = self.n_fft // 2 + 1
271
+ resize = magnitude.size(1)
272
+ if resize < size:
273
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
274
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
275
+
276
+ mel_output = torch.matmul(self.mel_basis, magnitude)
277
+ if self.is_half:
278
+ mel_output = mel_output.half()
279
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
280
+ return log_mel_spec
281
+
282
+
283
+
284
+ class RMVPE:
285
+ def __init__(self, model_path: str, is_half, device=None):
286
+ self.is_half = is_half
287
+ if device is None:
288
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
289
+ self.device = torch.device(device) if isinstance(device, str) else device
290
+
291
+ self.mel_extractor = MelSpectrogram(
292
+ is_half=is_half,
293
+ n_mel_channels=128,
294
+ sampling_rate=16000,
295
+ win_length=1024,
296
+ hop_length=160,
297
+ n_fft=None,
298
+ mel_fmin=30,
299
+ mel_fmax=8000
300
+ ).to(self.device)
301
+
302
+ model = E2E(n_blocks=4, n_gru=1, kernel_size=(2, 2))
303
+ ckpt = torch.load(model_path, map_location=self.device)
304
+ model.load_state_dict(ckpt)
305
+ model.eval()
306
+
307
+ if is_half:
308
+ model = model.half()
309
+ else:
310
+ model = model.float()
311
+
312
+ self.model = model.to(self.device)
313
+
314
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
315
+ self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
316
+
317
+ def mel2hidden(self, mel):
318
+ with torch.no_grad():
319
+ n_frames = mel.shape[-1]
320
+ n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
321
+ if n_pad > 0:
322
+ mel = F.pad(mel, (0, n_pad), mode="constant")
323
+ mel = mel.half() if self.is_half else mel.float()
324
+ hidden = self.model(mel)
325
+ return hidden[:, :n_frames]
326
+
327
+ def decode(self, hidden, thred=0.03):
328
+ cents_pred = self.to_local_average_cents(hidden, thred=thred)
329
+ f0 = 10 * (2 ** (cents_pred / 1200))
330
+ f0[f0 == 10] = 0
331
+ return f0
332
+
333
+ def infer_from_audio(self, audio, thred=0.03):
334
+ if not torch.is_tensor(audio):
335
+ audio = torch.from_numpy(audio)
336
+
337
+ mel = self.mel_extractor(audio.float().to(self.device).unsqueeze(0), center=True)
338
+ hidden = self.mel2hidden(mel)
339
+ hidden = hidden.squeeze(0).cpu().numpy()
340
+
341
+ if self.is_half:
342
+ hidden = hidden.astype("float32")
343
+
344
+ f0 = self.decode(hidden, thred=thred)
345
+ return f0
346
+
347
+ def to_local_average_cents(self, salience, thred=0.05):
348
+ center = np.argmax(salience, axis=1)
349
+ salience = np.pad(salience, ((0, 0), (4, 4)))
350
+ center += 4
351
+
352
+ todo_salience = []
353
+ todo_cents_mapping = []
354
+ starts = center - 4
355
+ ends = center + 5
356
+
357
+ for idx in range(salience.shape[0]):
358
+ todo_salience.append(salience[:, starts[idx]:ends[idx]][idx])
359
+ todo_cents_mapping.append(self.cents_mapping[starts[idx]:ends[idx]])
360
+
361
+ todo_salience = np.array(todo_salience)
362
+ todo_cents_mapping = np.array(todo_cents_mapping)
363
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
364
+ weight_sum = np.sum(todo_salience, 1)
365
+ devided = product_sum / weight_sum
366
+
367
+ maxx = np.max(salience, axis=1)
368
+ devided[maxx <= thred] = 0
369
+
370
+ return devided
371
+
372
+ class F0Extractor:
373
+ """Extract frame-level f0 from singing voice.
374
+
375
+ Wrapper around an RMVPE network that:
376
+ 1) loads the checkpoint once in ``__init__``
377
+ 2) exposes a simple :py:meth:`process` API and optionally saves ``*_f0.npy``.
378
+ """
379
+ def __init__(
380
+ self,
381
+ model_path: str,
382
+ device: str = "cpu",
383
+ *,
384
+ is_half: bool = False,
385
+ input_sr: int = 16000,
386
+ target_sr: int = 24000,
387
+ hop_size: int = 480,
388
+ max_duration: float = 300,
389
+ thred: float = 0.03,
390
+ verbose: bool = True,
391
+ ):
392
+ """Initialize the f0 extractor.
393
+
394
+ Args:
395
+ model_path: Path to RMVPE checkpoint.
396
+ device: Torch device string, e.g. ``"cuda:0"`` / ``"cpu"``.
397
+ is_half: Whether to run the model in fp16.
398
+ input_sr: Input resample rate used by RMVPE frontend.
399
+ target_sr: Target sample rate for the output f0 grid.
400
+ hop_size: Target hop size for the output f0 grid.
401
+ max_duration: Max duration (seconds) for interpolation grid.
402
+ thred: Voicing threshold used when decoding salience.
403
+ verbose: Whether to print verbose logs.
404
+ """
405
+ self.model_path = model_path
406
+ self.input_sr = input_sr
407
+ self.target_sr = target_sr
408
+ self.hop_size = hop_size
409
+ self.max_duration = max_duration
410
+ self.thred = thred
411
+
412
+ self.verbose = verbose
413
+
414
+ self.model = RMVPE(model_path, is_half=is_half, device=device)
415
+
416
+ if self.verbose:
417
+ print(
418
+ "[f0 extraction] init success:",
419
+ f"device={device}",
420
+ f"model_path={model_path}",
421
+ f"is_half={is_half}",
422
+ f"input_sr={input_sr}",
423
+ f"target_sr={target_sr}",
424
+ f"hop_size={hop_size}",
425
+ f"thred={thred}",
426
+ )
427
+
428
+ @staticmethod
429
+ def interpolate_f0(
430
+ f0_16k: np.ndarray,
431
+ original_length: int,
432
+ original_sr: int,
433
+ *,
434
+ target_sr: int = 48000,
435
+ hop_size: int = 256,
436
+ max_duration: float = 20.0,
437
+ ) -> np.ndarray:
438
+ """Interpolate f0 from RMVPE's 16k hop grid to target mel hop grid."""
439
+ mel_target_sr = target_sr
440
+ mel_hop_size = hop_size
441
+ mel_max_duration = max_duration
442
+
443
+ batch_max_length = int(mel_max_duration * mel_target_sr / mel_hop_size)
444
+ duration_in_seconds = original_length / original_sr
445
+ effective_target_length = int(duration_in_seconds * mel_target_sr)
446
+ original_frames = math.ceil(effective_target_length / mel_hop_size)
447
+ target_frames = min(original_frames, batch_max_length)
448
+
449
+ rmvpe_hop = 160
450
+ t_16k = np.arange(len(f0_16k)) * (rmvpe_hop / 16000.0)
451
+ t_target = np.arange(target_frames) * (mel_hop_size / float(mel_target_sr))
452
+
453
+ if len(f0_16k) > 0:
454
+ f_interp = interp1d(
455
+ t_16k,
456
+ f0_16k,
457
+ kind="linear",
458
+ bounds_error=False,
459
+ fill_value=0.0,
460
+ assume_sorted=True,
461
+ )
462
+ f0 = f_interp(t_target)
463
+ else:
464
+ f0 = np.zeros(target_frames)
465
+
466
+ if len(f0) != target_frames:
467
+ f0 = (
468
+ f0[:target_frames]
469
+ if len(f0) > target_frames
470
+ else np.pad(f0, (0, target_frames - len(f0)), "constant")
471
+ )
472
+
473
+ return f0
474
+
475
+ def process(self, audio_path: str, *, f0_path: str | None = None, verbose: Optional[bool] = None) -> np.ndarray:
476
+ """Run f0 extraction for a single wav.
477
+
478
+ Args:
479
+ audio_path: Path to the input wav file.
480
+ f0_path: if is not None, save the f0 data to this path.
481
+ verbose: Override instance-level verbose flag for this call.
482
+
483
+ Returns:
484
+ np.ndarray: shape ``[T]``, f0 in Hz (0 for unvoiced).
485
+ """
486
+ verbose = self.verbose if verbose is None else verbose
487
+ if verbose:
488
+ print(f"[f0 extraction] process: start: {audio_path}")
489
+ t0 = time.time()
490
+
491
+ audio, _ = librosa.load(audio_path, sr=self.input_sr)
492
+ f0_16k = self.model.infer_from_audio(audio, thred=self.thred)
493
+ f0 = self.interpolate_f0(
494
+ f0_16k,
495
+ original_length=audio.shape[-1],
496
+ original_sr=self.input_sr,
497
+ target_sr=self.target_sr,
498
+ hop_size=self.hop_size,
499
+ max_duration=self.max_duration,
500
+ )
501
+
502
+ if verbose:
503
+ dt = time.time() - t0
504
+ voiced_ratio = float(np.mean(f0 > 0)) if len(f0) else 0.0
505
+ print(
506
+ "[f0 extraction] process: done:",
507
+ f"frames={len(f0)}",
508
+ f"voiced_ratio={voiced_ratio:.3f}",
509
+ f"time={dt:.3f}s",
510
+ )
511
+ if f0_path is not None:
512
+ np.save(f0_path, f0)
513
+
514
+ return f0
515
+
516
+
517
+ if __name__ == "__main__":
518
+ model_path = (
519
+ "pretrained_models/rmvpe/rmvpe.pt"
520
+ )
521
+ audio_path = "./outputs/transcription/test.wav"
522
+
523
+ pe = F0Extractor(
524
+ model_path,
525
+ device="cuda",
526
+ )
527
+ f0 = pe.process(audio_path)
preprocess/tools/g2p.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import ToJyutping
4
+ from g2pM import G2pM
5
+ from g2p_en import G2p as G2pE
6
+
7
+ _EN_WORD_RE = re.compile(r"^[A-Za-z]+(?:'[A-Za-z]+)*$")
8
+ _ZH_WORD_RE = re.compile(r"[\u4e00-\u9fff]")
9
+
10
+ EN_FLAG = "en_"
11
+ YUE_FLAG = "yue_"
12
+ ZH_FLAG = "zh_"
13
+
14
+ g2p_zh = G2pM()
15
+ g2p_en = G2pE()
16
+
17
+
18
+ def is_chinese_char(word: str) -> bool:
19
+ if len(word) != 1:
20
+ return False
21
+ return bool(_ZH_WORD_RE.fullmatch(word))
22
+
23
+ def is_english_word(word: str) -> bool:
24
+ if not word:
25
+ return False
26
+ return bool(_EN_WORD_RE.fullmatch(word))
27
+
28
+ def g2p_cantonese(sent):
29
+ return ToJyutping.get_jyutping_list(sent) # with tone
30
+
31
+ def g2p_mandarin(sent):
32
+ return g2p_zh(sent, tone=True, char_split=False)
33
+
34
+ def g2p_english(word):
35
+ return g2p_en(word)
36
+
37
+ def g2p_transform(words, lang):
38
+
39
+ zh_words = []
40
+ transformed_words = [0] * len(words)
41
+
42
+ for idx, w in enumerate(words):
43
+ if w == "<SP>":
44
+ transformed_words[idx] = w
45
+ continue
46
+
47
+ w = w.replace("?", "").replace(".", "").replace("!", "").replace(",", "")
48
+
49
+ if is_chinese_char(w):
50
+ zh_words.append([idx, w])
51
+ else:
52
+ if is_english_word(w):
53
+ w = EN_FLAG + "-".join(g2p_english(w.lower()))
54
+ else:
55
+ w = "<SP>"
56
+ transformed_words[idx] = w
57
+
58
+ sent = "".join([k[1] for k in zh_words])
59
+
60
+ # zh (zh and yue) transformer to g2p
61
+ if len(sent) > 0:
62
+ if lang == "Cantonese":
63
+ g2pm_rst = g2p_cantonese(sent) # with tone
64
+ g2pm_rst = [YUE_FLAG + k[1] for k in g2pm_rst]
65
+ else:
66
+ g2pm_rst = g2p_mandarin(sent)
67
+ g2pm_rst = [ZH_FLAG + k for k in g2pm_rst]
68
+ for p, w in zip([k[0] for k in zh_words], g2pm_rst):
69
+ transformed_words[p] = w
70
+
71
+ return transformed_words
72
+
preprocess/tools/lyric_transcription.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary
2
+ # https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2
3
+ import os
4
+ import re
5
+ import time
6
+ from typing import Any, Dict, List, Tuple
7
+
8
+ import librosa
9
+ import numpy as np
10
+ from funasr import AutoModel
11
+
12
+
13
+ def _build_words_with_gaps(raw_words, raw_timestamps, wav_fn: str):
14
+ words, word_durs = [], []
15
+ prev = 0.0
16
+ for w, t in zip(raw_words, raw_timestamps):
17
+ s, e = float(t[0]), float(t[1])
18
+ if s > prev:
19
+ words.append("<SP>")
20
+ word_durs.append(s - prev)
21
+ words.append(w)
22
+ word_durs.append(e - s)
23
+ prev = e
24
+
25
+ wav_len = librosa.get_duration(filename=wav_fn)
26
+ if wav_len > prev:
27
+ if len(words) == 0:
28
+ words.append("<SP>")
29
+ word_durs.append(wav_len)
30
+ return words, word_durs
31
+ if words[-1] != "<SP>":
32
+ words.append("<SP>")
33
+ word_durs.append(wav_len - prev)
34
+ else:
35
+ word_durs[-1] += wav_len - prev
36
+
37
+ return words, word_durs
38
+
39
+ def _word_dur_post_process(words, word_durs, f0):
40
+ """Post-process word durations using f0 to better place silences.
41
+ """
42
+ # f0 time grid parameters
43
+ sr = 24000 # f0 sample rate
44
+ hop_length = 480 # f0 hop length
45
+
46
+ # Convert word durations (seconds) to frame boundaries on the f0 grid.
47
+ boundaries = np.cumsum([
48
+ 0,
49
+ *[
50
+ int(dur * sr / hop_length)
51
+ for dur in word_durs
52
+ ],
53
+ ]).tolist()
54
+
55
+ sil_tolerance = 5 # tolerance frames for silence detection
56
+ ext_tolerance = 5 # tolerance frames for vocal extension
57
+
58
+ new_words: list[str] = []
59
+ new_word_durs: list[float] = []
60
+ if words:
61
+ new_words.append(words[0])
62
+ new_word_durs.append(word_durs[0])
63
+
64
+ for i in range(1, len(words)):
65
+ word = words[i]
66
+ if word == "<SP>":
67
+ start_frame = boundaries[i]
68
+ end_frame = boundaries[i + 1]
69
+
70
+ num_frames = end_frame - start_frame
71
+ frame_idx = start_frame
72
+
73
+ # Find first region with at least 5 consecutive "unvoiced" frames.
74
+ unvoiced_count = 0
75
+ while frame_idx < end_frame:
76
+ if f0[frame_idx] <= 1: # unvoiced
77
+ unvoiced_count += 1
78
+ if unvoiced_count >= sil_tolerance:
79
+ frame_idx -= sil_tolerance - 1 # back to the last voiced frame
80
+ break
81
+ else:
82
+ unvoiced_count = 0
83
+ frame_idx += 1
84
+
85
+ voice_frames = frame_idx - start_frame
86
+
87
+ if voice_frames >= int(num_frames * 0.9): # over 90% voiced
88
+ # Treat the whole "<SP>" as silence and merge into previous word.
89
+ new_word_durs[-1] += word_durs[i]
90
+ elif voice_frames >= ext_tolerance: # over 5 frames voiced
91
+ # Split the "<SP>" into two parts: leading silence and tail kept as "<SP>".
92
+ dur = voice_frames * hop_length / sr
93
+ new_word_durs[-1] += dur
94
+ new_words.append("<SP>")
95
+ new_word_durs.append(word_durs[i] - dur)
96
+ else:
97
+ # Too short to adjust, keep as-is.
98
+ new_words.append(word)
99
+ new_word_durs.append(word_durs[i])
100
+ else:
101
+ new_words.append(word)
102
+ new_word_durs.append(word_durs[i])
103
+
104
+ return new_words, new_word_durs
105
+
106
+
107
+ class _ASRZhModel:
108
+ """Mandarin/Cantonese ASR wrapper."""
109
+
110
+ def __init__(self, model_path: str, device: str):
111
+ self.model = AutoModel(
112
+ model=model_path,
113
+ disable_update=True,
114
+ device=device,
115
+ )
116
+
117
+ def process(self, wav_fn):
118
+ out = self.model.generate(wav_fn, output_timestamp=True)[0]
119
+ raw_words = out["text"].replace("@", "").split(" ")
120
+ raw_timestamps = [[t[0] / 1000, t[1] / 1000] for t in out["timestamp"]]
121
+ words, word_durs = _build_words_with_gaps(raw_words, raw_timestamps, wav_fn)
122
+
123
+ if os.path.exists(wav_fn.replace(".wav", "_f0.npy")):
124
+ words, word_durs = _word_dur_post_process(
125
+ words, word_durs, np.load(wav_fn.replace(".wav", "_f0.npy"))
126
+ )
127
+
128
+ return words, word_durs
129
+
130
+
131
+ class _ASREnModel:
132
+ """English ASR wrapper for NeMo Parakeet-TDT."""
133
+
134
+ def __init__(self, model_path: str, device: str):
135
+ try:
136
+ import nemo.collections.asr as nemo_asr # type: ignore
137
+ except Exception as e: # pragma: no cover
138
+ raise ImportError(
139
+ "NeMo (nemo_toolkit) is required for ASR English but is not available in this Python env. "
140
+ "Install it in the active environment, then retry."
141
+ ) from e
142
+
143
+ self.model = nemo_asr.models.ASRModel.restore_from(
144
+ restore_path=model_path,
145
+ map_location=device,
146
+ )
147
+ self.model.eval()
148
+
149
+ @staticmethod
150
+ def _clean_word(word: str) -> str:
151
+ return re.sub(r"[\?\.,:]", "", word).strip()
152
+
153
+ @staticmethod
154
+ def _extract_word_segments(output: Any) -> List[Dict[str, Any]]:
155
+ ts = getattr(output, "timestamp", None)
156
+ if not ts or not isinstance(ts, dict):
157
+ return []
158
+ word_ts = ts.get("word")
159
+ return word_ts if isinstance(word_ts, list) else []
160
+
161
+ def process(self, wav_fn: str) -> Tuple[List[str], List[float]]:
162
+ outputs = self.model.transcribe(
163
+ [wav_fn],
164
+ timestamps=True,
165
+ batch_size=1,
166
+ num_workers=0,
167
+ )
168
+ output = outputs[0] if outputs else None
169
+
170
+ raw_words: List[str] = []
171
+ raw_timestamps: List[List[float]] = []
172
+ if output is not None:
173
+ for w in self._extract_word_segments(output):
174
+ s, e = float(w.get("start", 0.0)), float(w.get("end", 0.0))
175
+ word = self._clean_word(str(w.get("word", "")))
176
+ if word:
177
+ raw_words.append(word)
178
+ raw_timestamps.append([s, e])
179
+
180
+ words, durs = _build_words_with_gaps(raw_words, raw_timestamps, wav_fn)
181
+
182
+ if os.path.exists(wav_fn.replace(".wav", "_f0.npy")):
183
+ words, durs = _word_dur_post_process(
184
+ words, durs, np.load(wav_fn.replace(".wav", "_f0.npy"))
185
+ )
186
+
187
+ return words, durs
188
+
189
+
190
+ class LyricTranscriber:
191
+ """Transcribe lyrics from singing voice segment
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ zh_model_path: str,
197
+ en_model_path: str,
198
+ device: str = "cuda",
199
+ *,
200
+ verbose: bool = True,
201
+ ):
202
+ """Initialize lyric transcriber.
203
+
204
+ Args:
205
+ zh_model_path (str): Path to the Chinese model file.
206
+ en_model_path (str): Path to the English model file.
207
+ device (str): Device to use for tensor operations.
208
+ verbose (bool): Whether to print verbose logs.
209
+ """
210
+ self.verbose = verbose
211
+ self.device = device
212
+ self.zh_model_path = zh_model_path
213
+ self.en_model_path = en_model_path
214
+
215
+ if self.verbose:
216
+ print(
217
+ "[lyric transcription] init: start:",
218
+ f"device={device}",
219
+ f"model_path={zh_model_path}",
220
+ )
221
+
222
+ # Always initialize Chinese ASR.
223
+ self.zh_model = _ASRZhModel(device=device, model_path=zh_model_path)
224
+
225
+ # English ASR will be lazily initialized on first English request to avoid long waiting cost when importing NeMo
226
+ self.en_model = None
227
+
228
+ if self.verbose:
229
+ print("[lyric transcription] init: success")
230
+
231
+ def process(self, wav_fn, language: str | None = "Mandarin", *, verbose: bool | None = None):
232
+ """ Lyric transcriber process
233
+
234
+ Args:
235
+ wav_fn (str): Path to the audio file.
236
+ language (str | None): Language of the audio. Defaults to "Mandarin". Supports "Mandarin", "Cantonese" and "English".
237
+ verbose (bool | None): Whether to print verbose logs. Defaults to None.
238
+ """
239
+ v = self.verbose if verbose is None else verbose
240
+ if language not in {"Mandarin", "Cantonese", "English"}:
241
+ raise ValueError(f"Unsupported language: {language}, should be one of ['Mandarin', 'Cantonese', 'English']")
242
+ if v:
243
+ print(f"[lyric transcription] process: start: wav_fn={wav_fn} language={language}")
244
+ t0 = time.time()
245
+
246
+ lang = (language or "auto").lower()
247
+ if lang in {"english"}:
248
+ if self.en_model is None:
249
+ # Lazy-load NeMo model only when English is actually used.
250
+ if v:
251
+ print("[lyric transcription] init English ASR, please make sure NeMo is installed")
252
+ self.en_model = _ASREnModel(model_path=self.en_model_path, device=self.device)
253
+ out = self.en_model.process(wav_fn)
254
+ else:
255
+ out = self.zh_model.process(wav_fn)
256
+
257
+ if v:
258
+ words, durs = out
259
+ n_words = len(words) if isinstance(words, list) else 0
260
+ dur_sum = float(sum(durs)) if isinstance(durs, list) else 0.0
261
+ dt = time.time() - t0
262
+ print(
263
+ "[lyric transcription] process: done:",
264
+ f"n_words={n_words}",
265
+ f"dur_sum={dur_sum:.3f}s",
266
+ f"time={dt:.3f}s",
267
+ )
268
+
269
+ return out
270
+
271
+
272
+ if __name__ == "__main__":
273
+ m = LyricTranscriber(
274
+ zh_model_path="pretrained_models/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
275
+ en_model_path="pretrained_models/parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo",
276
+ device="cuda"
277
+ )
278
+ print(m.process("example/test/asr_zh.wav", language="Mandarin"))
279
+ print(m.process("example/test/asr_en.wav", language="English"))
preprocess/tools/midi_parser.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SoulX-Singer MIDI <-> metadata converter.
3
+
4
+ Converts between SoulX-Singer-style metadata JSON (with note_text, note_dur,
5
+ note_pitch, note_type per segment) and standard MIDI files. Uses an internal
6
+ Note dataclass (start_s, note_dur, note_text, note_pitch, note_type) as the
7
+ intermediate representation.
8
+ """
9
+ import os
10
+ import json
11
+ import shutil
12
+ from dataclasses import dataclass
13
+ from typing import Any, List, Tuple, Union
14
+
15
+ import librosa
16
+ import mido
17
+ from soundfile import write
18
+
19
+ from .f0_extraction import F0Extractor
20
+ from .g2p import g2p_transform
21
+
22
+
23
+ # Audio and segmenting constants (used by _edit_data_to_meta)
24
+ SAMPLE_RATE = 44100
25
+ DEFAULT_LANGUAGE = "Mandarin"
26
+ MAX_GAP_SEC = 5.0 # gap (sec) above which we start a new segment
27
+ MAX_SEGMENT_DUR_SUM_SEC = 60.0 # max cumulative note duration per segment (sec)
28
+ MIN_GAP_THRESHOLD_SEC = 0.001 # ignore gaps smaller than this
29
+ LONG_SILENCE_THRESHOLD_SEC = 0.05 # treat as separate <SP> if gap larger
30
+ MAX_LEADING_SP_DUR_SEC = 2.0 # cap leading silence in a segment to this (sec)
31
+ DEFAULT_RMVPE_MODEL_PATH = "pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt"
32
+
33
+
34
+ @dataclass
35
+ class Note:
36
+ """Single note: text, duration (seconds), pitch (MIDI), type. start_s is absolute start time in seconds (for ordering / MIDI)."""
37
+ start_s: float
38
+ note_dur: float
39
+ note_text: str
40
+ note_pitch: int
41
+ note_type: int
42
+
43
+ @property
44
+ def end_s(self) -> float:
45
+ return self.start_s + self.note_dur
46
+
47
+
48
+
49
+ def remove_duplicate_segments(meta_data: List[dict]) -> None:
50
+ """Merge consecutive identical notes (same text, pitch, type) within each segment. Mutates meta_data in place."""
51
+ for idx, segment in enumerate(meta_data):
52
+ texts = segment["note_text"]
53
+ durs = segment["note_dur"]
54
+ pitches = segment["note_pitch"]
55
+ types = segment["note_type"]
56
+ new_texts = []
57
+ new_durs = []
58
+ new_pitches = []
59
+ new_types = []
60
+ for i in range(len(texts)):
61
+ if i == 0:
62
+ new_texts.append(texts[i])
63
+ new_durs.append(durs[i])
64
+ new_pitches.append(pitches[i])
65
+ new_types.append(types[i])
66
+ continue
67
+ t, d, p, ty = texts[i], durs[i], pitches[i], types[i]
68
+ if t == "<SP>" and texts[i - 1] == "<SP>":
69
+ new_durs[-1] += d
70
+ continue
71
+ if t == texts[i - 1] and p == pitches[i - 1] and ty == types[i - 1]:
72
+ new_durs[-1] += d
73
+ else:
74
+ new_texts.append(t)
75
+ new_durs.append(d)
76
+ new_pitches.append(p)
77
+ new_types.append(ty)
78
+ meta_data[idx]["note_text"] = new_texts
79
+ meta_data[idx]["note_dur"] = new_durs
80
+ meta_data[idx]["note_pitch"] = new_pitches
81
+ meta_data[idx]["note_type"] = new_types
82
+
83
+ def meta2notes(meta_path: str) -> List[Note]:
84
+ """Parse SoulX-Singer metadata JSON into a flat list of Note (absolute start_s)."""
85
+ with open(meta_path, "r", encoding="utf-8") as f:
86
+ segments = json.load(f)
87
+ if not isinstance(segments, list):
88
+ raise ValueError(f"Metadata must be a list of segments, got {type(segments).__name__}")
89
+ if not segments:
90
+ raise ValueError("Metadata has no segments.")
91
+
92
+ notes: List[Note] = []
93
+ for seg in segments:
94
+ offset_s = seg["time"][0] / 1000
95
+ words = [str(x).replace("<AP>", "<SP>") for i, x in enumerate(seg["text"].split())]
96
+ word_durs = [float(x) for x in seg["duration"].split()]
97
+ pitches = [int(x) for x in seg["note_pitch"].split()]
98
+ types = [int(x) if words[i] != "<SP>" else 1 for i, x in enumerate(seg["note_type"].split())]
99
+ if len(words) != len(word_durs) or len(word_durs) != len(pitches) or len(pitches) != len(types):
100
+ raise ValueError(
101
+ f"Length mismatch in segment {seg.get('item_name', '?')}: "
102
+ "note_text, note_dur, note_pitch, note_type must have same length"
103
+ )
104
+ current_s = offset_s
105
+ for text, dur, pitch, type_ in zip(words, word_durs, pitches, types):
106
+ notes.append(
107
+ Note(
108
+ start_s=current_s,
109
+ note_dur=float(dur),
110
+ note_text=str(text),
111
+ note_pitch=int(pitch),
112
+ note_type=int(type_),
113
+ )
114
+ )
115
+ current_s += float(dur)
116
+ return notes
117
+
118
+ def _append_segment_to_meta(
119
+ meta_path_str: str,
120
+ cut_wavs_output_dir: str,
121
+ vocal_file: str,
122
+ audio_data: Any,
123
+ meta_data: List[dict],
124
+ note_start: List[float],
125
+ note_end: List[float],
126
+ note_text: List[Any],
127
+ note_pitch: List[Any],
128
+ note_type: List[Any],
129
+ note_dur: List[float],
130
+ end_time_ms_override: float | None = None,
131
+ ) -> None:
132
+ """Write one segment wav and append one segment dict to meta_data. Caller clears note_* lists after."""
133
+ base_name = os.path.splitext(os.path.basename(meta_path_str))[0]
134
+ item_name = f"{base_name}_{len(meta_data)}"
135
+ wav_fn = os.path.join(cut_wavs_output_dir, f"{item_name}.wav")
136
+ start_ms = int(note_start[0] * 1000)
137
+ end_ms = (
138
+ int(end_time_ms_override)
139
+ if end_time_ms_override is not None
140
+ else int(note_end[-1] * 1000)
141
+ )
142
+ start_sample = int(note_start[0] * SAMPLE_RATE)
143
+ end_sample = int(note_end[-1] * SAMPLE_RATE)
144
+ write(wav_fn, audio_data[start_sample:end_sample], SAMPLE_RATE)
145
+ meta_data.append({
146
+ "item_name": item_name,
147
+ "wav_fn": wav_fn,
148
+ "origin_wav_fn": vocal_file,
149
+ "start_time_ms": start_ms,
150
+ "end_time_ms": end_ms,
151
+ "language": DEFAULT_LANGUAGE,
152
+ "note_text": list(note_text),
153
+ "note_pitch": list(note_pitch),
154
+ "note_type": list(note_type),
155
+ "note_dur": list(note_dur),
156
+ })
157
+
158
+
159
+ def convert_meta(meta_data: List[dict], rmvpe_model_path, device="cuda"):
160
+ pitch_extractor = F0Extractor(rmvpe_model_path, device=device, verbose=False)
161
+ converted_data = []
162
+
163
+ for item in meta_data:
164
+ wav_fn = item.get("wav_fn")
165
+ if not wav_fn or not os.path.isfile(wav_fn):
166
+ raise FileNotFoundError(f"Segment wav file not found: {wav_fn}")
167
+ f0 = pitch_extractor.process(wav_fn)
168
+ converted_item = {
169
+ "index": item.get("item_name"),
170
+ "language": item.get("language"),
171
+ "time": [item.get("start_time_ms", 0), item.get("end_time_ms", sum(item["note_dur"]) * 1000)],
172
+ "duration": " ".join(str(round(x, 2)) for x in item.get("note_dur", [])),
173
+ "text": " ".join(item.get("note_text", [])),
174
+ "phoneme": " ".join(g2p_transform(item.get("note_text", []), DEFAULT_LANGUAGE)),
175
+ "note_pitch": " ".join(str(x) for x in item.get("note_pitch", [])),
176
+ "note_type": " ".join(str(x) for x in item.get("note_type", [])),
177
+ "f0": " ".join(str(round(float(x), 1)) for x in f0),
178
+ }
179
+ converted_data.append(converted_item)
180
+
181
+ return converted_data
182
+
183
+
184
+ def _edit_data_to_meta(
185
+ meta_path_str: str,
186
+ edit_data: List[dict],
187
+ vocal_file: str,
188
+ rmvpe_model_path: str | None = None,
189
+ device: str = "cuda",
190
+ ) -> None:
191
+ """Write SoulX-Singer metadata JSON from edit_data (list of {start, end, note_text, note_pitch, note_type})."""
192
+ # Use a fixed temporary directory for cut wavs
193
+ cut_wavs_output_dir = os.path.join(os.path.dirname(vocal_file), "cut_wavs_tmp")
194
+ os.makedirs(cut_wavs_output_dir, exist_ok=True)
195
+
196
+ note_text: List[Any] = []
197
+ note_pitch: List[Any] = []
198
+ note_type: List[Any] = []
199
+ note_dur: List[float] = []
200
+ note_start: List[float] = []
201
+ note_end: List[float] = []
202
+ prev_end = 0.0
203
+ meta_data: List[dict] = []
204
+ audio_data, _ = librosa.load(vocal_file, sr=SAMPLE_RATE, mono=True)
205
+ dur_sum = 0.0
206
+
207
+ for entry in edit_data:
208
+ start = float(entry["start"])
209
+ end = float(entry["end"])
210
+ text = entry["note_text"]
211
+ pitch = entry["note_pitch"]
212
+ type_ = entry["note_type"]
213
+
214
+ if text == "" or pitch == "" or type_ == "":
215
+ note_text.append("<SP>")
216
+ note_pitch.append(0)
217
+ note_type.append(1)
218
+ note_dur.append(end - start)
219
+ note_start.append(start)
220
+ note_end.append(end)
221
+ prev_end = end
222
+ dur_sum += end - start
223
+ continue
224
+
225
+ if (
226
+ len(note_text) > 0
227
+ and note_text[-1] == "<SP>"
228
+ and note_dur[-1] > MAX_LEADING_SP_DUR_SEC
229
+ ):
230
+ cut_time = note_dur[-1] - MAX_LEADING_SP_DUR_SEC
231
+ note_dur[-1] = MAX_LEADING_SP_DUR_SEC
232
+ end_ms_override = note_end[-1] * 1000 - cut_time * 1000
233
+ _append_segment_to_meta(
234
+ meta_path_str,
235
+ cut_wavs_output_dir,
236
+ vocal_file,
237
+ audio_data,
238
+ meta_data,
239
+ note_start,
240
+ note_end,
241
+ note_text,
242
+ note_pitch,
243
+ note_type,
244
+ note_dur,
245
+ end_time_ms_override=end_ms_override,
246
+ )
247
+ note_text = []
248
+ note_pitch = []
249
+ note_type = []
250
+ note_dur = []
251
+ note_start = []
252
+ note_end = []
253
+ prev_end = start
254
+ dur_sum = 0.0
255
+
256
+ gap_from_prev = start - prev_end
257
+ gap_from_last_note = (start - note_end[-1]) if note_end else 0.0
258
+ if (
259
+ gap_from_prev >= MAX_GAP_SEC
260
+ or gap_from_last_note >= MAX_GAP_SEC
261
+ or dur_sum >= MAX_SEGMENT_DUR_SUM_SEC
262
+ ):
263
+ if len(note_text) > 0:
264
+ _append_segment_to_meta(
265
+ meta_path_str,
266
+ cut_wavs_output_dir,
267
+ vocal_file,
268
+ audio_data,
269
+ meta_data,
270
+ note_start,
271
+ note_end,
272
+ note_text,
273
+ note_pitch,
274
+ note_type,
275
+ note_dur,
276
+ )
277
+ note_text = []
278
+ note_pitch = []
279
+ note_type = []
280
+ note_dur = []
281
+ note_start = []
282
+ note_end = []
283
+ prev_end = start
284
+ dur_sum = 0.0
285
+
286
+ if start - prev_end > MIN_GAP_THRESHOLD_SEC:
287
+ if start - prev_end > LONG_SILENCE_THRESHOLD_SEC or len(note_text) == 0:
288
+ note_text.append("<SP>")
289
+ note_pitch.append(0)
290
+ note_type.append(1)
291
+ note_dur.append(start - prev_end)
292
+ note_start.append(prev_end)
293
+ note_end.append(start)
294
+ else:
295
+ if len(note_dur) > 0:
296
+ note_dur[-1] += start - prev_end
297
+ note_end[-1] = start
298
+
299
+ prev_end = end
300
+ note_text.append(text)
301
+ note_pitch.append(int(pitch))
302
+ note_type.append(int(type_))
303
+ note_dur.append(end - start)
304
+ note_start.append(start)
305
+ note_end.append(end)
306
+ dur_sum += end - start
307
+
308
+ if len(note_text) > 0:
309
+ _append_segment_to_meta(
310
+ meta_path_str,
311
+ cut_wavs_output_dir,
312
+ vocal_file,
313
+ audio_data,
314
+ meta_data,
315
+ note_start,
316
+ note_end,
317
+ note_text,
318
+ note_pitch,
319
+ note_type,
320
+ note_dur,
321
+ )
322
+
323
+ remove_duplicate_segments(meta_data)
324
+
325
+ _rmvpe_path = rmvpe_model_path or DEFAULT_RMVPE_MODEL_PATH
326
+ converted_data = convert_meta(meta_data, _rmvpe_path, device)
327
+
328
+ with open(meta_path_str, "w", encoding="utf-8") as f:
329
+ json.dump(converted_data, f, ensure_ascii=False, indent=2)
330
+
331
+ # Clean up temporary cut wavs directory
332
+ try:
333
+ shutil.rmtree(cut_wavs_output_dir, ignore_errors=True)
334
+ except Exception:
335
+ pass
336
+
337
+
338
+ def notes2meta(
339
+ notes: List[Note],
340
+ meta_path: str,
341
+ vocal_file: str,
342
+ rmvpe_model_path: str | None = None,
343
+ device: str = "cuda",
344
+ ) -> None:
345
+ """Write SoulX-Singer metadata JSON from a list of Note (segmenting + wav cuts)."""
346
+ edit_data = [
347
+ {
348
+ "start": n.start_s,
349
+ "end": n.end_s,
350
+ "note_text": n.note_text,
351
+ "note_pitch": str(n.note_pitch),
352
+ "note_type": str(n.note_type),
353
+ }
354
+ for n in notes
355
+ ]
356
+ _edit_data_to_meta(
357
+ str(meta_path),
358
+ edit_data,
359
+ vocal_file,
360
+ rmvpe_model_path=rmvpe_model_path,
361
+ device=device,
362
+ )
363
+
364
+
365
+ @dataclass(frozen=True)
366
+ class MidiDefaults:
367
+ ticks_per_beat: int = 500
368
+ tempo: int = 500000 # microseconds per beat (120 BPM)
369
+ time_signature: Tuple[int, int] = (4, 4)
370
+ velocity: int = 64
371
+
372
+
373
+ def _seconds_to_ticks(seconds: float, ticks_per_beat: int, tempo: int) -> int:
374
+ return int(round(seconds * ticks_per_beat * 1_000_000 / tempo))
375
+
376
+
377
+ def notes2midi(
378
+ notes: List[Note],
379
+ midi_path: str,
380
+ defaults: MidiDefaults | None = None,
381
+ ) -> None:
382
+ """Write MIDI file from a list of Note."""
383
+ defaults = defaults or MidiDefaults()
384
+ if not notes:
385
+ raise ValueError("Empty note list.")
386
+
387
+ events: List[Tuple[int, int, Union[mido.Message, mido.MetaMessage]]] = []
388
+ for n in notes:
389
+ start_s = n.start_s
390
+ end_s = n.end_s
391
+ if end_s <= start_s:
392
+ continue
393
+
394
+ start_ticks = _seconds_to_ticks(
395
+ start_s, defaults.ticks_per_beat, defaults.tempo
396
+ )
397
+ end_ticks = _seconds_to_ticks(
398
+ end_s, defaults.ticks_per_beat, defaults.tempo
399
+ )
400
+ if end_ticks <= start_ticks:
401
+ end_ticks = start_ticks + 1
402
+
403
+ lyric = n.note_text
404
+ try:
405
+ lyric = lyric.encode("utf-8").decode("latin1")
406
+ except (UnicodeEncodeError, UnicodeDecodeError):
407
+ pass
408
+ if n.note_type == 3:
409
+ lyric = "-"
410
+
411
+ events.append(
412
+ (start_ticks, 1, mido.MetaMessage("lyrics", text=lyric, time=0))
413
+ )
414
+ events.append(
415
+ (
416
+ start_ticks,
417
+ 2,
418
+ mido.Message(
419
+ "note_on",
420
+ note=n.note_pitch,
421
+ velocity=defaults.velocity,
422
+ time=0,
423
+ ),
424
+ )
425
+ )
426
+ events.append(
427
+ (
428
+ end_ticks,
429
+ 0,
430
+ mido.Message("note_off", note=n.note_pitch, velocity=0, time=0),
431
+ )
432
+ )
433
+
434
+ events.sort(key=lambda x: (x[0], x[1]))
435
+
436
+ mid = mido.MidiFile(ticks_per_beat=defaults.ticks_per_beat)
437
+ track = mido.MidiTrack()
438
+ mid.tracks.append(track)
439
+
440
+ track.append(mido.MetaMessage("set_tempo", tempo=defaults.tempo, time=0))
441
+ track.append(
442
+ mido.MetaMessage(
443
+ "time_signature",
444
+ numerator=defaults.time_signature[0],
445
+ denominator=defaults.time_signature[1],
446
+ time=0,
447
+ )
448
+ )
449
+
450
+ last_tick = 0
451
+ for tick, _, msg in events:
452
+ msg.time = max(0, tick - last_tick)
453
+ track.append(msg)
454
+ last_tick = tick
455
+
456
+ track.append(mido.MetaMessage("end_of_track", time=0))
457
+ mid.save(midi_path)
458
+
459
+
460
+ def midi2notes(midi_path: str) -> List[Note]:
461
+ """Parse MIDI file into a list of Note. Merges all tracks; tempo from last set_tempo event."""
462
+ mid = mido.MidiFile(midi_path)
463
+ ticks_per_beat = mid.ticks_per_beat
464
+ tempo = 500000
465
+
466
+ raw_notes: List[dict] = []
467
+ lyrics: List[Tuple[int, str]] = []
468
+
469
+ for track in mid.tracks:
470
+ abs_ticks = 0
471
+ active = {}
472
+ for msg in track:
473
+ abs_ticks += msg.time
474
+ if msg.type == "set_tempo":
475
+ tempo = msg.tempo
476
+ elif msg.type == "lyrics":
477
+ text = msg.text
478
+ try:
479
+ text = text.encode("latin1").decode("utf-8")
480
+ except Exception:
481
+ pass
482
+ lyrics.append((abs_ticks, text))
483
+ elif msg.type == "note_on":
484
+ key = (msg.channel, msg.note)
485
+ if msg.velocity > 0:
486
+ active[key] = (abs_ticks, msg.velocity)
487
+ else:
488
+ if key in active:
489
+ start_ticks, vel = active.pop(key)
490
+ raw_notes.append(
491
+ {
492
+ "midi": msg.note,
493
+ "start_ticks": start_ticks,
494
+ "duration_ticks": abs_ticks - start_ticks,
495
+ "velocity": vel,
496
+ "lyric": "",
497
+ }
498
+ )
499
+ elif msg.type == "note_off":
500
+ key = (msg.channel, msg.note)
501
+ if key in active:
502
+ start_ticks, vel = active.pop(key)
503
+ raw_notes.append(
504
+ {
505
+ "midi": msg.note,
506
+ "start_ticks": start_ticks,
507
+ "duration_ticks": abs_ticks - start_ticks,
508
+ "velocity": vel,
509
+ "lyric": "",
510
+ }
511
+ )
512
+
513
+ if not raw_notes:
514
+ raise ValueError("No notes found in MIDI file")
515
+
516
+ for n in raw_notes:
517
+ n["end_ticks"] = n["start_ticks"] + n["duration_ticks"]
518
+
519
+ raw_notes.sort(key=lambda n: n["start_ticks"])
520
+ lyrics.sort(key=lambda x: x[0])
521
+
522
+ trimmed = []
523
+ for note in raw_notes:
524
+ while trimmed:
525
+ prev = trimmed[-1]
526
+ if note["start_ticks"] < prev["end_ticks"]:
527
+ prev["end_ticks"] = note["start_ticks"]
528
+ prev["duration_ticks"] = prev["end_ticks"] - prev["start_ticks"]
529
+ if prev["duration_ticks"] <= 0:
530
+ trimmed.pop()
531
+ continue
532
+ break
533
+ trimmed.append(note)
534
+ raw_notes = trimmed
535
+
536
+ tolerance = ticks_per_beat // 100
537
+ lyric_idx = 0
538
+ for note in raw_notes:
539
+ while lyric_idx < len(lyrics) and lyrics[lyric_idx][0] < note["start_ticks"] - tolerance:
540
+ lyric_idx += 1
541
+ if lyric_idx < len(lyrics):
542
+ lyric_ticks, lyric_text = lyrics[lyric_idx]
543
+ if abs(lyric_ticks - note["start_ticks"]) <= tolerance:
544
+ note["lyric"] = lyric_text
545
+ lyric_idx += 1
546
+
547
+ def ticks_to_seconds(ticks: int) -> float:
548
+ return (ticks / ticks_per_beat) * (tempo / 1_000_000)
549
+
550
+ result: List[Note] = []
551
+ prev_end_s = 0.0
552
+ for idx, n in enumerate(raw_notes):
553
+ start_s = ticks_to_seconds(n["start_ticks"])
554
+ end_s = ticks_to_seconds(n["end_ticks"])
555
+ if prev_end_s > start_s:
556
+ start_s = prev_end_s
557
+ dur_s = end_s - start_s
558
+ if dur_s <= 0:
559
+ continue
560
+
561
+ lyric = n.get("lyric", "")
562
+ if not lyric:
563
+ tp = 2
564
+ text = "啦"
565
+ elif lyric == "<SP>":
566
+ tp = 1
567
+ text = "<SP>"
568
+ elif lyric == "-":
569
+ tp = 3
570
+ text = raw_notes[idx - 1].get("lyric", "-") if idx > 0 else "-"
571
+ else:
572
+ tp = 2
573
+ text = lyric
574
+
575
+ result.append(
576
+ Note(
577
+ start_s=start_s,
578
+ note_dur=dur_s,
579
+ note_text=text,
580
+ note_pitch=n["midi"],
581
+ note_type=tp,
582
+ )
583
+ )
584
+ prev_end_s = end_s
585
+
586
+ return result
587
+
588
+
589
+ def meta2midi(meta_path: str, midi_path: str, defaults: MidiDefaults | None = None) -> None:
590
+ """Convert SoulX-Singer metadata JSON to MIDI file (meta -> List[Note] -> midi)."""
591
+ notes = meta2notes(meta_path)
592
+ notes2midi(notes, midi_path, defaults)
593
+ print(f"Saved MIDI to {midi_path}")
594
+
595
+
596
+ def midi2meta(
597
+ midi_path: str,
598
+ meta_path: str,
599
+ vocal_file: str,
600
+ rmvpe_model_path: str | None = None,
601
+ device: str = "cuda",
602
+ ) -> None:
603
+ """Convert MIDI file to SoulX-Singer metadata JSON (midi -> List[Note] -> meta)."""
604
+ meta_dir = os.path.dirname(meta_path)
605
+ if meta_dir:
606
+ os.makedirs(meta_dir, exist_ok=True)
607
+ # cut_wavs will be written to a fixed temporary directory inside _edit_data_to_meta
608
+ notes = midi2notes(midi_path)
609
+ notes2meta(
610
+ notes,
611
+ meta_path,
612
+ vocal_file,
613
+ rmvpe_model_path=rmvpe_model_path,
614
+ device=device,
615
+ )
616
+ print(f"Saved Meta to {meta_path}")
617
+
618
+
619
+ if __name__ == "__main__":
620
+ import argparse
621
+
622
+ parser = argparse.ArgumentParser(
623
+ description="Convert SoulX-Singer metadata JSON <-> MIDI."
624
+ )
625
+ parser.add_argument("--meta", type=str, help="Path to metadata JSON")
626
+ parser.add_argument("--midi", type=str, help="Path to MIDI file")
627
+ parser.add_argument("--vocal", type=str, help="Path to vocal wav (for midi2meta)")
628
+ parser.add_argument(
629
+ "--meta2midi",
630
+ action="store_true",
631
+ help="Convert meta -> midi (requires --meta and --midi)",
632
+ )
633
+ parser.add_argument(
634
+ "--midi2meta",
635
+ action="store_true",
636
+ help="Convert midi -> meta (requires --midi, --meta, --vocal, --cut_wavs_dir)",
637
+ )
638
+ parser.add_argument(
639
+ "--rmvpe_model_path",
640
+ type=str,
641
+ help="Path to RMVPE model",
642
+ default="pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt",
643
+ )
644
+ parser.add_argument(
645
+ "--device",
646
+ type=str,
647
+ help="Device to use for RMVPE",
648
+ default="cuda",
649
+ )
650
+ args = parser.parse_args()
651
+
652
+ if args.meta2midi:
653
+ if not args.meta or not args.midi:
654
+ parser.error("--meta2midi requires --meta and --midi")
655
+ meta2midi(args.meta, args.midi)
656
+ elif args.midi2meta:
657
+ if not args.midi or not args.meta or not args.vocal:
658
+ parser.error(
659
+ "--midi2meta requires --midi, --meta, --vocal"
660
+ )
661
+ midi2meta(
662
+ args.midi,
663
+ args.meta,
664
+ args.vocal,
665
+ rmvpe_model_path=args.rmvpe_model_path,
666
+ device=args.device,
667
+ )
668
+ else:
669
+ parser.print_help()
preprocess/tools/note_transcription/__init__.py ADDED
File without changes
preprocess/tools/note_transcription/model.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/RickyL-2000/ROSVOT
2
+ import math
3
+ import sys
4
+ import traceback
5
+ import json
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional
9
+
10
+ import librosa
11
+ import numpy as np
12
+ import torch
13
+ import matplotlib.pyplot as plt
14
+
15
+ from .utils.os_utils import safe_path
16
+ from .utils.commons.hparams import set_hparams
17
+ from .utils.commons.ckpt_utils import load_ckpt
18
+ from .utils.commons.dataset_utils import pad_or_cut_xd
19
+ from .utils.audio.mel import MelNet
20
+ from .utils.audio.pitch_utils import (
21
+ norm_interp_f0,
22
+ denorm_f0,
23
+ f0_to_coarse,
24
+ boundary2Interval,
25
+ save_midi,
26
+ midi_to_hz,
27
+ )
28
+ from .utils.rosvot_utils import (
29
+ get_mel_len,
30
+ align_word,
31
+ regulate_real_note_itv,
32
+ regulate_ill_slur,
33
+ bd_to_durs,
34
+ )
35
+ from .modules.pe.rmvpe import RMVPE
36
+ from .modules.rosvot.rosvot import MidiExtractor, WordbdExtractor
37
+
38
+
39
+ @torch.no_grad()
40
+ def infer_sample(
41
+ item: Dict[str, Any],
42
+ hparams: Dict[str, Any],
43
+ models: Dict[str, Any],
44
+ device: torch.device,
45
+ *,
46
+ save_dir: Optional[str] = None,
47
+ apply_rwbd: Optional[bool] = None,
48
+ # outputs
49
+ save_plot: bool = False,
50
+ no_save_midi: bool = True,
51
+ no_save_npy: bool = True,
52
+ verbose: bool = False,
53
+ ) -> Dict[str, Any]:
54
+ if "item_name" not in item or "wav_fn" not in item:
55
+ raise ValueError('item must contain keys: "item_name" and "wav_fn"')
56
+
57
+ item_name = item["item_name"]
58
+ wav_src = item["wav_fn"]
59
+
60
+ # Decide RWBD usage
61
+ if apply_rwbd is None:
62
+ apply_rwbd_ = ("word_durs" not in item)
63
+ else:
64
+ apply_rwbd_ = bool(apply_rwbd)
65
+
66
+ # Models
67
+ model = models["model"]
68
+ mel_net = models["mel_net"]
69
+ pe = models.get("pe")
70
+ wbd_predictor = models.get("wbd_predictor")
71
+
72
+ if wbd_predictor is None and apply_rwbd_:
73
+ raise ValueError("apply_rwbd is True but wbd_predictor model is not provided in models")
74
+
75
+ # ---- Prepare Data ----
76
+ if isinstance(wav_src, str):
77
+ wav, _ = librosa.core.load(wav_src, sr=hparams["audio_sample_rate"])
78
+ else:
79
+ wav = wav_src
80
+ if not isinstance(wav, np.ndarray):
81
+ wav = np.asarray(wav)
82
+ wav = wav.astype(np.float32)
83
+
84
+ # Calculate timestamps and alignment lengths
85
+ wav_len_samples = wav.shape[-1]
86
+ mel_len = get_mel_len(wav_len_samples, hparams["hop_size"])
87
+
88
+ # Word boundary preparation
89
+ mel2word = None
90
+ word_durs_filtered = None
91
+
92
+ if not apply_rwbd_:
93
+ if "word_durs" not in item:
94
+ raise ValueError('apply_rwbd=False but item has no "word_durs"')
95
+
96
+ wd_raw = list(item["word_durs"])
97
+ min_word_dur = hparams.get("min_word_dur", 20) / 1000
98
+ word_durs_filtered = []
99
+
100
+ for i, wd in enumerate(wd_raw):
101
+ if wd < min_word_dur:
102
+ if i == 0 and len(wd_raw) > 1:
103
+ wd_raw[i + 1] += wd
104
+ elif len(word_durs_filtered) > 0:
105
+ word_durs_filtered[-1] += wd
106
+ else:
107
+ word_durs_filtered.append(wd)
108
+
109
+ mel2word, _ = align_word(word_durs_filtered, mel_len, hparams["hop_size"], hparams["audio_sample_rate"])
110
+ mel2word = np.asarray(mel2word)
111
+ if mel2word.size > 0 and mel2word[0] == 0:
112
+ mel2word = mel2word + 1
113
+
114
+ mel2word_len = int(np.sum(mel2word > 0))
115
+ real_len = min(mel_len, mel2word_len)
116
+ else:
117
+ real_len = min(mel_len, hparams["max_frames"])
118
+
119
+ T = math.ceil(min(real_len, hparams["max_frames"]) / hparams["frames_multiple"]) * hparams["frames_multiple"]
120
+
121
+ # ---- Input Tensors & Padding ----
122
+ target_samples = T * hparams["hop_size"]
123
+ wav_t = torch.from_numpy(wav).float().to(device).unsqueeze(0) # [1, L]
124
+ if wav_t.shape[-1] < target_samples:
125
+ wav_t = pad_or_cut_xd(wav_t, target_samples, 1)
126
+
127
+ # ---- Pitch Extraction ----
128
+ if pe is not None:
129
+ f0s, uvs = pe.get_pitch_batch(
130
+ wav_t,
131
+ sample_rate=hparams["audio_sample_rate"],
132
+ hop_size=hparams["hop_size"],
133
+ lengths=[real_len],
134
+ fmax=hparams["f0_max"],
135
+ fmin=hparams["f0_min"],
136
+ )
137
+ f0_1d, uv_1d = norm_interp_f0(f0s[0][:T])
138
+ f0_t = pad_or_cut_xd(torch.FloatTensor(f0_1d).to(device), T, 0).unsqueeze(0)
139
+ uv_t = pad_or_cut_xd(torch.FloatTensor(uv_1d).to(device), T, 0).long().unsqueeze(0)
140
+ pitch_coarse = f0_to_coarse(denorm_f0(f0_t, uv_t)).to(device)
141
+ f0_np = denorm_f0(f0_t, uv_t)[0].detach().cpu().numpy()[:real_len]
142
+ else:
143
+ f0_t = uv_t = pitch_coarse = None
144
+ f0_np = None
145
+
146
+ # ---- Mel Extraction ----
147
+ mel = mel_net(wav_t) # [1, T_padded, C]
148
+ mel = pad_or_cut_xd(mel, T, 1)
149
+
150
+ # Construct non-padding mask
151
+ mel_nonpadding_mask = torch.zeros(1, T, device=device)
152
+ mel_nonpadding_mask[:, :real_len] = 1.0
153
+
154
+ # Apply mask to mel (zero out padding)
155
+ mel = (mel.transpose(1, 2) * mel_nonpadding_mask.unsqueeze(1)).transpose(1, 2)
156
+ # Re-calculate non_padding bool mask
157
+ mel_nonpadding = mel.abs().sum(-1) > 0
158
+
159
+ # ---- Word Boundary ----
160
+ word_durs_used = None
161
+ if apply_rwbd_:
162
+ mel_input = mel[:, :, : hparams.get("wbd_use_mel_bins", 80)]
163
+ wbd_outputs = wbd_predictor(
164
+ mel=mel_input,
165
+ pitch=pitch_coarse,
166
+ uv=uv_t,
167
+ non_padding=mel_nonpadding,
168
+ train=False,
169
+ )
170
+ word_bd = wbd_outputs["word_bd_pred"] # [1, T]
171
+ else:
172
+ # Construct word_bd from provided durs
173
+ mel2word_t = pad_or_cut_xd(torch.LongTensor(mel2word).to(device), T, 0)
174
+ word_bd = torch.zeros_like(mel2word_t)
175
+ # Vectorized check
176
+ word_bd[1:] = (mel2word_t[1:] != mel2word_t[:-1]).long()
177
+ word_bd[real_len:] = 0
178
+ word_bd = word_bd.unsqueeze(0) # [1, T]
179
+
180
+ word_durs_used = np.array(word_durs_filtered)
181
+
182
+ # ---- Main Inference ----
183
+ mel_input = mel[:, :, : hparams.get("use_mel_bins", 80)]
184
+ outputs = model(
185
+ mel=mel_input,
186
+ word_bd=word_bd,
187
+ pitch=pitch_coarse,
188
+ uv=uv_t,
189
+ non_padding=mel_nonpadding,
190
+ train=False,
191
+ )
192
+
193
+ note_lengths = outputs["note_lengths"].detach().cpu().numpy()
194
+ note_bd_pred = outputs["note_bd_pred"][0].detach().cpu().numpy()[:real_len]
195
+ note_pred = outputs["note_pred"][0].detach().cpu().numpy()[: note_lengths[0]]
196
+ note_bd_logits = torch.sigmoid(outputs["note_bd_logits"])[0].detach().cpu().numpy()[:real_len]
197
+
198
+ if note_pred.shape == (0,):
199
+ if verbose:
200
+ print(f"skip {item_name}: no notes detected")
201
+ return {
202
+ "item_name": item_name,
203
+ "pitches": [],
204
+ "note_durs": [],
205
+ "note2words": None,
206
+ }
207
+
208
+ # ---- Post-Processing & Regulation ----
209
+ note_itv_pred = boundary2Interval(note_bd_pred)
210
+ note2words = None
211
+
212
+ if apply_rwbd_:
213
+ word_bd_np = outputs['word_bd_pred'][0].detach().cpu().numpy()[:real_len]
214
+ word_durs_derived = np.array(bd_to_durs(word_bd_np)) * hparams['hop_size'] / hparams['audio_sample_rate']
215
+ word_durs_for_reg = word_durs_derived
216
+ word_bd_for_reg = word_bd_np
217
+ else:
218
+ word_bd_for_reg = word_bd[0].detach().cpu().numpy()[:real_len]
219
+ word_durs_for_reg = word_durs_used
220
+
221
+ should_regulate = hparams.get("infer_regulate_real_note_itv", True) and (not apply_rwbd_)
222
+
223
+ if should_regulate and (word_durs_for_reg is not None):
224
+ try:
225
+ note_itv_pred_secs, note2words = regulate_real_note_itv(
226
+ note_itv_pred,
227
+ note_bd_pred,
228
+ word_bd_for_reg,
229
+ word_durs_for_reg,
230
+ hparams["hop_size"],
231
+ hparams["audio_sample_rate"],
232
+ )
233
+ note_pred, note_itv_pred_secs, note2words = regulate_ill_slur(note_pred, note_itv_pred_secs, note2words)
234
+ except Exception as err:
235
+ if verbose:
236
+ _, exc_value, exc_tb = sys.exc_info()
237
+ tb = traceback.extract_tb(exc_tb)[-1]
238
+ print(f"postprocess failed: {err}: {exc_value} in {tb[0]}:{tb[1]} '{tb[2]}' in {tb[3]}")
239
+ # Fallback
240
+ note_itv_pred_secs = note_itv_pred * hparams["hop_size"] / hparams["audio_sample_rate"]
241
+ note2words = None
242
+ else:
243
+ note_itv_pred_secs = note_itv_pred * hparams["hop_size"] / hparams["audio_sample_rate"]
244
+
245
+ # ---- Output ----
246
+ note_durs = [float((itv[1] - itv[0])) for itv in note_itv_pred_secs]
247
+
248
+ out = {
249
+ "item_name": item_name,
250
+ "pitches": note_pred.tolist(),
251
+ "note_durs": note_durs,
252
+ "note2words": note2words.tolist() if note2words is not None else None,
253
+ }
254
+
255
+ # ---- Saving ----
256
+ if save_dir is not None:
257
+ save_dir_path = Path(save_dir)
258
+ save_dir_path.mkdir(parents=True, exist_ok=True)
259
+ fn = str(item_name)
260
+
261
+ if not no_save_midi:
262
+ save_midi(note_pred, note_itv_pred_secs, safe_path(save_dir_path / "midi" / f"{fn}.mid"))
263
+
264
+ if not no_save_npy:
265
+ np.save(safe_path(save_dir_path / "npy" / f"[note]{fn}.npy"), out, allow_pickle=True)
266
+
267
+ if save_plot:
268
+ fig = plt.figure()
269
+ if f0_np is not None:
270
+ plt.plot(f0_np, color="red", label="f0")
271
+
272
+ midi_pred = np.zeros(note_bd_pred.shape[0], dtype=np.float32)
273
+ itvs = np.round(note_itv_pred_secs * hparams["audio_sample_rate"] / hparams["hop_size"]).astype(int)
274
+ for i, itv in enumerate(itvs):
275
+ midi_pred[itv[0] : itv[1]] = note_pred[i]
276
+ plt.plot(midi_to_hz(midi_pred), color="blue", label="pred midi")
277
+ plt.plot(note_bd_logits * 100, color="green", label="note bd logits x100")
278
+ plt.legend()
279
+ plt.tight_layout()
280
+ plt.savefig(safe_path(save_dir_path / "plot" / f"[MIDI]{fn}.png"), format="png")
281
+ plt.close(fig)
282
+
283
+ return out
284
+
285
+
286
+ def load_rosvot_models(ckpt, config="", wbd_ckpt="", wbd_config="", device="cuda:0", verbose=False, thr=0.85):
287
+ """
288
+ Load models once to reuse across multiple items.
289
+ """
290
+ dev = torch.device(device)
291
+
292
+ # 1. Hparams
293
+ config_path = Path(ckpt).with_name("config.yaml") if config == "" else config
294
+ pe_ckpt = Path(ckpt).parent.parent / "rmvpe/model.pt"
295
+ hparams = set_hparams(
296
+ config=config_path,
297
+ print_hparams=verbose,
298
+ hparams_str=f"note_bd_threshold={thr}",
299
+ )
300
+
301
+ # 2. Main Model
302
+ model = MidiExtractor(hparams)
303
+ load_ckpt(model, ckpt, verbose=verbose)
304
+ model.eval().to(dev)
305
+
306
+ # 3. MelNet
307
+ mel_net = MelNet(hparams)
308
+ mel_net.to(dev)
309
+
310
+ # 4. Pitch Extractor
311
+ pe = None
312
+ if hparams.get("use_pitch_embed", False):
313
+ pe = RMVPE(pe_ckpt, device=dev)
314
+
315
+ # 5. Word Boundary Predictor (optional but we load if ckpt provided or needed)
316
+ wbd_predictor = None
317
+ if wbd_ckpt:
318
+ wbd_config_path = Path(wbd_ckpt).with_name("config.yaml") if wbd_config == "" else wbd_config
319
+ wbd_hparams = set_hparams(
320
+ config=wbd_config_path,
321
+ print_hparams=False,
322
+ hparams_str="",
323
+ )
324
+ hparams.update({
325
+ "wbd_use_mel_bins": wbd_hparams["use_mel_bins"],
326
+ "min_word_dur": wbd_hparams["min_word_dur"],
327
+ })
328
+ wbd_predictor = WordbdExtractor(wbd_hparams)
329
+ load_ckpt(wbd_predictor, wbd_ckpt, verbose=verbose)
330
+ wbd_predictor.eval().to(dev)
331
+
332
+ models = {
333
+ "model": model,
334
+ "mel_net": mel_net,
335
+ "pe": pe,
336
+ "wbd_predictor": wbd_predictor
337
+ }
338
+ return hparams, models
339
+
340
+
341
+ class NoteTranscriber:
342
+ """Note transcription wrapper based on ROSVOT.
343
+
344
+ Loads ROSVOT and optional RWBD models once in ``__init__`` and
345
+ exposes a :py:meth:`process` API that turns an item dict into
346
+ aligned note metadata for downstream SVS.
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ rosvot_model_path: str,
352
+ rwbd_model_path: str,
353
+ *,
354
+ rosvot_config_path: str = "",
355
+ rwbd_config_path: str = "",
356
+ device: str = "cuda:0",
357
+ thr: float = 0.85,
358
+ verbose: bool = True,
359
+ ):
360
+ """Initialize the note transcriber.
361
+
362
+ Args:
363
+ ckpt: Path to the main ROSVOT checkpoint.
364
+ config: Optional config YAML path for ROSVOT.
365
+ wbd_ckpt: Optional word-boundary checkpoint path.
366
+ wbd_config: Optional config YAML path for RWBD.
367
+ device: Torch device string, e.g. ``"cuda:0"`` / ``"cpu"``.
368
+ thr: Note boundary threshold.
369
+ verbose: Whether to print verbose logs.
370
+ """
371
+ self.verbose = verbose
372
+ self.device = torch.device(device)
373
+ self.hparams, self.models = load_rosvot_models(
374
+ ckpt=rosvot_model_path,
375
+ config=rosvot_config_path,
376
+ wbd_ckpt=rwbd_model_path,
377
+ wbd_config=rwbd_config_path,
378
+ device=device,
379
+ verbose=verbose,
380
+ thr=thr,
381
+ )
382
+
383
+ if self.verbose:
384
+ print(
385
+ "[note transcription] init success:",
386
+ f"device={self.device}",
387
+ f"rosvot_model_path={rosvot_model_path}",
388
+ f"rwbd_model_path={rwbd_model_path if rwbd_model_path else 'None'}",
389
+ f"thr={thr}",
390
+ )
391
+
392
+ def process(
393
+ self,
394
+ item: Dict[str, Any],
395
+ *,
396
+ segment_info: Optional[Dict[str, Any]] = None,
397
+ save_dir: Optional[str] = None,
398
+ apply_rwbd: Optional[bool] = None,
399
+ save_plot: bool = False,
400
+ no_save_midi: bool = True,
401
+ no_save_npy: bool = True,
402
+ verbose: Optional[bool] = None,
403
+ ) -> Dict[str, Any]:
404
+ """Run ROSVOT on a single item and post-process outputs.
405
+
406
+ Args:
407
+ item: Input metadata dict with at least ``item_name`` and ``wav_fn``.
408
+ segment_info: Optional segment metadata for sliced audio.
409
+ save_dir: Optional directory for debug artifacts (plots, midis).
410
+ apply_rwbd: Whether to run RWBD-based word boundary refinement.
411
+ save_plot: Whether to save diagnostic plots.
412
+ no_save_midi: If True, skip saving midi.
413
+ no_save_npy: If True, skip saving numpy intermediates.
414
+ verbose: Override instance-level verbose flag for this call.
415
+
416
+ Returns:
417
+ Dict with aligned note information for downstream SVS.
418
+ """
419
+ v = self.verbose if verbose is None else verbose
420
+ if v:
421
+ item_name = item.get("item_name", "")
422
+ wav_fn = item.get("wav_fn", "")
423
+ print(f"[note transcription] process: start: item_name={item_name} wav_fn={wav_fn}")
424
+ t0 = time.time()
425
+
426
+ rosvot_out = infer_sample(
427
+ item,
428
+ self.hparams,
429
+ self.models,
430
+ device=self.device,
431
+ save_dir=save_dir,
432
+ apply_rwbd=apply_rwbd,
433
+ save_plot=save_plot,
434
+ no_save_midi=no_save_midi,
435
+ no_save_npy=no_save_npy,
436
+ verbose=v,
437
+ )
438
+
439
+ out = self.post_process(
440
+ metadata=item,
441
+ segment_info=segment_info,
442
+ rosvot_out=rosvot_out,
443
+ )
444
+
445
+ if v:
446
+ dt = time.time() - t0
447
+ print(
448
+ "[note transcription] process: done:",
449
+ f"item_name={out.get('item_name','')}",
450
+ f"n_notes={len(out.get('note_pitch', []) or [])}",
451
+ f"time={dt:.3f}s",
452
+ )
453
+
454
+ return out
455
+
456
+ @staticmethod
457
+ def _normalize_note2words(note2words: list[int]) -> list[int]:
458
+ if not note2words:
459
+ return []
460
+ normalized = [note2words[0]]
461
+ for idx in range(1, len(note2words)):
462
+ if note2words[idx] < normalized[-1]:
463
+ normalized.append(normalized[-1])
464
+ else:
465
+ normalized.append(note2words[idx])
466
+ return normalized
467
+
468
+ @staticmethod
469
+ def _build_ep_types(note2words: list[int], align_words: list[str]) -> list[int]:
470
+ ep_types: list[int] = []
471
+ prev = -1
472
+ for i, w in zip(note2words, align_words):
473
+ if w == "<SP>":
474
+ ep_types.append(1)
475
+ else:
476
+ ep_types.append(2 if i != prev else 3)
477
+ prev = i
478
+ return ep_types
479
+
480
+ def post_process(
481
+ self,
482
+ *,
483
+ metadata: Dict[str, Any],
484
+ segment_info: Dict[str, Any],
485
+ rosvot_out: Dict[str, Any],
486
+ ) -> Dict[str, Any]:
487
+ """Build aligned note metadata using ROSVOT outputs."""
488
+ note2words_raw = rosvot_out.get("note2words") or []
489
+ note2words = self._normalize_note2words(note2words_raw)
490
+ align_words = [
491
+ metadata["words"][idx - 1]
492
+ for idx in note2words_raw
493
+ if 0 < idx <= len(metadata["words"])
494
+ ]
495
+ ep_types = self._build_ep_types(note2words, align_words) if align_words else []
496
+
497
+ return {
498
+ "item_name": rosvot_out.get("item_name", "") if not segment_info else segment_info["item_name"],
499
+ "wav_fn": metadata.get("wav_fn", "") if not segment_info else segment_info["wav_fn"],
500
+ "origin_wav_fn": metadata.get("origin_wav_fn", "") if not segment_info else segment_info["origin_wav_fn"],
501
+ "start_time_ms": "" if not segment_info else segment_info["start_time_ms"],
502
+ "end_time_ms": "" if not segment_info else segment_info["end_time_ms"],
503
+ "language": metadata.get("language", ""),
504
+ "note_text": align_words,
505
+ "note_dur": rosvot_out.get("note_durs", []),
506
+ "note_type": ep_types,
507
+ "note_pitch": rosvot_out.get("pitches", []),
508
+ }
509
+
510
+ if __name__ == "__main__":
511
+
512
+ items = json.load(open("example/test/rosvot_input.json", "r"))
513
+ item = items[0]
514
+
515
+ m = NoteTranscriber(
516
+ rosvot_model_path="pretrained_models/rosvot/rosvot/model.pt",
517
+ rwbd_model_path="pretrained_models/rosvot/rwbd/model.pt",
518
+ device="cuda"
519
+ )
520
+ out = m.process(item)
521
+
522
+ print(out)
preprocess/tools/note_transcription/modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """ROSVOT model submodules."""
preprocess/tools/note_transcription/modules/commons/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Common ROSVOT layers and utilities."""
preprocess/tools/note_transcription/modules/commons/conformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Conformer layers for ROSVOT."""
preprocess/tools/note_transcription/modules/commons/conformer/conformer.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from .espnet_positional_embedding import RelPositionalEncoding, ScaledPositionalEncoding, PositionalEncoding
3
+ from .espnet_transformer_attn import RelPositionMultiHeadedAttention, MultiHeadedAttention
4
+ from .layers import Swish, ConvolutionModule, EncoderLayer, MultiLayeredConv1d
5
+ from ..layers import Embedding
6
+
7
+
8
+ class ConformerLayers(nn.Module):
9
+ def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4,
10
+ use_last_norm=True, save_hidden=False):
11
+ super().__init__()
12
+ self.use_last_norm = use_last_norm
13
+ self.layers = nn.ModuleList()
14
+ positionwise_layer = MultiLayeredConv1d
15
+ positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
16
+ self.pos_embed = RelPositionalEncoding(hidden_size, dropout)
17
+ self.encoder_layers = nn.ModuleList([EncoderLayer(
18
+ hidden_size,
19
+ RelPositionMultiHeadedAttention(num_heads, hidden_size, 0.0),
20
+ positionwise_layer(*positionwise_layer_args),
21
+ positionwise_layer(*positionwise_layer_args),
22
+ ConvolutionModule(hidden_size, kernel_size, Swish()),
23
+ dropout,
24
+ ) for _ in range(num_layers)])
25
+ if self.use_last_norm:
26
+ self.layer_norm = nn.LayerNorm(hidden_size)
27
+ else:
28
+ self.layer_norm = nn.Linear(hidden_size, hidden_size)
29
+ self.save_hidden = save_hidden
30
+ if save_hidden:
31
+ self.hiddens = []
32
+
33
+ def forward(self, x, padding_mask=None):
34
+ """
35
+
36
+ :param x: [B, T, H]
37
+ :param padding_mask: [B, T]
38
+ :return: [B, T, H]
39
+ """
40
+ self.hiddens = []
41
+ nonpadding_mask = x.abs().sum(-1) > 0
42
+ x = self.pos_embed(x)
43
+ for l in self.encoder_layers:
44
+ x, mask = l(x, nonpadding_mask[:, None, :])
45
+ if self.save_hidden:
46
+ self.hiddens.append(x[0])
47
+ x = x[0]
48
+ x = self.layer_norm(x) * nonpadding_mask.float()[:, :, None]
49
+ return x
50
+
51
+ class FastConformerLayers(ConformerLayers):
52
+ def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4,
53
+ use_last_norm=True, save_hidden=False):
54
+ super(ConformerLayers, self).__init__()
55
+ self.use_last_norm = use_last_norm
56
+ self.layers = nn.ModuleList()
57
+ positionwise_layer = MultiLayeredConv1d
58
+ positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
59
+ self.pos_embed = PositionalEncoding(hidden_size, dropout)
60
+ self.encoder_layers = nn.ModuleList([EncoderLayer(
61
+ hidden_size,
62
+ MultiHeadedAttention(num_heads, hidden_size, 0.0, flash=True),
63
+ positionwise_layer(*positionwise_layer_args),
64
+ positionwise_layer(*positionwise_layer_args),
65
+ ConvolutionModule(hidden_size, kernel_size, Swish()),
66
+ dropout,
67
+ ) for _ in range(num_layers)])
68
+ if self.use_last_norm:
69
+ self.layer_norm = nn.LayerNorm(hidden_size)
70
+ else:
71
+ self.layer_norm = nn.Linear(hidden_size, hidden_size)
72
+ self.save_hidden = save_hidden
73
+ if save_hidden:
74
+ self.hiddens = []
75
+
76
+ class ConformerEncoder(ConformerLayers):
77
+ def __init__(self, hidden_size, dict_size, num_layers=None):
78
+ conformer_enc_kernel_size = 9
79
+ super().__init__(hidden_size, num_layers, conformer_enc_kernel_size)
80
+ self.embed = Embedding(dict_size, hidden_size, padding_idx=0)
81
+
82
+ def forward(self, x):
83
+ """
84
+
85
+ :param src_tokens: [B, T]
86
+ :return: [B x T x C]
87
+ """
88
+ x = self.embed(x) # [B, T, H]
89
+ x = super(ConformerEncoder, self).forward(x)
90
+ return x
91
+
92
+
93
+ class ConformerDecoder(ConformerLayers):
94
+ def __init__(self, hidden_size, num_layers):
95
+ conformer_dec_kernel_size = 9
96
+ super().__init__(hidden_size, num_layers, conformer_dec_kernel_size)
preprocess/tools/note_transcription/modules/commons/conformer/espnet_positional_embedding.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ class PositionalEncoding(torch.nn.Module):
6
+ """Positional encoding.
7
+ Args:
8
+ d_model (int): Embedding dimension.
9
+ dropout_rate (float): Dropout rate.
10
+ max_len (int): Maximum input length.
11
+ reverse (bool): Whether to reverse the input position.
12
+ """
13
+
14
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
15
+ """Construct an PositionalEncoding object."""
16
+ super(PositionalEncoding, self).__init__()
17
+ self.d_model = d_model
18
+ self.reverse = reverse
19
+ self.xscale = math.sqrt(self.d_model)
20
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
21
+ self.pe = None
22
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
23
+
24
+ def extend_pe(self, x):
25
+ """Reset the positional encodings."""
26
+ if self.pe is not None:
27
+ if self.pe.size(1) >= x.size(1):
28
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
29
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
30
+ return
31
+ pe = torch.zeros(x.size(1), self.d_model)
32
+ if self.reverse:
33
+ position = torch.arange(
34
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
35
+ ).unsqueeze(1)
36
+ else:
37
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
38
+ div_term = torch.exp(
39
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
40
+ * -(math.log(10000.0) / self.d_model)
41
+ )
42
+ pe[:, 0::2] = torch.sin(position * div_term)
43
+ pe[:, 1::2] = torch.cos(position * div_term)
44
+ pe = pe.unsqueeze(0)
45
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ """Add positional encoding.
49
+ Args:
50
+ x (torch.Tensor): Input tensor (batch, time, `*`).
51
+ Returns:
52
+ torch.Tensor: Encoded tensor (batch, time, `*`).
53
+ """
54
+ self.extend_pe(x)
55
+ x = x * self.xscale + self.pe[:, : x.size(1)]
56
+ return self.dropout(x)
57
+
58
+
59
+ class ScaledPositionalEncoding(PositionalEncoding):
60
+ """Scaled positional encoding module.
61
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
62
+ Args:
63
+ d_model (int): Embedding dimension.
64
+ dropout_rate (float): Dropout rate.
65
+ max_len (int): Maximum input length.
66
+ """
67
+
68
+ def __init__(self, d_model, dropout_rate, max_len=5000):
69
+ """Initialize class."""
70
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
71
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
72
+
73
+ def reset_parameters(self):
74
+ """Reset parameters."""
75
+ self.alpha.data = torch.tensor(1.0)
76
+
77
+ def forward(self, x):
78
+ """Add positional encoding.
79
+ Args:
80
+ x (torch.Tensor): Input tensor (batch, time, `*`).
81
+ Returns:
82
+ torch.Tensor: Encoded tensor (batch, time, `*`).
83
+ """
84
+ self.extend_pe(x)
85
+ x = x + self.alpha * self.pe[:, : x.size(1)]
86
+ return self.dropout(x)
87
+
88
+
89
+ class RelPositionalEncoding(PositionalEncoding):
90
+ """Relative positional encoding module.
91
+ See : Appendix B in https://arxiv.org/abs/1901.02860
92
+ Args:
93
+ d_model (int): Embedding dimension.
94
+ dropout_rate (float): Dropout rate.
95
+ max_len (int): Maximum input length.
96
+ """
97
+
98
+ def __init__(self, d_model, dropout_rate, max_len=5000):
99
+ """Initialize class."""
100
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
101
+
102
+ def forward(self, x):
103
+ """Compute positional encoding.
104
+ Args:
105
+ x (torch.Tensor): Input tensor (batch, time, `*`).
106
+ Returns:
107
+ torch.Tensor: Encoded tensor (batch, time, `*`).
108
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
109
+ """
110
+ self.extend_pe(x)
111
+ x = x * self.xscale
112
+ pos_emb = self.pe[:, : x.size(1)]
113
+ return self.dropout(x), self.dropout(pos_emb)
preprocess/tools/note_transcription/modules/commons/conformer/espnet_transformer_attn.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Multi-Head Attention layer definition."""
8
+
9
+ from packaging import version
10
+ import math
11
+
12
+ import numpy
13
+ import torch
14
+ from torch import nn
15
+
16
+
17
+ class MultiHeadedAttention(nn.Module):
18
+ """Multi-Head Attention layer.
19
+ Args:
20
+ n_head (int): The number of heads.
21
+ n_feat (int): The number of features.
22
+ dropout_rate (float): Dropout rate.
23
+ """
24
+
25
+ def __init__(self, n_head, n_feat, dropout_rate, flash=False):
26
+ """Construct an MultiHeadedAttention object."""
27
+ super(MultiHeadedAttention, self).__init__()
28
+ assert n_feat % n_head == 0
29
+ # We assume d_v always equals d_k
30
+ self.d_k = n_feat // n_head
31
+ self.h = n_head
32
+ self.linear_q = nn.Linear(n_feat, n_feat)
33
+ self.linear_k = nn.Linear(n_feat, n_feat)
34
+ self.linear_v = nn.Linear(n_feat, n_feat)
35
+ self.linear_out = nn.Linear(n_feat, n_feat)
36
+ self.attn = None
37
+ self.dropout = nn.Dropout(p=dropout_rate)
38
+ self.dropout_rate = dropout_rate
39
+ self.flash = flash
40
+
41
+ def forward_qkv(self, query, key, value):
42
+ """Transform query, key and value.
43
+ Args:
44
+ query (torch.Tensor): Query tensor (#batch, time1, size).
45
+ key (torch.Tensor): Key tensor (#batch, time2, size).
46
+ value (torch.Tensor): Value tensor (#batch, time2, size).
47
+ Returns:
48
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
49
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
50
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
51
+ """
52
+ n_batch = query.size(0)
53
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
54
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
55
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
56
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
57
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
58
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
59
+
60
+ return q, k, v
61
+
62
+ def forward_attention(self, value, scores, mask):
63
+ """Compute attention context vector.
64
+ Args:
65
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
66
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
67
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
68
+ Returns:
69
+ torch.Tensor: Transformed value (#batch, time1, d_model)
70
+ weighted by the attention score (#batch, time1, time2).
71
+ """
72
+ n_batch = value.size(0)
73
+ if mask is not None:
74
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
75
+ min_value = float(
76
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
77
+ )
78
+ scores = scores.masked_fill(mask, min_value)
79
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
80
+ mask, 0.0
81
+ ) # (batch, head, time1, time2)
82
+ else:
83
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
84
+
85
+ p_attn = self.dropout(self.attn)
86
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
87
+ x = (
88
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
89
+ ) # (batch, time1, d_model)
90
+
91
+ return self.linear_out(x) # (batch, time1, d_model)
92
+
93
+ def forward(self, query, key, value, mask):
94
+ """Compute scaled dot product attention.
95
+ Args:
96
+ query (torch.Tensor): Query tensor (#batch, time1, size).
97
+ key (torch.Tensor): Key tensor (#batch, time2, size).
98
+ value (torch.Tensor): Value tensor (#batch, time2, size).
99
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
100
+ (#batch, time1, time2).
101
+ Returns:
102
+ torch.Tensor: Output tensor (#batch, time1, d_model).
103
+ """
104
+ q, k, v = self.forward_qkv(query, key, value)
105
+ if version.parse(torch.__version__) >= version.parse("2.0") and self.flash:
106
+ n_batch = value.size(0)
107
+ x = torch.nn.functional.scaled_dot_product_attention(
108
+ q, k, v, attn_mask=mask.unsqueeze(1) if mask is not None else None, dropout_p=self.dropout_rate)
109
+ x = (
110
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
111
+ ) # (batch, time1, d_model)
112
+ return self.linear_out(x)
113
+ else:
114
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
115
+ return self.forward_attention(v, scores, mask)
116
+
117
+
118
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
119
+ """Multi-Head Attention layer with relative position encoding.
120
+ Paper: https://arxiv.org/abs/1901.02860
121
+ Args:
122
+ n_head (int): The number of heads.
123
+ n_feat (int): The number of features.
124
+ dropout_rate (float): Dropout rate.
125
+ """
126
+
127
+ def __init__(self, n_head, n_feat, dropout_rate):
128
+ """Construct an RelPositionMultiHeadedAttention object."""
129
+ super().__init__(n_head, n_feat, dropout_rate)
130
+ # linear transformation for positional ecoding
131
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
132
+ # these two learnable bias are used in matrix c and matrix d
133
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
134
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
135
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
136
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
137
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
138
+
139
+ def rel_shift(self, x, zero_triu=False):
140
+ """Compute relative positinal encoding.
141
+ Args:
142
+ x (torch.Tensor): Input tensor (batch, time, size).
143
+ zero_triu (bool): If true, return the lower triangular part of the matrix.
144
+ Returns:
145
+ torch.Tensor: Output tensor.
146
+ """
147
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
148
+ x_padded = torch.cat([zero_pad, x], dim=-1)
149
+
150
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
151
+ x = x_padded[:, :, 1:].view_as(x)
152
+
153
+ if zero_triu:
154
+ ones = torch.ones((x.size(2), x.size(3)))
155
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
156
+
157
+ return x
158
+
159
+ def forward(self, query, key, value, pos_emb, mask):
160
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
161
+ Args:
162
+ query (torch.Tensor): Query tensor (#batch, time1, size).
163
+ key (torch.Tensor): Key tensor (#batch, time2, size).
164
+ value (torch.Tensor): Value tensor (#batch, time2, size).
165
+ pos_emb (torch.Tensor): Positional embedding tensor (#batch, time2, size).
166
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
167
+ (#batch, time1, time2).
168
+ Returns:
169
+ torch.Tensor: Output tensor (#batch, time1, d_model).
170
+ """
171
+ q, k, v = self.forward_qkv(query, key, value)
172
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
173
+
174
+ n_batch_pos = pos_emb.size(0)
175
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
176
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
177
+
178
+ # (batch, head, time1, d_k)
179
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
180
+ # (batch, head, time1, d_k)
181
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
182
+
183
+ # compute attention score
184
+ # first compute matrix a and matrix c
185
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
186
+ # (batch, head, time1, time2)
187
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
188
+
189
+ # compute matrix b and matrix d
190
+ # (batch, head, time1, time2)
191
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
192
+ matrix_bd = self.rel_shift(matrix_bd)
193
+
194
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
195
+ self.d_k
196
+ ) # (batch, head, time1, time2)
197
+
198
+ return self.forward_attention(v, scores, mask)
preprocess/tools/note_transcription/modules/commons/conformer/layers.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+
4
+ from ..layers import LayerNorm
5
+
6
+
7
+ class ConvolutionModule(nn.Module):
8
+ """ConvolutionModule in Conformer model.
9
+ Args:
10
+ channels (int): The number of channels of conv layers.
11
+ kernel_size (int): Kernerl size of conv layers.
12
+ """
13
+
14
+ def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
15
+ """Construct an ConvolutionModule object."""
16
+ super(ConvolutionModule, self).__init__()
17
+ # kernerl_size should be a odd number for 'SAME' padding
18
+ assert (kernel_size - 1) % 2 == 0
19
+
20
+ self.pointwise_conv1 = nn.Conv1d(
21
+ channels,
22
+ 2 * channels,
23
+ kernel_size=1,
24
+ stride=1,
25
+ padding=0,
26
+ bias=bias,
27
+ )
28
+ self.depthwise_conv = nn.Conv1d(
29
+ channels,
30
+ channels,
31
+ kernel_size,
32
+ stride=1,
33
+ padding=(kernel_size - 1) // 2,
34
+ groups=channels,
35
+ bias=bias,
36
+ )
37
+ self.norm = nn.BatchNorm1d(channels)
38
+ self.pointwise_conv2 = nn.Conv1d(
39
+ channels,
40
+ channels,
41
+ kernel_size=1,
42
+ stride=1,
43
+ padding=0,
44
+ bias=bias,
45
+ )
46
+ self.activation = activation
47
+
48
+ def forward(self, x):
49
+ """Compute convolution module.
50
+ Args:
51
+ x (torch.Tensor): Input tensor (#batch, time, channels).
52
+ Returns:
53
+ torch.Tensor: Output tensor (#batch, time, channels).
54
+ """
55
+ # exchange the temporal dimension and the feature dimension
56
+ x = x.transpose(1, 2)
57
+
58
+ # GLU mechanism
59
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
60
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
61
+
62
+ # 1D Depthwise Conv
63
+ x = self.depthwise_conv(x)
64
+ x = self.activation(self.norm(x))
65
+
66
+ x = self.pointwise_conv2(x)
67
+
68
+ return x.transpose(1, 2)
69
+
70
+
71
+ class MultiLayeredConv1d(torch.nn.Module):
72
+ """Multi-layered conv1d for Transformer block.
73
+ This is a module of multi-leyered conv1d designed
74
+ to replace positionwise feed-forward network
75
+ in Transforner block, which is introduced in
76
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
77
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
78
+ https://arxiv.org/pdf/1905.09263.pdf
79
+ """
80
+
81
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
82
+ """Initialize MultiLayeredConv1d module.
83
+ Args:
84
+ in_chans (int): Number of input channels.
85
+ hidden_chans (int): Number of hidden channels.
86
+ kernel_size (int): Kernel size of conv1d.
87
+ dropout_rate (float): Dropout rate.
88
+ """
89
+ super(MultiLayeredConv1d, self).__init__()
90
+ self.w_1 = torch.nn.Conv1d(
91
+ in_chans,
92
+ hidden_chans,
93
+ kernel_size,
94
+ stride=1,
95
+ padding=(kernel_size - 1) // 2,
96
+ )
97
+ self.w_2 = torch.nn.Conv1d(
98
+ hidden_chans,
99
+ in_chans,
100
+ kernel_size,
101
+ stride=1,
102
+ padding=(kernel_size - 1) // 2,
103
+ )
104
+ self.dropout = torch.nn.Dropout(dropout_rate)
105
+
106
+ def forward(self, x):
107
+ """Calculate forward propagation.
108
+ Args:
109
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
110
+ Returns:
111
+ torch.Tensor: Batch of output tensors (B, T, hidden_chans).
112
+ """
113
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
114
+ return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
115
+
116
+
117
+ class Swish(torch.nn.Module):
118
+ """Construct an Swish object."""
119
+
120
+ def forward(self, x):
121
+ """Return Swich activation function."""
122
+ return x * torch.sigmoid(x)
123
+
124
+
125
+ class EncoderLayer(nn.Module):
126
+ """Encoder layer module.
127
+ Args:
128
+ size (int): Input dimension.
129
+ self_attn (torch.nn.Module): Self-attention module instance.
130
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
131
+ can be used as the argument.
132
+ feed_forward (torch.nn.Module): Feed-forward module instance.
133
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
134
+ can be used as the argument.
135
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
136
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
137
+ can be used as the argument.
138
+ conv_module (torch.nn.Module): Convolution module instance.
139
+ `ConvlutionModule` instance can be used as the argument.
140
+ dropout_rate (float): Dropout rate.
141
+ normalize_before (bool): Whether to use layer_norm before the first block.
142
+ concat_after (bool): Whether to concat attention layer's input and output.
143
+ if True, additional linear will be applied.
144
+ i.e. x -> x + linear(concat(x, att(x)))
145
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ size,
151
+ self_attn,
152
+ feed_forward,
153
+ feed_forward_macaron,
154
+ conv_module,
155
+ dropout_rate,
156
+ normalize_before=True,
157
+ concat_after=False,
158
+ ):
159
+ """Construct an EncoderLayer object."""
160
+ super(EncoderLayer, self).__init__()
161
+ self.self_attn = self_attn
162
+ self.feed_forward = feed_forward
163
+ self.feed_forward_macaron = feed_forward_macaron
164
+ self.conv_module = conv_module
165
+ self.norm_ff = LayerNorm(size) # for the FNN module
166
+ self.norm_mha = LayerNorm(size) # for the MHA module
167
+ if feed_forward_macaron is not None:
168
+ self.norm_ff_macaron = LayerNorm(size)
169
+ self.ff_scale = 0.5
170
+ else:
171
+ self.ff_scale = 1.0
172
+ if self.conv_module is not None:
173
+ self.norm_conv = LayerNorm(size) # for the CNN module
174
+ self.norm_final = LayerNorm(size) # for the final output of the block
175
+ self.dropout = nn.Dropout(dropout_rate)
176
+ self.size = size
177
+ self.normalize_before = normalize_before
178
+ self.concat_after = concat_after
179
+ if self.concat_after:
180
+ self.concat_linear = nn.Linear(size + size, size)
181
+
182
+ def forward(self, x_input, mask, cache=None):
183
+ """Compute encoded features.
184
+ Args:
185
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
186
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
187
+ - w/o pos emb: Tensor (#batch, time, size).
188
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
189
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
190
+ Returns:
191
+ torch.Tensor: Output tensor (#batch, time, size).
192
+ torch.Tensor: Mask tensor (#batch, time).
193
+ """
194
+ if isinstance(x_input, tuple):
195
+ x, pos_emb = x_input[0], x_input[1]
196
+ else:
197
+ x, pos_emb = x_input, None
198
+
199
+ # whether to use macaron style
200
+ if self.feed_forward_macaron is not None:
201
+ residual = x
202
+ if self.normalize_before:
203
+ x = self.norm_ff_macaron(x)
204
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
205
+ if not self.normalize_before:
206
+ x = self.norm_ff_macaron(x)
207
+
208
+ # multi-headed self-attention module
209
+ residual = x
210
+ if self.normalize_before:
211
+ x = self.norm_mha(x)
212
+
213
+ if cache is None:
214
+ x_q = x
215
+ else:
216
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
217
+ x_q = x[:, -1:, :]
218
+ residual = residual[:, -1:, :]
219
+ mask = None if mask is None else mask[:, -1:, :]
220
+
221
+ if pos_emb is not None:
222
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
223
+ else:
224
+ x_att = self.self_attn(x_q, x, x, mask)
225
+
226
+ if self.concat_after:
227
+ x_concat = torch.cat((x, x_att), dim=-1)
228
+ x = residual + self.concat_linear(x_concat)
229
+ else:
230
+ x = residual + self.dropout(x_att)
231
+ if not self.normalize_before:
232
+ x = self.norm_mha(x)
233
+
234
+ # convolution module
235
+ if self.conv_module is not None:
236
+ residual = x
237
+ if self.normalize_before:
238
+ x = self.norm_conv(x)
239
+ x = residual + self.dropout(self.conv_module(x))
240
+ if not self.normalize_before:
241
+ x = self.norm_conv(x)
242
+
243
+ # feed forward module
244
+ residual = x
245
+ if self.normalize_before:
246
+ x = self.norm_ff(x)
247
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
248
+ if not self.normalize_before:
249
+ x = self.norm_ff(x)
250
+
251
+ if self.conv_module is not None:
252
+ x = self.norm_final(x)
253
+
254
+ if cache is not None:
255
+ x = torch.cat([cache, x], dim=1)
256
+
257
+ if pos_emb is not None:
258
+ return (x, pos_emb), mask
259
+
260
+ return x, mask
preprocess/tools/note_transcription/modules/commons/conv.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .layers import LayerNorm, Embedding
7
+
8
+ class LambdaLayer(nn.Module):
9
+ def __init__(self, lambd):
10
+ super(LambdaLayer, self).__init__()
11
+ self.lambd = lambd
12
+
13
+ def forward(self, x):
14
+ return self.lambd(x)
15
+
16
+ def init_weights_func(m):
17
+ classname = m.__class__.__name__
18
+ if classname.find("Conv1d") != -1:
19
+ torch.nn.init.xavier_uniform_(m.weight)
20
+
21
+ def get_norm_builder(norm_type, channels, ln_eps=1e-6):
22
+ if norm_type == 'bn':
23
+ norm_builder = lambda: nn.BatchNorm1d(channels)
24
+ elif norm_type == 'in':
25
+ norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
26
+ elif norm_type == 'gn':
27
+ norm_builder = lambda: nn.GroupNorm(8, channels)
28
+ elif norm_type == 'ln':
29
+ norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
30
+ else:
31
+ norm_builder = lambda: nn.Identity()
32
+ return norm_builder
33
+
34
+ def get_act_builder(act_type):
35
+ if act_type == 'gelu':
36
+ act_builder = lambda: nn.GELU()
37
+ elif act_type == 'relu':
38
+ act_builder = lambda: nn.ReLU(inplace=True)
39
+ elif act_type == 'leakyrelu':
40
+ act_builder = lambda: nn.LeakyReLU(negative_slope=0.01, inplace=True)
41
+ elif act_type == 'swish':
42
+ act_builder = lambda: nn.SiLU(inplace=True)
43
+ else:
44
+ act_builder = lambda: nn.Identity()
45
+ return act_builder
46
+
47
+ class ResidualBlock(nn.Module):
48
+ """Implements conv->PReLU->norm n-times"""
49
+
50
+ def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
51
+ c_multiple=2, ln_eps=1e-12, act_type='gelu'):
52
+ super(ResidualBlock, self).__init__()
53
+
54
+ norm_builder = get_norm_builder(norm_type, channels, ln_eps)
55
+ act_builder = get_act_builder(act_type)
56
+
57
+ self.blocks = [
58
+ nn.Sequential(
59
+ norm_builder(),
60
+ nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
61
+ padding=(dilation * (kernel_size - 1)) // 2),
62
+ LambdaLayer(lambda x: x * kernel_size ** -0.5),
63
+ act_builder(),
64
+ nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation),
65
+ )
66
+ for i in range(n)
67
+ ]
68
+
69
+ self.blocks = nn.ModuleList(self.blocks)
70
+ self.dropout = dropout
71
+
72
+ def forward(self, x):
73
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
74
+ for b in self.blocks:
75
+ x_ = b(x)
76
+ if self.dropout > 0 and self.training:
77
+ x_ = F.dropout(x_, self.dropout, training=self.training)
78
+ x = x + x_
79
+ x = x * nonpadding
80
+ return x
81
+
82
+
83
+ class ConvBlocks(nn.Module):
84
+ """Decodes the expanded phoneme encoding into spectrograms"""
85
+
86
+ def __init__(self, hidden_size, out_dims, dilations, kernel_size,
87
+ norm_type='ln', layers_in_block=2, c_multiple=2,
88
+ dropout=0.0, ln_eps=1e-5,
89
+ init_weights=True, is_BTC=True, num_layers=None, post_net_kernel=3, act_type='gelu'):
90
+ super(ConvBlocks, self).__init__()
91
+ self.is_BTC = is_BTC
92
+ if num_layers is not None:
93
+ dilations = [1] * num_layers
94
+ self.res_blocks = nn.Sequential(
95
+ *[ResidualBlock(hidden_size, kernel_size, d,
96
+ n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
97
+ dropout=dropout, ln_eps=ln_eps, act_type=act_type)
98
+ for d in dilations],
99
+ )
100
+ norm = get_norm_builder(norm_type, hidden_size, ln_eps)()
101
+ self.last_norm = norm
102
+ self.post_net1 = nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel,
103
+ padding=post_net_kernel // 2)
104
+ if init_weights:
105
+ self.apply(init_weights_func)
106
+
107
+ def forward(self, x, nonpadding=None):
108
+ """
109
+
110
+ :param x: [B, T, H]
111
+ :return: [B, T, H]
112
+ """
113
+ if self.is_BTC:
114
+ x = x.transpose(1, 2)
115
+ if nonpadding is None:
116
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
117
+ elif self.is_BTC:
118
+ nonpadding = nonpadding.transpose(1, 2)
119
+ x = self.res_blocks(x) * nonpadding
120
+ x = self.last_norm(x) * nonpadding
121
+ x = self.post_net1(x) * nonpadding
122
+ if self.is_BTC:
123
+ x = x.transpose(1, 2)
124
+ return x
125
+
126
+
127
+ class TextConvEncoder(ConvBlocks):
128
+ def __init__(self, dict_size, hidden_size, out_dims, dilations, kernel_size,
129
+ norm_type='ln', layers_in_block=2, c_multiple=2,
130
+ dropout=0.0, ln_eps=1e-5, init_weights=True, num_layers=None, post_net_kernel=3):
131
+ super().__init__(hidden_size, out_dims, dilations, kernel_size,
132
+ norm_type, layers_in_block, c_multiple,
133
+ dropout, ln_eps, init_weights, num_layers=num_layers,
134
+ post_net_kernel=post_net_kernel)
135
+ self.embed_tokens = Embedding(dict_size, hidden_size, 0)
136
+ self.embed_scale = math.sqrt(hidden_size)
137
+
138
+ def forward(self, txt_tokens):
139
+ """
140
+
141
+ :param txt_tokens: [B, T]
142
+ :return: {
143
+ 'encoder_out': [B x T x C]
144
+ }
145
+ """
146
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
147
+ return super().forward(x)
148
+
149
+
150
+ class ConditionalConvBlocks(ConvBlocks):
151
+ def __init__(self, hidden_size, c_cond, c_out, dilations, kernel_size,
152
+ norm_type='ln', layers_in_block=2, c_multiple=2,
153
+ dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, num_layers=None):
154
+ super().__init__(hidden_size, c_out, dilations, kernel_size,
155
+ norm_type, layers_in_block, c_multiple,
156
+ dropout, ln_eps, init_weights, is_BTC=False, num_layers=num_layers)
157
+ self.g_prenet = nn.Conv1d(c_cond, hidden_size, 3, padding=1)
158
+ self.is_BTC_ = is_BTC
159
+ if init_weights:
160
+ self.g_prenet.apply(init_weights_func)
161
+
162
+ def forward(self, x, cond, nonpadding=None):
163
+ if self.is_BTC_:
164
+ x = x.transpose(1, 2)
165
+ cond = cond.transpose(1, 2)
166
+ if nonpadding is not None:
167
+ nonpadding = nonpadding.transpose(1, 2)
168
+ if nonpadding is None:
169
+ nonpadding = x.abs().sum(1)[:, None]
170
+ x = x + self.g_prenet(cond)
171
+ x = x * nonpadding
172
+ x = super(ConditionalConvBlocks, self).forward(x) # input needs to be BTC
173
+ if self.is_BTC_:
174
+ x = x.transpose(1, 2)
175
+ return x
preprocess/tools/note_transcription/modules/commons/layers.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.autograd import Function
4
+
5
+ class LayerNorm(torch.nn.LayerNorm):
6
+ """Layer normalization module.
7
+ :param int nout: output dim size
8
+ :param int dim: dimension to be normalized
9
+ """
10
+
11
+ def __init__(self, nout, dim=-1, eps=1e-5):
12
+ """Construct an LayerNorm object."""
13
+ super(LayerNorm, self).__init__(nout, eps=eps)
14
+ self.dim = dim
15
+
16
+ def forward(self, x):
17
+ """Apply layer normalization.
18
+ :param torch.Tensor x: input tensor
19
+ :return: layer normalized tensor
20
+ :rtype torch.Tensor
21
+ """
22
+ if self.dim == -1:
23
+ return super(LayerNorm, self).forward(x)
24
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
25
+
26
+
27
+ class Reshape(nn.Module):
28
+ def __init__(self, *args):
29
+ super(Reshape, self).__init__()
30
+ self.shape = args
31
+
32
+ def forward(self, x):
33
+ return x.view(self.shape)
34
+
35
+
36
+ class Permute(nn.Module):
37
+ def __init__(self, *args):
38
+ super(Permute, self).__init__()
39
+ self.args = args
40
+
41
+ def forward(self, x):
42
+ return x.permute(self.args)
43
+
44
+
45
+ def Linear(in_features, out_features, bias=True, init_type='xavier'):
46
+ m = nn.Linear(in_features, out_features, bias)
47
+ if init_type == 'xavier':
48
+ nn.init.xavier_uniform_(m.weight)
49
+ elif init_type == 'kaiming':
50
+ nn.init.kaiming_normal_(m.weight, mode='fan_in')
51
+ if bias:
52
+ nn.init.constant_(m.bias, 0.)
53
+ return m
54
+
55
+
56
+ def Embedding(num_embeddings, embedding_dim, padding_idx=None, init_type='normal'):
57
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
58
+ if init_type == 'normal':
59
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
60
+ elif init_type == 'kaiming':
61
+ nn.init.kaiming_normal_(m.weight, mode='fan_in')
62
+ if padding_idx is not None:
63
+ nn.init.constant_(m.weight[padding_idx], 0)
64
+ return m
65
+
66
+
67
+ class GradientReverseFunction(Function):
68
+ @staticmethod
69
+ def forward(ctx, input, coeff=1.):
70
+ ctx.coeff = coeff
71
+ output = input * 1.0
72
+ return output
73
+
74
+ @staticmethod
75
+ def backward(ctx, grad_output):
76
+ return grad_output.neg() * ctx.coeff, None
77
+
78
+
79
+ class GRL(nn.Module):
80
+ def __init__(self):
81
+ super(GRL, self).__init__()
82
+
83
+ def forward(self, *input):
84
+ return GradientReverseFunction.apply(*input)
85
+
preprocess/tools/note_transcription/modules/commons/rel_transformer.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from .layers import Embedding
7
+
8
+
9
+ def convert_pad_shape(pad_shape):
10
+ l = pad_shape[::-1]
11
+ pad_shape = [item for sublist in l for item in sublist]
12
+ return pad_shape
13
+
14
+
15
+ def shift_1d(x):
16
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
17
+ return x
18
+
19
+
20
+ def sequence_mask(length, max_length=None):
21
+ if max_length is None:
22
+ max_length = length.max()
23
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
24
+ return x.unsqueeze(0) < length.unsqueeze(1)
25
+
26
+
27
+ class Encoder(nn.Module):
28
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
29
+ window_size=None, block_length=None, pre_ln=False, **kwargs):
30
+ super().__init__()
31
+ self.hidden_channels = hidden_channels
32
+ self.filter_channels = filter_channels
33
+ self.n_heads = n_heads
34
+ self.n_layers = n_layers
35
+ self.kernel_size = kernel_size
36
+ self.p_dropout = p_dropout
37
+ self.window_size = window_size
38
+ self.block_length = block_length
39
+ self.pre_ln = pre_ln
40
+
41
+ self.drop = nn.Dropout(p_dropout)
42
+ self.attn_layers = nn.ModuleList()
43
+ self.norm_layers_1 = nn.ModuleList()
44
+ self.ffn_layers = nn.ModuleList()
45
+ self.norm_layers_2 = nn.ModuleList()
46
+ for i in range(self.n_layers):
47
+ self.attn_layers.append(
48
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
49
+ p_dropout=p_dropout, block_length=block_length))
50
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
51
+ self.ffn_layers.append(
52
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
53
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
54
+ if pre_ln:
55
+ self.last_ln = LayerNorm(hidden_channels)
56
+
57
+ def forward(self, x, x_mask):
58
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
59
+ for i in range(self.n_layers):
60
+ x = x * x_mask
61
+ x_ = x
62
+ if self.pre_ln:
63
+ x = self.norm_layers_1[i](x)
64
+ y = self.attn_layers[i](x, x, attn_mask)
65
+ y = self.drop(y)
66
+ x = x_ + y
67
+ if not self.pre_ln:
68
+ x = self.norm_layers_1[i](x)
69
+
70
+ x_ = x
71
+ if self.pre_ln:
72
+ x = self.norm_layers_2[i](x)
73
+ y = self.ffn_layers[i](x, x_mask)
74
+ y = self.drop(y)
75
+ x = x_ + y
76
+ if not self.pre_ln:
77
+ x = self.norm_layers_2[i](x)
78
+ if self.pre_ln:
79
+ x = self.last_ln(x)
80
+ x = x * x_mask
81
+ return x
82
+
83
+
84
+ class MultiHeadAttention(nn.Module):
85
+ def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
86
+ block_length=None, proximal_bias=False, proximal_init=False):
87
+ super().__init__()
88
+ assert channels % n_heads == 0
89
+
90
+ self.channels = channels
91
+ self.out_channels = out_channels
92
+ self.n_heads = n_heads
93
+ self.window_size = window_size
94
+ self.heads_share = heads_share
95
+ self.block_length = block_length
96
+ self.proximal_bias = proximal_bias
97
+ self.p_dropout = p_dropout
98
+ self.attn = None
99
+
100
+ self.k_channels = channels // n_heads
101
+ self.conv_q = nn.Conv1d(channels, channels, 1)
102
+ self.conv_k = nn.Conv1d(channels, channels, 1)
103
+ self.conv_v = nn.Conv1d(channels, channels, 1)
104
+ if window_size is not None:
105
+ n_heads_rel = 1 if heads_share else n_heads
106
+ rel_stddev = self.k_channels ** -0.5
107
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
108
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
109
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
110
+ self.drop = nn.Dropout(p_dropout)
111
+
112
+ nn.init.xavier_uniform_(self.conv_q.weight)
113
+ nn.init.xavier_uniform_(self.conv_k.weight)
114
+ if proximal_init:
115
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
116
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
117
+ nn.init.xavier_uniform_(self.conv_v.weight)
118
+
119
+ def forward(self, x, c, attn_mask=None):
120
+ q = self.conv_q(x)
121
+ k = self.conv_k(c)
122
+ v = self.conv_v(c)
123
+
124
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
125
+
126
+ x = self.conv_o(x)
127
+ return x
128
+
129
+ def attention(self, query, key, value, mask=None):
130
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
131
+ b, d, t_s, t_t = (*key.size(), query.size(2))
132
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
133
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
134
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
135
+
136
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
137
+ if self.window_size is not None:
138
+ assert t_s == t_t, "Relative attention is only available for self-attention."
139
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
140
+ rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
141
+ rel_logits = self._relative_position_to_absolute_position(rel_logits)
142
+ scores_local = rel_logits / math.sqrt(self.k_channels)
143
+ scores = scores + scores_local
144
+ if self.proximal_bias:
145
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
146
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
147
+ if mask is not None:
148
+ scores = scores.masked_fill(mask == 0, -1e4)
149
+ if self.block_length is not None:
150
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
151
+ scores = scores * block_mask + -1e4 * (1 - block_mask)
152
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
153
+ p_attn = self.drop(p_attn)
154
+ output = torch.matmul(p_attn, value)
155
+ if self.window_size is not None:
156
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
157
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
158
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
159
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
160
+ return output, p_attn
161
+
162
+ def _matmul_with_relative_values(self, x, y):
163
+ """
164
+ x: [b, h, l, m]
165
+ y: [h or 1, m, d]
166
+ ret: [b, h, l, d]
167
+ """
168
+ ret = torch.matmul(x, y.unsqueeze(0))
169
+ return ret
170
+
171
+ def _matmul_with_relative_keys(self, x, y):
172
+ """
173
+ x: [b, h, l, d]
174
+ y: [h or 1, m, d]
175
+ ret: [b, h, l, m]
176
+ """
177
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
178
+ return ret
179
+
180
+ def _get_relative_embeddings(self, relative_embeddings, length):
181
+ max_relative_position = 2 * self.window_size + 1
182
+ # Pad first before slice to avoid using cond ops.
183
+ pad_length = max(length - (self.window_size + 1), 0)
184
+ slice_start_position = max((self.window_size + 1) - length, 0)
185
+ slice_end_position = slice_start_position + 2 * length - 1
186
+ if pad_length > 0:
187
+ padded_relative_embeddings = F.pad(
188
+ relative_embeddings,
189
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
190
+ else:
191
+ padded_relative_embeddings = relative_embeddings
192
+ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
193
+ return used_relative_embeddings
194
+
195
+ def _relative_position_to_absolute_position(self, x):
196
+ """
197
+ x: [b, h, l, 2*l-1]
198
+ ret: [b, h, l, l]
199
+ """
200
+ batch, heads, length, _ = x.size()
201
+ # Concat columns of pad to shift from relative to absolute indexing.
202
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
203
+
204
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
205
+ x_flat = x.view([batch, heads, length * 2 * length])
206
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
207
+
208
+ # Reshape and slice out the padded elements.
209
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
210
+ return x_final
211
+
212
+ def _absolute_position_to_relative_position(self, x):
213
+ """
214
+ x: [b, h, l, l]
215
+ ret: [b, h, l, 2*l-1]
216
+ """
217
+ batch, heads, length, _ = x.size()
218
+ # padd along column
219
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
220
+ x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
221
+ # add 0's in the beginning that will skew the elements after reshape
222
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
223
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
224
+ return x_final
225
+
226
+ def _attention_bias_proximal(self, length):
227
+ """Bias for self-attention to encourage attention to close positions.
228
+ Args:
229
+ length: an integer scalar.
230
+ Returns:
231
+ a Tensor with shape [1, 1, length, length]
232
+ """
233
+ r = torch.arange(length, dtype=torch.float32)
234
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
235
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
236
+
237
+
238
+ class FFN(nn.Module):
239
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
240
+ super().__init__()
241
+ self.in_channels = in_channels
242
+ self.out_channels = out_channels
243
+ self.filter_channels = filter_channels
244
+ self.kernel_size = kernel_size
245
+ self.p_dropout = p_dropout
246
+ self.activation = activation
247
+
248
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
249
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
250
+ self.drop = nn.Dropout(p_dropout)
251
+
252
+ def forward(self, x, x_mask):
253
+ x = self.conv_1(x * x_mask)
254
+ if self.activation == "gelu":
255
+ x = x * torch.sigmoid(1.702 * x)
256
+ else:
257
+ x = torch.relu(x)
258
+ x = self.drop(x)
259
+ x = self.conv_2(x * x_mask)
260
+ return x * x_mask
261
+
262
+
263
+ class LayerNorm(nn.Module):
264
+ def __init__(self, channels, eps=1e-4):
265
+ super().__init__()
266
+ self.channels = channels
267
+ self.eps = eps
268
+
269
+ self.gamma = nn.Parameter(torch.ones(channels))
270
+ self.beta = nn.Parameter(torch.zeros(channels))
271
+
272
+ def forward(self, x):
273
+ n_dims = len(x.shape)
274
+ mean = torch.mean(x, 1, keepdim=True)
275
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
276
+
277
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
278
+
279
+ shape = [1, -1] + [1] * (n_dims - 2)
280
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
281
+ return x
282
+
283
+
284
+ class ConvReluNorm(nn.Module):
285
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
286
+ super().__init__()
287
+ self.in_channels = in_channels
288
+ self.hidden_channels = hidden_channels
289
+ self.out_channels = out_channels
290
+ self.kernel_size = kernel_size
291
+ self.n_layers = n_layers
292
+ self.p_dropout = p_dropout
293
+ assert n_layers > 1, "Number of layers should be larger than 0."
294
+
295
+ self.conv_layers = nn.ModuleList()
296
+ self.norm_layers = nn.ModuleList()
297
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
298
+ self.norm_layers.append(LayerNorm(hidden_channels))
299
+ self.relu_drop = nn.Sequential(
300
+ nn.ReLU(),
301
+ nn.Dropout(p_dropout))
302
+ for _ in range(n_layers - 1):
303
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
304
+ self.norm_layers.append(LayerNorm(hidden_channels))
305
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
306
+ self.proj.weight.data.zero_()
307
+ self.proj.bias.data.zero_()
308
+
309
+ def forward(self, x, x_mask):
310
+ x_org = x
311
+ for i in range(self.n_layers):
312
+ x = self.conv_layers[i](x * x_mask)
313
+ x = self.norm_layers[i](x)
314
+ x = self.relu_drop(x)
315
+ x = x_org + self.proj(x)
316
+ return x * x_mask
317
+
318
+
319
+ class RelTransformerEncoder(nn.Module):
320
+ def __init__(self,
321
+ n_vocab,
322
+ out_channels,
323
+ hidden_channels,
324
+ filter_channels,
325
+ n_heads,
326
+ n_layers,
327
+ kernel_size,
328
+ p_dropout=0.0,
329
+ window_size=4,
330
+ block_length=None,
331
+ prenet=True,
332
+ pre_ln=True,
333
+ ):
334
+
335
+ super().__init__()
336
+
337
+ self.n_vocab = n_vocab
338
+ self.out_channels = out_channels
339
+ self.hidden_channels = hidden_channels
340
+ self.filter_channels = filter_channels
341
+ self.n_heads = n_heads
342
+ self.n_layers = n_layers
343
+ self.kernel_size = kernel_size
344
+ self.p_dropout = p_dropout
345
+ self.window_size = window_size
346
+ self.block_length = block_length
347
+ self.prenet = prenet
348
+ if n_vocab > 0:
349
+ self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
350
+
351
+ if prenet:
352
+ self.pre = ConvReluNorm(hidden_channels, hidden_channels, hidden_channels,
353
+ kernel_size=5, n_layers=3, p_dropout=0)
354
+ self.encoder = Encoder(
355
+ hidden_channels,
356
+ filter_channels,
357
+ n_heads,
358
+ n_layers,
359
+ kernel_size,
360
+ p_dropout,
361
+ window_size=window_size,
362
+ block_length=block_length,
363
+ pre_ln=pre_ln,
364
+ )
365
+
366
+ def forward(self, x, x_mask=None):
367
+ if self.n_vocab > 0:
368
+ x_lengths = (x > 0).long().sum(-1)
369
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
370
+ else:
371
+ x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
372
+ x = torch.transpose(x, 1, -1) # [b, h, t]
373
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
374
+
375
+ if self.prenet:
376
+ x = self.pre(x, x_mask)
377
+ x = self.encoder(x, x_mask)
378
+ return x.transpose(1, 2)
preprocess/tools/note_transcription/modules/commons/rnn.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class PreNet(nn.Module):
7
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
8
+ super().__init__()
9
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
10
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
11
+ self.p = dropout
12
+
13
+ def forward(self, x):
14
+ x = self.fc1(x)
15
+ x = F.relu(x)
16
+ x = F.dropout(x, self.p, training=self.training)
17
+ x = self.fc2(x)
18
+ x = F.relu(x)
19
+ x = F.dropout(x, self.p, training=self.training)
20
+ return x
21
+
22
+
23
+ class HighwayNetwork(nn.Module):
24
+ def __init__(self, size):
25
+ super().__init__()
26
+ self.W1 = nn.Linear(size, size)
27
+ self.W2 = nn.Linear(size, size)
28
+ self.W1.bias.data.fill_(0.)
29
+
30
+ def forward(self, x):
31
+ x1 = self.W1(x)
32
+ x2 = self.W2(x)
33
+ g = torch.sigmoid(x2)
34
+ y = g * F.relu(x1) + (1. - g) * x
35
+ return y
36
+
37
+
38
+ class BatchNormConv(nn.Module):
39
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
40
+ super().__init__()
41
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
42
+ self.bnorm = nn.BatchNorm1d(out_channels)
43
+ self.relu = relu
44
+
45
+ def forward(self, x):
46
+ x = self.conv(x)
47
+ x = F.relu(x) if self.relu is True else x
48
+ return self.bnorm(x)
49
+
50
+
51
+ class ConvNorm(torch.nn.Module):
52
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
53
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
54
+ super(ConvNorm, self).__init__()
55
+ if padding is None:
56
+ assert (kernel_size % 2 == 1)
57
+ padding = int(dilation * (kernel_size - 1) / 2)
58
+
59
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
60
+ kernel_size=kernel_size, stride=stride,
61
+ padding=padding, dilation=dilation,
62
+ bias=bias)
63
+
64
+ torch.nn.init.xavier_uniform_(
65
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
66
+
67
+ def forward(self, signal):
68
+ conv_signal = self.conv(signal)
69
+ return conv_signal
70
+
71
+
72
+ class CBHG(nn.Module):
73
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
74
+ super().__init__()
75
+
76
+ # List of all rnns to call `flatten_parameters()` on
77
+ self._to_flatten = []
78
+
79
+ self.bank_kernels = [i for i in range(1, K + 1)]
80
+ self.conv1d_bank = nn.ModuleList()
81
+ for k in self.bank_kernels:
82
+ conv = BatchNormConv(in_channels, channels, k)
83
+ self.conv1d_bank.append(conv)
84
+
85
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
86
+
87
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
88
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
89
+
90
+ # Fix the highway input if necessary
91
+ if proj_channels[-1] != channels:
92
+ self.highway_mismatch = True
93
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
94
+ else:
95
+ self.highway_mismatch = False
96
+
97
+ self.highways = nn.ModuleList()
98
+ for i in range(num_highways):
99
+ hn = HighwayNetwork(channels)
100
+ self.highways.append(hn)
101
+
102
+ self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True)
103
+ self._to_flatten.append(self.rnn)
104
+
105
+ # Avoid fragmentation of RNN parameters and associated warning
106
+ self._flatten_parameters()
107
+
108
+ def forward(self, x):
109
+ # Although we `_flatten_parameters()` on init, when using DataParallel
110
+ # the model gets replicated, making it no longer guaranteed that the
111
+ # weights are contiguous in GPU memory. Hence, we must call it again
112
+ self._flatten_parameters()
113
+
114
+ # Save these for later
115
+ residual = x
116
+ seq_len = x.size(-1)
117
+ conv_bank = []
118
+
119
+ # Convolution Bank
120
+ for conv in self.conv1d_bank:
121
+ c = conv(x) # Convolution
122
+ conv_bank.append(c[:, :, :seq_len])
123
+
124
+ # Stack along the channel axis
125
+ conv_bank = torch.cat(conv_bank, dim=1)
126
+
127
+ # dump the last padding to fit residual
128
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
129
+
130
+ # Conv1d projections
131
+ x = self.conv_project1(x)
132
+ x = self.conv_project2(x)
133
+
134
+ # Residual Connect
135
+ x = x + residual
136
+
137
+ # Through the highways
138
+ x = x.transpose(1, 2)
139
+ if self.highway_mismatch is True:
140
+ x = self.pre_highway(x)
141
+ for h in self.highways:
142
+ x = h(x)
143
+
144
+ # And then the RNN
145
+ x, _ = self.rnn(x)
146
+ return x
147
+
148
+ def _flatten_parameters(self):
149
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
150
+ to improve efficiency and avoid PyTorch yelling at us."""
151
+ [m.flatten_parameters() for m in self._to_flatten]
152
+
153
+
154
+ class TacotronEncoder(nn.Module):
155
+ def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout):
156
+ super().__init__()
157
+ self.embedding = nn.Embedding(num_chars, embed_dims)
158
+ self.pre_net = PreNet(embed_dims, embed_dims, embed_dims, dropout=dropout)
159
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
160
+ proj_channels=[cbhg_channels, cbhg_channels],
161
+ num_highways=num_highways)
162
+ self.proj_out = nn.Linear(cbhg_channels * 2, cbhg_channels)
163
+
164
+ def forward(self, x):
165
+ x = self.embedding(x)
166
+ x = self.pre_net(x)
167
+ x.transpose_(1, 2)
168
+ x = self.cbhg(x)
169
+ x = self.proj_out(x)
170
+ return x
171
+
172
+
173
+ class RNNEncoder(nn.Module):
174
+ def __init__(self, num_chars, embedding_dim, n_convolutions=3, kernel_size=5):
175
+ super(RNNEncoder, self).__init__()
176
+ self.embedding = nn.Embedding(num_chars, embedding_dim, padding_idx=0)
177
+ convolutions = []
178
+ for _ in range(n_convolutions):
179
+ conv_layer = nn.Sequential(
180
+ ConvNorm(embedding_dim,
181
+ embedding_dim,
182
+ kernel_size=kernel_size, stride=1,
183
+ padding=int((kernel_size - 1) / 2),
184
+ dilation=1, w_init_gain='relu'),
185
+ nn.BatchNorm1d(embedding_dim))
186
+ convolutions.append(conv_layer)
187
+ self.convolutions = nn.ModuleList(convolutions)
188
+
189
+ self.lstm = nn.LSTM(embedding_dim, int(embedding_dim / 2), 1,
190
+ batch_first=True, bidirectional=True)
191
+
192
+ def forward(self, x):
193
+ input_lengths = (x > 0).sum(-1)
194
+ input_lengths = input_lengths.cpu().numpy()
195
+
196
+ x = self.embedding(x)
197
+ x = x.transpose(1, 2) # [B, H, T]
198
+ for conv in self.convolutions:
199
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training) + x
200
+ x = x.transpose(1, 2) # [B, T, H]
201
+
202
+ # pytorch tensor are not reversible, hence the conversion
203
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
204
+
205
+ self.lstm.flatten_parameters()
206
+ outputs, _ = self.lstm(x)
207
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
208
+
209
+ return outputs
210
+
211
+
212
+ class DecoderRNN(torch.nn.Module):
213
+ def __init__(self, hidden_size, decoder_rnn_dim, dropout):
214
+ super(DecoderRNN, self).__init__()
215
+ self.in_conv1d = nn.Sequential(
216
+ torch.nn.Conv1d(
217
+ in_channels=hidden_size,
218
+ out_channels=hidden_size,
219
+ kernel_size=9, padding=4,
220
+ ),
221
+ torch.nn.ReLU(),
222
+ torch.nn.Conv1d(
223
+ in_channels=hidden_size,
224
+ out_channels=hidden_size,
225
+ kernel_size=9, padding=4,
226
+ ),
227
+ )
228
+ self.ln = nn.LayerNorm(hidden_size)
229
+ if decoder_rnn_dim == 0:
230
+ decoder_rnn_dim = hidden_size * 2
231
+ self.rnn = torch.nn.LSTM(
232
+ input_size=hidden_size,
233
+ hidden_size=decoder_rnn_dim,
234
+ num_layers=1,
235
+ batch_first=True,
236
+ bidirectional=True,
237
+ dropout=dropout
238
+ )
239
+ self.rnn.flatten_parameters()
240
+ self.conv1d = torch.nn.Conv1d(
241
+ in_channels=decoder_rnn_dim * 2,
242
+ out_channels=hidden_size,
243
+ kernel_size=3,
244
+ padding=1,
245
+ )
246
+
247
+ def forward(self, x):
248
+ input_masks = x.abs().sum(-1).ne(0).data[:, :, None]
249
+ input_lengths = input_masks.sum([-1, -2])
250
+ input_lengths = input_lengths.cpu().numpy()
251
+
252
+ x = self.in_conv1d(x.transpose(1, 2)).transpose(1, 2)
253
+ x = self.ln(x)
254
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
255
+ self.rnn.flatten_parameters()
256
+ x, _ = self.rnn(x) # [B, T, C]
257
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
258
+ x = x * input_masks
259
+ pre_mel = self.conv1d(x.transpose(1, 2)).transpose(1, 2) # [B, T, C]
260
+ pre_mel = pre_mel * input_masks
261
+ return pre_mel
preprocess/tools/note_transcription/modules/commons/transformer.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import Parameter, Linear
5
+ from .layers import LayerNorm, Embedding
6
+ from ...utils.nn.seq_utils import (
7
+ get_incremental_state,
8
+ set_incremental_state,
9
+ softmax,
10
+ make_positions,
11
+ )
12
+ import torch.nn.functional as F
13
+
14
+ DEFAULT_MAX_SOURCE_POSITIONS = 2000
15
+ DEFAULT_MAX_TARGET_POSITIONS = 2000
16
+
17
+
18
+ class SinusoidalPositionalEmbedding(nn.Module):
19
+ """This module produces sinusoidal positional embeddings of any length.
20
+
21
+ Padding symbols are ignored.
22
+ """
23
+
24
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
25
+ super().__init__()
26
+ self.embedding_dim = embedding_dim
27
+ self.padding_idx = padding_idx
28
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
29
+ init_size,
30
+ embedding_dim,
31
+ padding_idx,
32
+ )
33
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
34
+
35
+ @staticmethod
36
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
37
+ """Build sinusoidal embeddings.
38
+
39
+ This matches the implementation in tensor2tensor, but differs slightly
40
+ from the description in Section 3.5 of "Attention Is All You Need".
41
+ """
42
+ half_dim = embedding_dim // 2
43
+ emb = math.log(10000) / (half_dim - 1)
44
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
45
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
46
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
47
+ if embedding_dim % 2 == 1:
48
+ # zero pad
49
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
50
+ if padding_idx is not None:
51
+ emb[padding_idx, :] = 0
52
+ return emb
53
+
54
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
55
+ """Input is expected to be of size [bsz x seqlen]."""
56
+ bsz, seq_len = input.shape[:2]
57
+ max_pos = self.padding_idx + 1 + seq_len
58
+ if self.weights is None or max_pos > self.weights.size(0):
59
+ # recompute/expand embeddings if needed
60
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
61
+ max_pos,
62
+ self.embedding_dim,
63
+ self.padding_idx,
64
+ )
65
+ self.weights = self.weights.to(self._float_tensor)
66
+
67
+ if incremental_state is not None:
68
+ # positions is the same for every token when decoding a single step
69
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
70
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
71
+
72
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
73
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
74
+
75
+ def max_positions(self):
76
+ """Maximum number of supported positions."""
77
+ return int(1e5) # an arbitrary large number
78
+
79
+
80
+ class TransformerFFNLayer(nn.Module):
81
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
82
+ super().__init__()
83
+ self.kernel_size = kernel_size
84
+ self.dropout = dropout
85
+ self.act = act
86
+ if padding == 'SAME':
87
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
88
+ elif padding == 'LEFT':
89
+ self.ffn_1 = nn.Sequential(
90
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
91
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
92
+ )
93
+ self.ffn_2 = Linear(filter_size, hidden_size)
94
+
95
+ def forward(self, x, incremental_state=None):
96
+ # x: T x B x C
97
+ if incremental_state is not None:
98
+ saved_state = self._get_input_buffer(incremental_state)
99
+ if 'prev_input' in saved_state:
100
+ prev_input = saved_state['prev_input']
101
+ x = torch.cat((prev_input, x), dim=0)
102
+ x = x[-self.kernel_size:]
103
+ saved_state['prev_input'] = x
104
+ self._set_input_buffer(incremental_state, saved_state)
105
+
106
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
107
+ x = x * self.kernel_size ** -0.5
108
+
109
+ if incremental_state is not None:
110
+ x = x[-1:]
111
+ if self.act == 'gelu':
112
+ x = F.gelu(x)
113
+ if self.act == 'relu':
114
+ x = F.relu(x)
115
+ x = F.dropout(x, self.dropout, training=self.training)
116
+ x = self.ffn_2(x)
117
+ return x
118
+
119
+ def _get_input_buffer(self, incremental_state):
120
+ return get_incremental_state(
121
+ self,
122
+ incremental_state,
123
+ 'f',
124
+ ) or {}
125
+
126
+ def _set_input_buffer(self, incremental_state, buffer):
127
+ set_incremental_state(
128
+ self,
129
+ incremental_state,
130
+ 'f',
131
+ buffer,
132
+ )
133
+
134
+ def clear_buffer(self, incremental_state):
135
+ if incremental_state is not None:
136
+ saved_state = self._get_input_buffer(incremental_state)
137
+ if 'prev_input' in saved_state:
138
+ del saved_state['prev_input']
139
+ self._set_input_buffer(incremental_state, saved_state)
140
+
141
+
142
+ class MultiheadAttention(nn.Module):
143
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
144
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
145
+ encoder_decoder_attention=False):
146
+ super().__init__()
147
+ self.embed_dim = embed_dim
148
+ self.kdim = kdim if kdim is not None else embed_dim
149
+ self.vdim = vdim if vdim is not None else embed_dim
150
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
151
+
152
+ self.num_heads = num_heads
153
+ self.dropout = dropout
154
+ self.head_dim = embed_dim // num_heads
155
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
156
+ self.scaling = self.head_dim ** -0.5
157
+
158
+ self.self_attention = self_attention
159
+ self.encoder_decoder_attention = encoder_decoder_attention
160
+
161
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
162
+ 'value to be of the same size'
163
+
164
+ if self.qkv_same_dim:
165
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
166
+ else:
167
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
168
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
169
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
170
+
171
+ if bias:
172
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
173
+ else:
174
+ self.register_parameter('in_proj_bias', None)
175
+
176
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
177
+
178
+ if add_bias_kv:
179
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
180
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
181
+ else:
182
+ self.bias_k = self.bias_v = None
183
+
184
+ self.add_zero_attn = add_zero_attn
185
+
186
+ self.reset_parameters()
187
+
188
+ self.enable_torch_version = False
189
+ if hasattr(F, "multi_head_attention_forward"):
190
+ self.enable_torch_version = True
191
+ else:
192
+ self.enable_torch_version = False
193
+ self.last_attn_probs = None
194
+
195
+ def reset_parameters(self):
196
+ if self.qkv_same_dim:
197
+ nn.init.xavier_uniform_(self.in_proj_weight)
198
+ else:
199
+ nn.init.xavier_uniform_(self.k_proj_weight)
200
+ nn.init.xavier_uniform_(self.v_proj_weight)
201
+ nn.init.xavier_uniform_(self.q_proj_weight)
202
+
203
+ nn.init.xavier_uniform_(self.out_proj.weight)
204
+ if self.in_proj_bias is not None:
205
+ nn.init.constant_(self.in_proj_bias, 0.)
206
+ nn.init.constant_(self.out_proj.bias, 0.)
207
+ if self.bias_k is not None:
208
+ nn.init.xavier_normal_(self.bias_k)
209
+ if self.bias_v is not None:
210
+ nn.init.xavier_normal_(self.bias_v)
211
+
212
+ def forward(
213
+ self,
214
+ query, key, value,
215
+ key_padding_mask=None,
216
+ incremental_state=None,
217
+ need_weights=True,
218
+ static_kv=False,
219
+ attn_mask=None,
220
+ before_softmax=False,
221
+ need_head_weights=False,
222
+ enc_dec_attn_constraint_mask=None,
223
+ reset_attn_weight=None
224
+ ):
225
+ """Input shape: Time x Batch x Channel
226
+
227
+ Args:
228
+ key_padding_mask (ByteTensor, optional): mask to exclude
229
+ keys that are pads, of shape `(batch, src_len)`, where
230
+ padding elements are indicated by 1s.
231
+ need_weights (bool, optional): return the attention weights,
232
+ averaged over heads (default: False).
233
+ attn_mask (ByteTensor, optional): typically used to
234
+ implement causal attention, where the mask prevents the
235
+ attention from looking forward in time (default: None).
236
+ before_softmax (bool, optional): return the raw attention
237
+ weights and values before the attention softmax.
238
+ need_head_weights (bool, optional): return the attention
239
+ weights for each head. Implies *need_weights*. Default:
240
+ return the average attention weights over all heads.
241
+ """
242
+ if need_head_weights:
243
+ need_weights = True
244
+
245
+ tgt_len, bsz, embed_dim = query.size()
246
+ assert embed_dim == self.embed_dim
247
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
248
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
249
+ if self.qkv_same_dim:
250
+ return F.multi_head_attention_forward(query, key, value,
251
+ self.embed_dim, self.num_heads,
252
+ self.in_proj_weight,
253
+ self.in_proj_bias, self.bias_k, self.bias_v,
254
+ self.add_zero_attn, self.dropout,
255
+ self.out_proj.weight, self.out_proj.bias,
256
+ self.training, key_padding_mask, need_weights,
257
+ attn_mask)
258
+ else:
259
+ return F.multi_head_attention_forward(query, key, value,
260
+ self.embed_dim, self.num_heads,
261
+ torch.empty([0]),
262
+ self.in_proj_bias, self.bias_k, self.bias_v,
263
+ self.add_zero_attn, self.dropout,
264
+ self.out_proj.weight, self.out_proj.bias,
265
+ self.training, key_padding_mask, need_weights,
266
+ attn_mask, use_separate_proj_weight=True,
267
+ q_proj_weight=self.q_proj_weight,
268
+ k_proj_weight=self.k_proj_weight,
269
+ v_proj_weight=self.v_proj_weight)
270
+
271
+ if incremental_state is not None:
272
+ saved_state = self._get_input_buffer(incremental_state)
273
+ if 'prev_key' in saved_state:
274
+ # previous time steps are cached - no need to recompute
275
+ # key and value if they are static
276
+ if static_kv:
277
+ assert self.encoder_decoder_attention and not self.self_attention
278
+ key = value = None
279
+ else:
280
+ saved_state = None
281
+
282
+ if self.self_attention:
283
+ # self-attention
284
+ q, k, v = self.in_proj_qkv(query)
285
+ elif self.encoder_decoder_attention:
286
+ # encoder-decoder attention
287
+ q = self.in_proj_q(query)
288
+ if key is None:
289
+ assert value is None
290
+ k = v = None
291
+ else:
292
+ k = self.in_proj_k(key)
293
+ v = self.in_proj_v(key)
294
+
295
+ else:
296
+ q = self.in_proj_q(query)
297
+ k = self.in_proj_k(key)
298
+ v = self.in_proj_v(value)
299
+ q *= self.scaling
300
+
301
+ if self.bias_k is not None:
302
+ assert self.bias_v is not None
303
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
304
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
305
+ if attn_mask is not None:
306
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
307
+ if key_padding_mask is not None:
308
+ key_padding_mask = torch.cat(
309
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
310
+
311
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
312
+ if k is not None:
313
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
314
+ if v is not None:
315
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
316
+
317
+ if saved_state is not None:
318
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
319
+ if 'prev_key' in saved_state:
320
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
321
+ if static_kv:
322
+ k = prev_key
323
+ else:
324
+ k = torch.cat((prev_key, k), dim=1)
325
+ if 'prev_value' in saved_state:
326
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
327
+ if static_kv:
328
+ v = prev_value
329
+ else:
330
+ v = torch.cat((prev_value, v), dim=1)
331
+ if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
332
+ prev_key_padding_mask = saved_state['prev_key_padding_mask']
333
+ if static_kv:
334
+ key_padding_mask = prev_key_padding_mask
335
+ else:
336
+ key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
337
+
338
+ saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
339
+ saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
340
+ saved_state['prev_key_padding_mask'] = key_padding_mask
341
+
342
+ self._set_input_buffer(incremental_state, saved_state)
343
+
344
+ src_len = k.size(1)
345
+
346
+ # This is part of a workaround to get around fork/join parallelism
347
+ # not supporting Optional types.
348
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
349
+ key_padding_mask = None
350
+
351
+ if key_padding_mask is not None:
352
+ assert key_padding_mask.size(0) == bsz
353
+ assert key_padding_mask.size(1) == src_len
354
+
355
+ if self.add_zero_attn:
356
+ src_len += 1
357
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
358
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
359
+ if attn_mask is not None:
360
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
361
+ if key_padding_mask is not None:
362
+ key_padding_mask = torch.cat(
363
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
364
+
365
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
366
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
367
+
368
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
369
+
370
+ if attn_mask is not None:
371
+ if len(attn_mask.shape) == 2:
372
+ attn_mask = attn_mask.unsqueeze(0)
373
+ elif len(attn_mask.shape) == 3:
374
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
375
+ bsz * self.num_heads, tgt_len, src_len)
376
+ attn_weights = attn_weights + attn_mask
377
+
378
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
379
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
380
+ attn_weights = attn_weights.masked_fill(
381
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
382
+ -1e8,
383
+ )
384
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
385
+
386
+ if key_padding_mask is not None:
387
+ # don't attend to padding symbols
388
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
389
+ attn_weights = attn_weights.masked_fill(
390
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
391
+ -1e8,
392
+ )
393
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
394
+
395
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
396
+
397
+ if before_softmax:
398
+ return attn_weights, v
399
+
400
+ attn_weights_float = softmax(attn_weights, dim=-1)
401
+ attn_weights = attn_weights_float.type_as(attn_weights)
402
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
403
+
404
+ if reset_attn_weight is not None:
405
+ if reset_attn_weight:
406
+ self.last_attn_probs = attn_probs.detach()
407
+ else:
408
+ assert self.last_attn_probs is not None
409
+ attn_probs = self.last_attn_probs
410
+ attn = torch.bmm(attn_probs, v)
411
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
412
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
413
+ attn = self.out_proj(attn)
414
+
415
+ if need_weights:
416
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
417
+ if not need_head_weights:
418
+ # average attention weights over heads
419
+ attn_weights = attn_weights.mean(dim=0)
420
+ else:
421
+ attn_weights = None
422
+
423
+ return attn, (attn_weights, attn_logits)
424
+
425
+ def in_proj_qkv(self, query):
426
+ return self._in_proj(query).chunk(3, dim=-1)
427
+
428
+ def in_proj_q(self, query):
429
+ if self.qkv_same_dim:
430
+ return self._in_proj(query, end=self.embed_dim)
431
+ else:
432
+ bias = self.in_proj_bias
433
+ if bias is not None:
434
+ bias = bias[:self.embed_dim]
435
+ return F.linear(query, self.q_proj_weight, bias)
436
+
437
+ def in_proj_k(self, key):
438
+ if self.qkv_same_dim:
439
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
440
+ else:
441
+ weight = self.k_proj_weight
442
+ bias = self.in_proj_bias
443
+ if bias is not None:
444
+ bias = bias[self.embed_dim:2 * self.embed_dim]
445
+ return F.linear(key, weight, bias)
446
+
447
+ def in_proj_v(self, value):
448
+ if self.qkv_same_dim:
449
+ return self._in_proj(value, start=2 * self.embed_dim)
450
+ else:
451
+ weight = self.v_proj_weight
452
+ bias = self.in_proj_bias
453
+ if bias is not None:
454
+ bias = bias[2 * self.embed_dim:]
455
+ return F.linear(value, weight, bias)
456
+
457
+ def _in_proj(self, input, start=0, end=None):
458
+ weight = self.in_proj_weight
459
+ bias = self.in_proj_bias
460
+ weight = weight[start:end, :]
461
+ if bias is not None:
462
+ bias = bias[start:end]
463
+ return F.linear(input, weight, bias)
464
+
465
+ def _get_input_buffer(self, incremental_state):
466
+ return get_incremental_state(
467
+ self,
468
+ incremental_state,
469
+ 'attn_state',
470
+ ) or {}
471
+
472
+ def _set_input_buffer(self, incremental_state, buffer):
473
+ set_incremental_state(
474
+ self,
475
+ incremental_state,
476
+ 'attn_state',
477
+ buffer,
478
+ )
479
+
480
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
481
+ return attn_weights
482
+
483
+ def clear_buffer(self, incremental_state=None):
484
+ if incremental_state is not None:
485
+ saved_state = self._get_input_buffer(incremental_state)
486
+ if 'prev_key' in saved_state:
487
+ del saved_state['prev_key']
488
+ if 'prev_value' in saved_state:
489
+ del saved_state['prev_value']
490
+ self._set_input_buffer(incremental_state, saved_state)
491
+
492
+
493
+ class EncSALayer(nn.Module):
494
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
495
+ relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu'):
496
+ super().__init__()
497
+ self.c = c
498
+ self.dropout = dropout
499
+ self.num_heads = num_heads
500
+ if num_heads > 0:
501
+ self.layer_norm1 = LayerNorm(c)
502
+ self.self_attn = MultiheadAttention(
503
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
504
+ self.layer_norm2 = LayerNorm(c)
505
+ self.ffn = TransformerFFNLayer(
506
+ c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
507
+
508
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
509
+ layer_norm_training = kwargs.get('layer_norm_training', None)
510
+ if layer_norm_training is not None:
511
+ self.layer_norm1.training = layer_norm_training
512
+ self.layer_norm2.training = layer_norm_training
513
+ if self.num_heads > 0:
514
+ residual = x
515
+ x = self.layer_norm1(x)
516
+ x, _, = self.self_attn(
517
+ query=x,
518
+ key=x,
519
+ value=x,
520
+ key_padding_mask=encoder_padding_mask
521
+ )
522
+ x = F.dropout(x, self.dropout, training=self.training)
523
+ x = residual + x
524
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
525
+
526
+ residual = x
527
+ x = self.layer_norm2(x)
528
+ x = self.ffn(x)
529
+ x = F.dropout(x, self.dropout, training=self.training)
530
+ x = residual + x
531
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
532
+ return x
533
+
534
+
535
+ class DecSALayer(nn.Module):
536
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
537
+ kernel_size=9, act='gelu'):
538
+ super().__init__()
539
+ self.c = c
540
+ self.dropout = dropout
541
+ self.layer_norm1 = LayerNorm(c)
542
+ self.self_attn = MultiheadAttention(
543
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
544
+ )
545
+ self.layer_norm2 = LayerNorm(c)
546
+ self.encoder_attn = MultiheadAttention(
547
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
548
+ )
549
+ self.layer_norm3 = LayerNorm(c)
550
+ self.ffn = TransformerFFNLayer(
551
+ c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
552
+
553
+ def forward(
554
+ self,
555
+ x,
556
+ encoder_out=None,
557
+ encoder_padding_mask=None,
558
+ incremental_state=None,
559
+ self_attn_mask=None,
560
+ self_attn_padding_mask=None,
561
+ attn_out=None,
562
+ reset_attn_weight=None,
563
+ **kwargs,
564
+ ):
565
+ layer_norm_training = kwargs.get('layer_norm_training', None)
566
+ if layer_norm_training is not None:
567
+ self.layer_norm1.training = layer_norm_training
568
+ self.layer_norm2.training = layer_norm_training
569
+ self.layer_norm3.training = layer_norm_training
570
+ residual = x
571
+ x = self.layer_norm1(x)
572
+ x, _ = self.self_attn(
573
+ query=x,
574
+ key=x,
575
+ value=x,
576
+ key_padding_mask=self_attn_padding_mask,
577
+ incremental_state=incremental_state,
578
+ attn_mask=self_attn_mask
579
+ )
580
+ x = F.dropout(x, self.dropout, training=self.training)
581
+ x = residual + x
582
+
583
+ attn_logits = None
584
+ if encoder_out is not None or attn_out is not None:
585
+ residual = x
586
+ x = self.layer_norm2(x)
587
+ if encoder_out is not None:
588
+ x, attn = self.encoder_attn(
589
+ query=x,
590
+ key=encoder_out,
591
+ value=encoder_out,
592
+ key_padding_mask=encoder_padding_mask,
593
+ incremental_state=incremental_state,
594
+ static_kv=True,
595
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
596
+ 'enc_dec_attn_constraint_mask'),
597
+ reset_attn_weight=reset_attn_weight
598
+ )
599
+ attn_logits = attn[1]
600
+ elif attn_out is not None:
601
+ x = self.encoder_attn.in_proj_v(attn_out)
602
+ if encoder_out is not None or attn_out is not None:
603
+ x = F.dropout(x, self.dropout, training=self.training)
604
+ x = residual + x
605
+
606
+ residual = x
607
+ x = self.layer_norm3(x)
608
+ x = self.ffn(x, incremental_state=incremental_state)
609
+ x = F.dropout(x, self.dropout, training=self.training)
610
+ x = residual + x
611
+ return x, attn_logits
612
+
613
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
614
+ self.encoder_attn.clear_buffer(incremental_state)
615
+ self.ffn.clear_buffer(incremental_state)
616
+
617
+ def set_buffer(self, name, tensor, incremental_state):
618
+ return set_incremental_state(self, incremental_state, name, tensor)
619
+
620
+
621
+ class TransformerEncoderLayer(nn.Module):
622
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2):
623
+ super().__init__()
624
+ self.hidden_size = hidden_size
625
+ self.dropout = dropout
626
+ self.num_heads = num_heads
627
+ self.op = EncSALayer(
628
+ hidden_size, num_heads, dropout=dropout,
629
+ attention_dropout=0.0, relu_dropout=dropout,
630
+ kernel_size=kernel_size)
631
+
632
+ def forward(self, x, **kwargs):
633
+ return self.op(x, **kwargs)
634
+
635
+
636
+ class TransformerDecoderLayer(nn.Module):
637
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2):
638
+ super().__init__()
639
+ self.hidden_size = hidden_size
640
+ self.dropout = dropout
641
+ self.num_heads = num_heads
642
+ self.op = DecSALayer(
643
+ hidden_size, num_heads, dropout=dropout,
644
+ attention_dropout=0.0, relu_dropout=dropout,
645
+ kernel_size=kernel_size)
646
+
647
+ def forward(self, x, **kwargs):
648
+ return self.op(x, **kwargs)
649
+
650
+ def clear_buffer(self, *args):
651
+ return self.op.clear_buffer(*args)
652
+
653
+ def set_buffer(self, *args):
654
+ return self.op.set_buffer(*args)
655
+
656
+
657
+ class FFTBlocks(nn.Module):
658
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
659
+ num_heads=2, use_pos_embed=True, use_last_norm=True,
660
+ use_pos_embed_alpha=True):
661
+ super().__init__()
662
+ self.num_layers = num_layers
663
+ embed_dim = self.hidden_size = hidden_size
664
+ self.dropout = dropout
665
+ self.use_pos_embed = use_pos_embed
666
+ self.use_last_norm = use_last_norm
667
+ if use_pos_embed:
668
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
669
+ self.padding_idx = 0
670
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
671
+ self.embed_positions = SinusoidalPositionalEmbedding(
672
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
673
+ )
674
+
675
+ self.layers = nn.ModuleList([])
676
+ self.layers.extend([
677
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
678
+ kernel_size=ffn_kernel_size, num_heads=num_heads)
679
+ for _ in range(self.num_layers)
680
+ ])
681
+ if self.use_last_norm:
682
+ self.layer_norm = nn.LayerNorm(embed_dim)
683
+ else:
684
+ self.layer_norm = None
685
+
686
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
687
+ """
688
+ :param x: [B, T, C]
689
+ :param padding_mask: [B, T]
690
+ :return: [B, T, C] or [L, B, T, C]
691
+ """
692
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
693
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
694
+ if self.use_pos_embed:
695
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
696
+ x = x + positions
697
+ x = F.dropout(x, p=self.dropout, training=self.training)
698
+ # B x T x C -> T x B x C
699
+ x = x.transpose(0, 1) * nonpadding_mask_TB
700
+ hiddens = []
701
+ for layer in self.layers:
702
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
703
+ hiddens.append(x)
704
+ if self.use_last_norm:
705
+ x = self.layer_norm(x) * nonpadding_mask_TB
706
+ if return_hiddens:
707
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
708
+ x = x.transpose(1, 2) # [L, B, T, C]
709
+ else:
710
+ x = x.transpose(0, 1) # [B, T, C]
711
+ return x
712
+
713
+
714
+ class FastSpeechEncoder(FFTBlocks):
715
+ def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9, num_heads=2,
716
+ dropout=0.0):
717
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
718
+ use_pos_embed=False, dropout=dropout) # use_pos_embed_alpha for compatibility
719
+ self.embed_tokens = Embedding(dict_size, hidden_size, 0)
720
+ self.embed_scale = math.sqrt(hidden_size)
721
+ self.padding_idx = 0
722
+ self.embed_positions = SinusoidalPositionalEmbedding(
723
+ hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
724
+ )
725
+
726
+ def forward(self, txt_tokens, attn_mask=None):
727
+ """
728
+
729
+ :param txt_tokens: [B, T]
730
+ :return: {
731
+ 'encoder_out': [B x T x C]
732
+ }
733
+ """
734
+ encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
735
+ x = self.forward_embedding(txt_tokens) # [B, T, H]
736
+ if self.num_layers > 0:
737
+ x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
738
+ return x
739
+
740
+ def forward_embedding(self, txt_tokens):
741
+ # embed tokens and positions
742
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
743
+ positions = self.embed_positions(txt_tokens)
744
+ x = x + positions
745
+ x = F.dropout(x, p=self.dropout, training=self.training)
746
+ return x
747
+
748
+
749
+ class FastSpeechDecoder(FFTBlocks):
750
+ def __init__(self, hidden_size=256, num_layers=4, kernel_size=9, num_heads=2):
751
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
preprocess/tools/note_transcription/modules/commons/wavenet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from packaging import version
4
+
5
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
6
+ n_channels_int = n_channels[0]
7
+ in_act = input_a + input_b
8
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
9
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
10
+ acts = t_act * s_act
11
+ return acts
12
+
13
+ jit_fused_add_tanh_sigmoid_multiply = fused_add_tanh_sigmoid_multiply
14
+
15
+ def script_function():
16
+ if version.parse(torch.__version__) >= version.parse('2.0'):
17
+ global jit_fused_add_tanh_sigmoid_multiply
18
+ jit_fused_add_tanh_sigmoid_multiply = torch.jit.script(fused_add_tanh_sigmoid_multiply)
19
+
20
+
21
+ class WN(torch.nn.Module):
22
+ def __init__(self, hidden_size, kernel_size, dilation_rate, n_layers, c_cond=0,
23
+ p_dropout=0, share_cond_layers=False, is_BTC=False):
24
+ super(WN, self).__init__()
25
+ assert (kernel_size % 2 == 1)
26
+ assert (hidden_size % 2 == 0)
27
+ self.is_BTC = is_BTC
28
+ self.hidden_size = hidden_size
29
+ self.kernel_size = kernel_size
30
+ self.dilation_rate = dilation_rate
31
+ self.n_layers = n_layers
32
+ self.gin_channels = c_cond
33
+ self.p_dropout = p_dropout
34
+ self.share_cond_layers = share_cond_layers
35
+
36
+ self.in_layers = torch.nn.ModuleList()
37
+ self.res_skip_layers = torch.nn.ModuleList()
38
+ self.drop = nn.Dropout(p_dropout)
39
+
40
+ if c_cond != 0 and not share_cond_layers:
41
+ cond_layer = torch.nn.Conv1d(c_cond, 2 * hidden_size * n_layers, 1)
42
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
43
+
44
+ for i in range(n_layers):
45
+ dilation = dilation_rate ** i
46
+ padding = int((kernel_size * dilation - dilation) / 2)
47
+ in_layer = torch.nn.Conv1d(hidden_size, 2 * hidden_size, kernel_size,
48
+ dilation=dilation, padding=padding)
49
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
50
+ self.in_layers.append(in_layer)
51
+
52
+ # last one is not necessary
53
+ if i < n_layers - 1:
54
+ res_skip_channels = 2 * hidden_size
55
+ else:
56
+ res_skip_channels = hidden_size
57
+
58
+ res_skip_layer = torch.nn.Conv1d(hidden_size, res_skip_channels, 1)
59
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
60
+ self.res_skip_layers.append(res_skip_layer)
61
+
62
+ script_function()
63
+
64
+ def forward(self, x, nonpadding=None, cond=None):
65
+ if self.is_BTC:
66
+ x = x.transpose(1, 2)
67
+ cond = cond.transpose(1, 2) if cond is not None else None
68
+ nonpadding = nonpadding.transpose(1, 2) if nonpadding is not None else None
69
+ if nonpadding is None:
70
+ nonpadding = 1
71
+ output = torch.zeros_like(x)
72
+ n_channels_tensor = torch.IntTensor([self.hidden_size])
73
+
74
+ if cond is not None and not self.share_cond_layers:
75
+ cond = self.cond_layer(cond)
76
+
77
+ for i in range(self.n_layers):
78
+ x_in = self.in_layers[i](x)
79
+ x_in = self.drop(x_in)
80
+ if cond is not None:
81
+ cond_offset = i * 2 * self.hidden_size
82
+ cond_l = cond[:, cond_offset:cond_offset + 2 * self.hidden_size, :]
83
+ else:
84
+ cond_l = torch.zeros_like(x_in)
85
+
86
+ if version.parse(torch.__version__) >= version.parse('2.0'):
87
+ acts = jit_fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor)
88
+ else:
89
+ acts = fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor)
90
+
91
+ res_skip_acts = self.res_skip_layers[i](acts)
92
+ if i < self.n_layers - 1:
93
+ x = (x + res_skip_acts[:, :self.hidden_size, :]) * nonpadding
94
+ output = output + res_skip_acts[:, self.hidden_size:, :]
95
+ else:
96
+ output = output + res_skip_acts
97
+ output = output * nonpadding
98
+ if self.is_BTC:
99
+ output = output.transpose(1, 2)
100
+ return output
101
+
102
+ def remove_weight_norm(self):
103
+ def remove_weight_norm(m):
104
+ try:
105
+ nn.utils.remove_weight_norm(m)
106
+ except ValueError: # this module didn't have weight norm
107
+ return
108
+
109
+ self.apply(remove_weight_norm)
preprocess/tools/note_transcription/modules/pe/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Pitch extractor modules for ROSVOT."""